first add files

This commit is contained in:
2023-10-08 20:59:00 +08:00
parent b494be364b
commit 1dac226337
991 changed files with 368151 additions and 40 deletions

View File

@@ -0,0 +1,11 @@
# ext/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .. import util as _sa_util
_sa_util.preloaded.import_prefix("sqlalchemy.ext")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,22 @@
# ext/asyncio/__init__.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .engine import async_engine_from_config
from .engine import AsyncConnection
from .engine import AsyncEngine
from .engine import AsyncTransaction
from .engine import create_async_engine
from .events import AsyncConnectionEvents
from .events import AsyncSessionEvents
from .result import AsyncMappingResult
from .result import AsyncResult
from .result import AsyncScalarResult
from .scoping import async_scoped_session
from .session import async_object_session
from .session import async_session
from .session import AsyncSession
from .session import AsyncSessionTransaction

View File

@@ -0,0 +1,89 @@
import abc
import functools
import weakref
from . import exc as async_exc
class ReversibleProxy:
# weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
_proxy_objects = {}
__slots__ = ("__weakref__",)
def _assign_proxied(self, target):
if target is not None:
target_ref = weakref.ref(target, ReversibleProxy._target_gced)
proxy_ref = weakref.ref(
self,
functools.partial(ReversibleProxy._target_gced, target_ref),
)
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
return target
@classmethod
def _target_gced(cls, ref, proxy_ref=None):
cls._proxy_objects.pop(ref, None)
@classmethod
def _regenerate_proxy_for_target(cls, target):
raise NotImplementedError()
@classmethod
def _retrieve_proxy_for_target(cls, target, regenerate=True):
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
except KeyError:
pass
else:
proxy = proxy_ref()
if proxy is not None:
return proxy
if regenerate:
return cls._regenerate_proxy_for_target(target)
else:
return None
class StartableContext(abc.ABC):
__slots__ = ()
@abc.abstractmethod
async def start(self, is_ctxmanager=False):
pass
def __await__(self):
return self.start().__await__()
async def __aenter__(self):
return await self.start(is_ctxmanager=True)
@abc.abstractmethod
async def __aexit__(self, type_, value, traceback):
pass
def _raise_for_not_started(self):
raise async_exc.AsyncContextNotStarted(
"%s context has not been started and object has not been awaited."
% (self.__class__.__name__)
)
class ProxyComparable(ReversibleProxy):
__slots__ = ()
def __hash__(self):
return id(self)
def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self._proxied == other._proxied
)
def __ne__(self, other):
return (
not isinstance(other, self.__class__)
or self._proxied != other._proxied
)

View File

@@ -0,0 +1,828 @@
# ext/asyncio/engine.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import asyncio
from . import exc as async_exc
from .base import ProxyComparable
from .base import StartableContext
from .result import _ensure_sync_result
from .result import AsyncResult
from ... import exc
from ... import inspection
from ... import util
from ...engine import create_engine as _create_engine
from ...engine.base import NestedTransaction
from ...future import Connection
from ...future import Engine
from ...util.concurrency import greenlet_spawn
def create_async_engine(*arg, **kw):
"""Create a new async engine instance.
Arguments passed to :func:`_asyncio.create_async_engine` are mostly
identical to those passed to the :func:`_sa.create_engine` function.
The specified dialect must be an asyncio-compatible dialect
such as :ref:`dialect-postgresql-asyncpg`.
.. versionadded:: 1.4
"""
if kw.get("server_side_cursors", False):
raise async_exc.AsyncMethodRequired(
"Can't set server_side_cursors for async engine globally; "
"use the connection.stream() method for an async "
"streaming result set"
)
kw["future"] = True
sync_engine = _create_engine(*arg, **kw)
return AsyncEngine(sync_engine)
def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
"""Create a new AsyncEngine instance using a configuration dictionary.
This function is analogous to the :func:`_sa.engine_from_config` function
in SQLAlchemy Core, except that the requested dialect must be an
asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`.
The argument signature of the function is identical to that
of :func:`_sa.engine_from_config`.
.. versionadded:: 1.4.29
"""
options = {
key[len(prefix) :]: value
for key, value in configuration.items()
if key.startswith(prefix)
}
options["_coerce_config"] = True
options.update(kwargs)
url = options.pop("url")
return create_async_engine(url, **options)
class AsyncConnectable:
__slots__ = "_slots_dispatch", "__weakref__"
@util.create_proxy_methods(
Connection,
":class:`_future.Connection`",
":class:`_asyncio.AsyncConnection`",
classmethods=[],
methods=[],
attributes=[
"closed",
"invalidated",
"dialect",
"default_isolation_level",
],
)
class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
:meth:`_asyncio.AsyncEngine.connect`
method of :class:`_asyncio.AsyncEngine`::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
async with engine.connect() as conn:
result = await conn.execute(select(table))
.. versionadded:: 1.4
""" # noqa
# AsyncConnection is a thin proxy; no state should be added here
# that is not retrievable from the "sync" engine / connection, e.g.
# current transaction, info, etc. It should be possible to
# create a new AsyncConnection that matches this one given only the
# "sync" elements.
__slots__ = (
"engine",
"sync_engine",
"sync_connection",
)
def __init__(self, async_engine, sync_connection=None):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
self.sync_connection = self._assign_proxied(sync_connection)
sync_connection: Connection
"""Reference to the sync-style :class:`_engine.Connection` this
:class:`_asyncio.AsyncConnection` proxies requests towards.
This instance can be used as an event target.
.. seealso::
:ref:`asyncio_events`
"""
sync_engine: Engine
"""Reference to the sync-style :class:`_engine.Engine` this
:class:`_asyncio.AsyncConnection` is associated with via its underlying
:class:`_engine.Connection`.
This instance can be used as an event target.
.. seealso::
:ref:`asyncio_events`
"""
@classmethod
def _regenerate_proxy_for_target(cls, target):
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
)
async def start(self, is_ctxmanager=False):
"""Start this :class:`_asyncio.AsyncConnection` object's context
outside of using a Python ``with:`` block.
"""
if self.sync_connection:
raise exc.InvalidRequestError("connection is already started")
self.sync_connection = self._assign_proxied(
await (greenlet_spawn(self.sync_engine.connect))
)
return self
@property
def connection(self):
"""Not implemented for async; call
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
"""
raise exc.InvalidRequestError(
"AsyncConnection.connection accessor is not implemented as the "
"attribute may need to reconnect on an invalidated connection. "
"Use the get_raw_connection() method."
)
async def get_raw_connection(self):
"""Return the pooled DBAPI-level connection in use by this
:class:`_asyncio.AsyncConnection`.
This is a SQLAlchemy connection-pool proxied connection
which then has the attribute
:attr:`_pool._ConnectionFairy.driver_connection` that refers to the
actual driver connection. Its
:attr:`_pool._ConnectionFairy.dbapi_connection` refers instead
to an :class:`_engine.AdaptedConnection` instance that
adapts the driver connection to the DBAPI protocol.
"""
conn = self._sync_connection()
return await greenlet_spawn(getattr, conn, "connection")
@property
def _proxied(self):
return self.sync_connection
@property
def info(self):
"""Return the :attr:`_engine.Connection.info` dictionary of the
underlying :class:`_engine.Connection`.
This dictionary is freely writable for user-defined state to be
associated with the database connection.
This attribute is only available if the :class:`.AsyncConnection` is
currently connected. If the :attr:`.AsyncConnection.closed` attribute
is ``True``, then accessing this attribute will raise
:class:`.ResourceClosedError`.
.. versionadded:: 1.4.0b2
"""
return self.sync_connection.info
def _sync_connection(self):
if not self.sync_connection:
self._raise_for_not_started()
return self.sync_connection
def begin(self):
"""Begin a transaction prior to autobegin occurring."""
self._sync_connection()
return AsyncTransaction(self)
def begin_nested(self):
"""Begin a nested transaction and return a transaction handle."""
self._sync_connection()
return AsyncTransaction(self, nested=True)
async def invalidate(self, exception=None):
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
See the method :meth:`_engine.Connection.invalidate` for full
detail on this method.
"""
conn = self._sync_connection()
return await greenlet_spawn(conn.invalidate, exception=exception)
async def get_isolation_level(self):
conn = self._sync_connection()
return await greenlet_spawn(conn.get_isolation_level)
async def set_isolation_level(self):
conn = self._sync_connection()
return await greenlet_spawn(conn.get_isolation_level)
def in_transaction(self):
"""Return True if a transaction is in progress.
.. versionadded:: 1.4.0b2
"""
conn = self._sync_connection()
return conn.in_transaction()
def in_nested_transaction(self):
"""Return True if a transaction is in progress.
.. versionadded:: 1.4.0b2
"""
conn = self._sync_connection()
return conn.in_nested_transaction()
def get_transaction(self):
"""Return an :class:`.AsyncTransaction` representing the current
transaction, if any.
This makes use of the underlying synchronous connection's
:meth:`_engine.Connection.get_transaction` method to get the current
:class:`_engine.Transaction`, which is then proxied in a new
:class:`.AsyncTransaction` object.
.. versionadded:: 1.4.0b2
"""
conn = self._sync_connection()
trans = conn.get_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
def get_nested_transaction(self):
"""Return an :class:`.AsyncTransaction` representing the current
nested (savepoint) transaction, if any.
This makes use of the underlying synchronous connection's
:meth:`_engine.Connection.get_nested_transaction` method to get the
current :class:`_engine.Transaction`, which is then proxied in a new
:class:`.AsyncTransaction` object.
.. versionadded:: 1.4.0b2
"""
conn = self._sync_connection()
trans = conn.get_nested_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
async def execution_options(self, **opt):
r"""Set non-SQL options for the connection which take effect
during execution.
This returns this :class:`_asyncio.AsyncConnection` object with
the new options added.
See :meth:`_future.Connection.execution_options` for full details
on this method.
"""
conn = self._sync_connection()
c2 = await greenlet_spawn(conn.execution_options, **opt)
assert c2 is conn
return self
async def commit(self):
"""Commit the transaction that is currently in progress.
This method commits the current transaction if one has been started.
If no transaction was started, the method has no effect, assuming
the connection is in a non-invalidated state.
A transaction is begun on a :class:`_future.Connection` automatically
whenever a statement is first executed, or when the
:meth:`_future.Connection.begin` method is called.
"""
conn = self._sync_connection()
await greenlet_spawn(conn.commit)
async def rollback(self):
"""Roll back the transaction that is currently in progress.
This method rolls back the current transaction if one has been started.
If no transaction was started, the method has no effect. If a
transaction was started and the connection is in an invalidated state,
the transaction is cleared using this method.
A transaction is begun on a :class:`_future.Connection` automatically
whenever a statement is first executed, or when the
:meth:`_future.Connection.begin` method is called.
"""
conn = self._sync_connection()
await greenlet_spawn(conn.rollback)
async def close(self):
"""Close this :class:`_asyncio.AsyncConnection`.
This has the effect of also rolling back the transaction if one
is in place.
"""
conn = self._sync_connection()
await greenlet_spawn(conn.close)
async def exec_driver_sql(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
r"""Executes a driver-level SQL string and return buffered
:class:`_engine.Result`.
"""
conn = self._sync_connection()
result = await greenlet_spawn(
conn.exec_driver_sql,
statement,
parameters,
execution_options,
_require_await=True,
)
return await _ensure_sync_result(result, self.exec_driver_sql)
async def stream(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
conn = self._sync_connection()
result = await greenlet_spawn(
conn._execute_20,
statement,
parameters,
util.EMPTY_DICT.merge_with(
execution_options, {"stream_results": True}
),
_require_await=True,
)
if not result.context._is_server_side:
# TODO: real exception here
assert False, "server side result expected"
return AsyncResult(result)
async def execute(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
r"""Executes a SQL statement construct and return a buffered
:class:`_engine.Result`.
:param object: The statement to be executed. This is always
an object that is in both the :class:`_expression.ClauseElement` and
:class:`_expression.Executable` hierarchies, including:
* :class:`_expression.Select`
* :class:`_expression.Insert`, :class:`_expression.Update`,
:class:`_expression.Delete`
* :class:`_expression.TextClause` and
:class:`_expression.TextualSelect`
* :class:`_schema.DDL` and objects which inherit from
:class:`_schema.DDLElement`
:param parameters: parameters which will be bound into the statement.
This may be either a dictionary of parameter names to values,
or a mutable sequence (e.g. a list) of dictionaries. When a
list of dictionaries is passed, the underlying statement execution
will make use of the DBAPI ``cursor.executemany()`` method.
When a single dictionary is passed, the DBAPI ``cursor.execute()``
method will be used.
:param execution_options: optional dictionary of execution options,
which will be associated with the statement execution. This
dictionary can provide a subset of the options that are accepted
by :meth:`_future.Connection.execution_options`.
:return: a :class:`_engine.Result` object.
"""
conn = self._sync_connection()
result = await greenlet_spawn(
conn._execute_20,
statement,
parameters,
execution_options,
_require_await=True,
)
return await _ensure_sync_result(result, self.execute)
async def scalar(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
r"""Executes a SQL statement construct and returns a scalar object.
This method is shorthand for invoking the
:meth:`_engine.Result.scalar` method after invoking the
:meth:`_future.Connection.execute` method. Parameters are equivalent.
:return: a scalar Python value representing the first column of the
first row returned.
"""
result = await self.execute(statement, parameters, execution_options)
return result.scalar()
async def scalars(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
r"""Executes a SQL statement construct and returns a scalar objects.
This method is shorthand for invoking the
:meth:`_engine.Result.scalars` method after invoking the
:meth:`_future.Connection.execute` method. Parameters are equivalent.
:return: a :class:`_engine.ScalarResult` object.
.. versionadded:: 1.4.24
"""
result = await self.execute(statement, parameters, execution_options)
return result.scalars()
async def stream_scalars(
self,
statement,
parameters=None,
execution_options=util.EMPTY_DICT,
):
r"""Executes a SQL statement and returns a streaming scalar result
object.
This method is shorthand for invoking the
:meth:`_engine.AsyncResult.scalars` method after invoking the
:meth:`_future.Connection.stream` method. Parameters are equivalent.
:return: an :class:`_asyncio.AsyncScalarResult` object.
.. versionadded:: 1.4.24
"""
result = await self.stream(statement, parameters, execution_options)
return result.scalars()
async def run_sync(self, fn, *arg, **kw):
"""Invoke the given sync callable passing self as the first argument.
This method maintains the asyncio event loop all the way through
to the database connection by running the given callable in a
specially instrumented greenlet.
E.g.::
with async_engine.begin() as conn:
await conn.run_sync(metadata.create_all)
.. note::
The provided callable is invoked inline within the asyncio event
loop, and will block on traditional IO calls. IO within this
callable should only call into SQLAlchemy's asyncio database
APIs which will be properly adapted to the greenlet context.
.. seealso::
:ref:`session_run_sync`
"""
conn = self._sync_connection()
return await greenlet_spawn(fn, conn, *arg, **kw)
def __await__(self):
return self.start().__await__()
async def __aexit__(self, type_, value, traceback):
await asyncio.shield(self.close())
@util.create_proxy_methods(
Engine,
":class:`_future.Engine`",
":class:`_asyncio.AsyncEngine`",
classmethods=[],
methods=[
"clear_compiled_cache",
"update_execution_options",
"get_execution_options",
],
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
)
class AsyncEngine(ProxyComparable, AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
:func:`_asyncio.create_async_engine` function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
.. versionadded:: 1.4
""" # noqa
# AsyncEngine is a thin proxy; no state should be added here
# that is not retrievable from the "sync" engine / connection, e.g.
# current transaction, info, etc. It should be possible to
# create a new AsyncEngine that matches this one given only the
# "sync" elements.
__slots__ = ("sync_engine", "_proxied")
_connection_cls = AsyncConnection
_option_cls: type
class _trans_ctx(StartableContext):
def __init__(self, conn):
self.conn = conn
async def start(self, is_ctxmanager=False):
await self.conn.start(is_ctxmanager=is_ctxmanager)
self.transaction = self.conn.begin()
await self.transaction.__aenter__()
return self.conn
async def __aexit__(self, type_, value, traceback):
async def go():
await self.transaction.__aexit__(type_, value, traceback)
await self.conn.close()
await asyncio.shield(go())
def __init__(self, sync_engine):
if not sync_engine.dialect.is_async:
raise exc.InvalidRequestError(
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
sync_engine: Engine
"""Reference to the sync-style :class:`_engine.Engine` this
:class:`_asyncio.AsyncEngine` proxies requests towards.
This instance can be used as an event target.
.. seealso::
:ref:`asyncio_events`
"""
@classmethod
def _regenerate_proxy_for_target(cls, target):
return AsyncEngine(target)
def begin(self):
"""Return a context manager which when entered will deliver an
:class:`_asyncio.AsyncConnection` with an
:class:`_asyncio.AsyncTransaction` established.
E.g.::
async with async_engine.begin() as conn:
await conn.execute(
text("insert into table (x, y, z) values (1, 2, 3)")
)
await conn.execute(text("my_special_procedure(5)"))
"""
conn = self.connect()
return self._trans_ctx(conn)
def connect(self):
"""Return an :class:`_asyncio.AsyncConnection` object.
The :class:`_asyncio.AsyncConnection` will procure a database
connection from the underlying connection pool when it is entered
as an async context manager::
async with async_engine.connect() as conn:
result = await conn.execute(select(user_table))
The :class:`_asyncio.AsyncConnection` may also be started outside of a
context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
method.
"""
return self._connection_cls(self)
async def raw_connection(self):
"""Return a "raw" DBAPI connection from the connection pool.
.. seealso::
:ref:`dbapi_connections`
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
def execution_options(self, **opt):
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
:class:`_asyncio.AsyncConnection` objects with the given execution
options.
Proxied from :meth:`_future.Engine.execution_options`. See that
method for details.
"""
return AsyncEngine(self.sync_engine.execution_options(**opt))
async def dispose(self):
"""Dispose of the connection pool used by this
:class:`_asyncio.AsyncEngine`.
This will close all connection pool connections that are
**currently checked in**. See the documentation for the underlying
:meth:`_future.Engine.dispose` method for further notes.
.. seealso::
:meth:`_future.Engine.dispose`
"""
await greenlet_spawn(self.sync_engine.dispose)
class AsyncTransaction(ProxyComparable, StartableContext):
"""An asyncio proxy for a :class:`_engine.Transaction`."""
__slots__ = ("connection", "sync_transaction", "nested")
def __init__(self, connection, nested=False):
self.connection = connection # AsyncConnection
self.sync_transaction = None # sqlalchemy.engine.Transaction
self.nested = nested
@classmethod
def _regenerate_proxy_for_target(cls, target):
sync_connection = target.connection
sync_transaction = target
nested = isinstance(target, NestedTransaction)
async_connection = AsyncConnection._retrieve_proxy_for_target(
sync_connection
)
assert async_connection is not None
obj = cls.__new__(cls)
obj.connection = async_connection
obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
def _sync_transaction(self):
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
@property
def _proxied(self):
return self.sync_transaction
@property
def is_valid(self):
return self._sync_transaction().is_valid
@property
def is_active(self):
return self._sync_transaction().is_active
async def close(self):
"""Close this :class:`.Transaction`.
If this transaction is the base transaction in a begin/commit
nesting, the transaction will rollback(). Otherwise, the
method returns.
This is used to cancel a Transaction without affecting the scope of
an enclosing transaction.
"""
await greenlet_spawn(self._sync_transaction().close)
async def rollback(self):
"""Roll back this :class:`.Transaction`."""
await greenlet_spawn(self._sync_transaction().rollback)
async def commit(self):
"""Commit this :class:`.Transaction`."""
await greenlet_spawn(self._sync_transaction().commit)
async def start(self, is_ctxmanager=False):
"""Start this :class:`_asyncio.AsyncTransaction` object's context
outside of using a Python ``with:`` block.
"""
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
self.connection._sync_connection().begin_nested
if self.nested
else self.connection._sync_connection().begin
)
)
if is_ctxmanager:
self.sync_transaction.__enter__()
return self
async def __aexit__(self, type_, value, traceback):
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
def _get_sync_engine_or_connection(async_engine):
if isinstance(async_engine, AsyncConnection):
return async_engine.sync_connection
try:
return async_engine.sync_engine
except AttributeError as e:
raise exc.ArgumentError(
"AsyncEngine expected, got %r" % async_engine
) from e
@inspection._inspects(AsyncConnection)
def _no_insp_for_async_conn_yet(subject):
raise exc.NoInspectionAvailable(
"Inspection on an AsyncConnection is currently not supported. "
"Please use ``run_sync`` to pass a callable where it's possible "
"to call ``inspect`` on the passed connection.",
code="xd3s",
)
@inspection._inspects(AsyncEngine)
def _no_insp_for_async_engine_xyet(subject):
raise exc.NoInspectionAvailable(
"Inspection on an AsyncEngine is currently not supported. "
"Please obtain a connection then use ``conn.run_sync`` to pass a "
"callable where it's possible to call ``inspect`` on the "
"passed connection.",
code="xd3s",
)

View File

@@ -0,0 +1,44 @@
# ext/asyncio/events.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .engine import AsyncConnectable
from .session import AsyncSession
from ...engine import events as engine_event
from ...orm import events as orm_event
class AsyncConnectionEvents(engine_event.ConnectionEvents):
_target_class_doc = "SomeEngine"
_dispatch_target = AsyncConnectable
@classmethod
def _no_async_engine_events(cls):
raise NotImplementedError(
"asynchronous events are not implemented at this time. Apply "
"synchronous listeners to the AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes."
)
@classmethod
def _listen(cls, event_key, retval=False):
cls._no_async_engine_events()
class AsyncSessionEvents(orm_event.SessionEvents):
_target_class_doc = "SomeSession"
_dispatch_target = AsyncSession
@classmethod
def _no_async_engine_events(cls):
raise NotImplementedError(
"asynchronous events are not implemented at this time. Apply "
"synchronous listeners to the AsyncSession.sync_session."
)
@classmethod
def _listen(cls, event_key, retval=False):
cls._no_async_engine_events()

View File

@@ -0,0 +1,21 @@
# ext/asyncio/exc.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ... import exc
class AsyncMethodRequired(exc.InvalidRequestError):
"""an API can't be used because its result would not be
compatible with async"""
class AsyncContextNotStarted(exc.InvalidRequestError):
"""a startable context manager has not been started."""
class AsyncContextAlreadyStarted(exc.InvalidRequestError):
"""a startable context manager is already started."""

View File

@@ -0,0 +1,671 @@
# ext/asyncio/result.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import operator
from . import exc as async_exc
from ...engine.result import _NO_ROW
from ...engine.result import FilterResult
from ...engine.result import FrozenResult
from ...engine.result import MergedResult
from ...sql.base import _generative
from ...util.concurrency import greenlet_spawn
class AsyncCommon(FilterResult):
async def close(self):
"""Close this result."""
await greenlet_spawn(self._real_result.close)
class AsyncResult(AsyncCommon):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
use a server-side cursor. It is returned only from the
:meth:`_asyncio.AsyncConnection.stream` and
:meth:`_asyncio.AsyncSession.stream` methods.
.. note:: As is the case with :class:`_engine.Result`, this object is
used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
which can yield instances of ORM mapped objects either individually or
within tuple-like rows. Note that these result objects do not
deduplicate instances or rows automatically as is the case with the
legacy :class:`_orm.Query` object. For in-Python de-duplication of
instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
method.
.. versionadded:: 1.4
"""
def __init__(self, real_result):
self._real_result = real_result
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
# BaseCursorResult pre-generates the "_row_getter". Use that
# if available rather than building a second one
if "_row_getter" in real_result.__dict__:
self._set_memoized_attribute(
"_row_getter", real_result.__dict__["_row_getter"]
)
def keys(self):
"""Return the :meth:`_engine.Result.keys` collection from the
underlying :class:`_engine.Result`.
"""
return self._metadata.keys
@_generative
def unique(self, strategy=None):
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncResult`.
Refer to :meth:`_engine.Result.unique` in the synchronous
SQLAlchemy API for a complete behavioral description.
"""
self._unique_filter_state = (set(), strategy)
def columns(self, *col_expressions):
r"""Establish the columns that should be returned in each row.
Refer to :meth:`_engine.Result.columns` in the synchronous
SQLAlchemy API for a complete behavioral description.
"""
return self._column_slices(col_expressions)
async def partitions(self, size=None):
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
async def scroll_results(connection):
result = await connection.stream(select(users_table))
async for partition in result.partitions(100):
print("list of rows: %s" % partition)
.. seealso::
:meth:`_engine.Result.partitions`
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchone(self):
"""Fetch one row.
When all rows are exhausted, returns None.
This method is provided for backwards compatibility with
SQLAlchemy 1.x.x.
To fetch the first row of a result only, use the
:meth:`_engine.Result.first` method. To iterate through all
rows, iterate the :class:`_engine.Result` object directly.
:return: a :class:`.Row` object if no filters are applied, or None
if no rows remain.
"""
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
return None
else:
return row
async def fetchmany(self, size=None):
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
This method is provided for backwards compatibility with
SQLAlchemy 1.x.x.
To fetch rows in groups, use the
:meth:`._asyncio.AsyncResult.partitions` method.
:return: a list of :class:`.Row` objects.
.. seealso::
:meth:`_asyncio.AsyncResult.partitions`
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self):
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
will return an empty list.
:return: a list of :class:`.Row` objects.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self):
return self
async def __anext__(self):
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self):
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
.. note:: This method returns one **row**, e.g. tuple, by default. To
return exactly one single scalar value, that is, the first column of
the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
or combine :meth:`_asyncio.AsyncResult.scalars` and
:meth:`_asyncio.AsyncResult.first`.
:return: a :class:`.Row` object, or None
if no rows remain.
.. seealso::
:meth:`_asyncio.AsyncResult.scalar`
:meth:`_asyncio.AsyncResult.one`
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self):
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
Raises :class:`.MultipleResultsFound`
if multiple rows are returned.
.. versionadded:: 1.4
:return: The first :class:`.Row` or None if no row is available.
:raises: :class:`.MultipleResultsFound`
.. seealso::
:meth:`_asyncio.AsyncResult.first`
:meth:`_asyncio.AsyncResult.one`
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
async def scalar_one(self):
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one`.
.. seealso::
:meth:`_asyncio.AsyncResult.one`
:meth:`_asyncio.AsyncResult.scalars`
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
async def scalar_one_or_none(self):
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one_or_none`.
.. seealso::
:meth:`_asyncio.AsyncResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalars`
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
async def one(self):
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
rows, or :class:`.MultipleResultsFound` if multiple rows
would be returned.
.. note:: This method returns one **row**, e.g. tuple, by default.
To return exactly one single scalar value, that is, the first
column of the first row, use the
:meth:`_asyncio.AsyncResult.scalar_one` method, or combine
:meth:`_asyncio.AsyncResult.scalars` and
:meth:`_asyncio.AsyncResult.one`.
.. versionadded:: 1.4
:return: The first :class:`.Row`.
:raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
.. seealso::
:meth:`_asyncio.AsyncResult.first`
:meth:`_asyncio.AsyncResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalar_one`
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
async def scalar(self):
"""Fetch the first column of the first row, and close the result set.
Returns None if there are no rows to fetch.
No validation is performed to test if additional rows remain.
After calling this method, the object is fully closed,
e.g. the :meth:`_engine.CursorResult.close`
method will have been called.
:return: a Python scalar value , or None if no rows remain.
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
async def freeze(self):
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
The callable object returned is an instance of
:class:`_engine.FrozenResult`.
This is used for result set caching. The method must be called
on the result when it has been unconsumed, and calling the method
will consume the result fully. When the :class:`_engine.FrozenResult`
is retrieved from a cache, it can be called any number of times where
it will produce a new :class:`_engine.Result` object each time
against its stored set of rows.
.. seealso::
:ref:`do_orm_execute_re_executing` - example usage within the
ORM to implement a result-set cache.
"""
return await greenlet_spawn(FrozenResult, self)
def merge(self, *others):
"""Merge this :class:`_asyncio.AsyncResult` with other compatible
result objects.
The object returned is an instance of :class:`_engine.MergedResult`,
which will be composed of iterators from the given result
objects.
The new result will use the metadata from this result object.
The subsequent result objects must be against an identical
set of result / cursor metadata, otherwise the behavior is
undefined.
"""
return MergedResult(self._metadata, (self,) + others)
def scalars(self, index=0):
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
Refer to :meth:`_result.Result.scalars` in the synchronous
SQLAlchemy API for a complete behavioral description.
:param index: integer or row key indicating the column to be fetched
from each row, defaults to ``0`` indicating the first column.
:return: a new :class:`_asyncio.AsyncScalarResult` filtering object
referring to this :class:`_asyncio.AsyncResult` object.
"""
return AsyncScalarResult(self._real_result, index)
def mappings(self):
"""Apply a mappings filter to returned rows, returning an instance of
:class:`_asyncio.AsyncMappingResult`.
When this filter is applied, fetching rows will return
:class:`.RowMapping` objects instead of :class:`.Row` objects.
Refer to :meth:`_result.Result.mappings` in the synchronous
SQLAlchemy API for a complete behavioral description.
:return: a new :class:`_asyncio.AsyncMappingResult` filtering object
referring to the underlying :class:`_result.Result` object.
"""
return AsyncMappingResult(self._real_result)
class AsyncScalarResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
:meth:`_asyncio.AsyncResult.scalars` method.
Refer to the :class:`_result.ScalarResult` object in the synchronous
SQLAlchemy API for a complete behavioral description.
.. versionadded:: 1.4
"""
_generate_rows = False
def __init__(self, real_result, index):
self._real_result = real_result
if real_result._source_supports_scalars:
self._metadata = real_result._metadata
self._post_creational_filter = None
else:
self._metadata = real_result._metadata._reduce([index])
self._post_creational_filter = operator.itemgetter(0)
self._unique_filter_state = real_result._unique_filter_state
def unique(self, strategy=None):
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncScalarResult`.
See :meth:`_asyncio.AsyncResult.unique` for usage details.
"""
self._unique_filter_state = (set(), strategy)
return self
async def partitions(self, size=None):
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchall(self):
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
async def fetchmany(self, size=None):
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self):
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self):
return self
async def __anext__(self):
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self):
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self):
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
async def one(self):
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
scalar values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
class AsyncMappingResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
values rather than :class:`_engine.Row` values.
The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
:meth:`_asyncio.AsyncResult.mappings` method.
Refer to the :class:`_result.MappingResult` object in the synchronous
SQLAlchemy API for a complete behavioral description.
.. versionadded:: 1.4
"""
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
def __init__(self, result):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
if result._source_supports_scalars:
self._metadata = self._metadata._reduce([0])
def keys(self):
"""Return an iterable view which yields the string keys that would
be represented by each :class:`.Row`.
The view also can be tested for key containment using the Python
``in`` operator, which will test both for the string keys represented
in the view, as well as for alternate keys such as column objects.
.. versionchanged:: 1.4 a key view object is returned rather than a
plain list.
"""
return self._metadata.keys
def unique(self, strategy=None):
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncMappingResult`.
See :meth:`_asyncio.AsyncResult.unique` for usage details.
"""
self._unique_filter_state = (set(), strategy)
return self
def columns(self, *col_expressions):
r"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
async def partitions(self, size=None):
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
getter = self._manyrow_getter
while True:
partition = await greenlet_spawn(getter, self, size)
if partition:
yield partition
else:
break
async def fetchall(self):
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
async def fetchone(self):
"""Fetch one object.
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
return None
else:
return row
async def fetchmany(self, size=None):
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
async def all(self):
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._allrows)
def __aiter__(self):
return self
async def __anext__(self):
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
async def first(self):
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
async def one_or_none(self):
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
async def one(self):
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
mapping values, rather than :class:`_result.Row` objects,
are returned.
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
async def _ensure_sync_result(result, calling_method):
if not result._is_cursor:
cursor_result = getattr(result, "raw", None)
else:
cursor_result = result
if cursor_result and cursor_result.context._is_server_side:
await greenlet_spawn(cursor_result.close)
raise async_exc.AsyncMethodRequired(
"Can't use the %s.%s() method with a "
"server-side cursor. "
"Use the %s.stream() method for an async "
"streaming result set."
% (
calling_method.__self__.__class__.__name__,
calling_method.__name__,
calling_method.__self__.__class__.__name__,
)
)
return result

View File

@@ -0,0 +1,107 @@
# ext/asyncio/scoping.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .session import AsyncSession
from ...orm.scoping import ScopedSessionMixin
from ...util import create_proxy_methods
from ...util import ScopedRegistry
@create_proxy_methods(
AsyncSession,
":class:`_asyncio.AsyncSession`",
":class:`_asyncio.scoping.async_scoped_session`",
classmethods=["close_all", "object_session", "identity_key"],
methods=[
"__contains__",
"__iter__",
"add",
"add_all",
"begin",
"begin_nested",
"close",
"commit",
"connection",
"delete",
"execute",
"expire",
"expire_all",
"expunge",
"expunge_all",
"flush",
"get",
"get_bind",
"is_modified",
"invalidate",
"merge",
"refresh",
"rollback",
"scalar",
"scalars",
"stream",
"stream_scalars",
],
attributes=[
"bind",
"dirty",
"deleted",
"new",
"identity_map",
"is_active",
"autoflush",
"no_autoflush",
"info",
],
)
class async_scoped_session(ScopedSessionMixin):
"""Provides scoped management of :class:`.AsyncSession` objects.
See the section :ref:`asyncio_scoped_session` for usage details.
.. versionadded:: 1.4.19
"""
_support_async = True
def __init__(self, session_factory, scopefunc):
"""Construct a new :class:`_asyncio.async_scoped_session`.
:param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
instances. This is usually, but not necessarily, an instance
of :class:`_orm.sessionmaker` which itself was passed the
:class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
parameter::
async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
:param scopefunc: function which defines
the current scope. A function such as ``asyncio.current_task``
may be useful here.
""" # noqa: E501
self.session_factory = session_factory
self.registry = ScopedRegistry(session_factory, scopefunc)
@property
def _proxied(self):
return self.registry()
async def remove(self):
"""Dispose of the current :class:`.AsyncSession`, if present.
Different from scoped_session's remove method, this method would use
await to wait for the close method of AsyncSession.
"""
if self.registry.has():
await self.registry().close()
self.registry.clear()

View File

@@ -0,0 +1,759 @@
# ext/asyncio/session.py
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import asyncio
from . import engine
from . import result as _result
from .base import ReversibleProxy
from .base import StartableContext
from .result import _ensure_sync_result
from ... import util
from ...orm import object_session
from ...orm import Session
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
@util.create_proxy_methods(
Session,
":class:`_orm.Session`",
":class:`_asyncio.AsyncSession`",
classmethods=["object_session", "identity_key"],
methods=[
"__contains__",
"__iter__",
"add",
"add_all",
"expire",
"expire_all",
"expunge",
"expunge_all",
"is_modified",
"in_transaction",
"in_nested_transaction",
],
attributes=[
"dirty",
"deleted",
"new",
"identity_map",
"is_active",
"autoflush",
"no_autoflush",
"info",
],
)
class AsyncSession(ReversibleProxy):
"""Asyncio version of :class:`_orm.Session`.
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
:class:`_orm.Session` instance.
.. versionadded:: 1.4
To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
implementations, see the
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
"""
_is_asyncio = True
dispatch = None
def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
r"""Construct a new :class:`_asyncio.AsyncSession`.
All parameters other than ``sync_session_class`` are passed to the
``sync_session_class`` callable directly to instantiate a new
:class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
parameter documentation.
:param sync_session_class:
A :class:`_orm.Session` subclass or other callable which will be used
to construct the :class:`_orm.Session` which will be proxied. This
parameter may be used to provide custom :class:`_orm.Session`
subclasses. Defaults to the
:attr:`_asyncio.AsyncSession.sync_session_class` class-level
attribute.
.. versionadded:: 1.4.24
"""
kw["future"] = True
if bind:
self.bind = bind
bind = engine._get_sync_engine_or_connection(bind)
if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine_or_connection(b)
for key, b in binds.items()
}
if sync_session_class:
self.sync_session_class = sync_session_class
self.sync_session = self._proxied = self._assign_proxied(
self.sync_session_class(bind=bind, binds=binds, **kw)
)
sync_session_class = Session
"""The class or callable that provides the
underlying :class:`_orm.Session` instance for a particular
:class:`_asyncio.AsyncSession`.
At the class level, this attribute is the default value for the
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
subclasses of :class:`_asyncio.AsyncSession` can override this.
At the instance level, this attribute indicates the current class or
callable that was used to provide the :class:`_orm.Session` instance for
this :class:`_asyncio.AsyncSession` instance.
.. versionadded:: 1.4.24
"""
sync_session: Session
"""Reference to the underlying :class:`_orm.Session` this
:class:`_asyncio.AsyncSession` proxies requests towards.
This instance can be used as an event target.
.. seealso::
:ref:`asyncio_events`
"""
async def refresh(
self, instance, attribute_names=None, with_for_update=None
):
"""Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
refreshed with their current database value.
This is the async version of the :meth:`_orm.Session.refresh` method.
See that method for a complete description of all options.
.. seealso::
:meth:`_orm.Session.refresh` - main documentation for refresh
"""
return await greenlet_spawn(
self.sync_session.refresh,
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
async def run_sync(self, fn, *arg, **kw):
"""Invoke the given sync callable passing sync self as the first
argument.
This method maintains the asyncio event loop all the way through
to the database connection by running the given callable in a
specially instrumented greenlet.
E.g.::
with AsyncSession(async_engine) as session:
await session.run_sync(some_business_method)
.. note::
The provided callable is invoked inline within the asyncio event
loop, and will block on traditional IO calls. IO within this
callable should only call into SQLAlchemy's asyncio database
APIs which will be properly adapted to the greenlet context.
.. seealso::
:ref:`session_run_sync`
"""
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
async def execute(
self,
statement,
params=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
**kw
):
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
.. seealso::
:meth:`_orm.Session.execute` - main documentation for execute
"""
if execution_options:
execution_options = util.immutabledict(execution_options).union(
_EXECUTE_OPTIONS
)
else:
execution_options = _EXECUTE_OPTIONS
result = await greenlet_spawn(
self.sync_session.execute,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
return await _ensure_sync_result(result, self.execute)
async def scalar(
self,
statement,
params=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
**kw
):
"""Execute a statement and return a scalar result.
.. seealso::
:meth:`_orm.Session.scalar` - main documentation for scalar
"""
result = await self.execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
return result.scalar()
async def scalars(
self,
statement,
params=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
**kw
):
"""Execute a statement and return scalar results.
:return: a :class:`_result.ScalarResult` object
.. versionadded:: 1.4.24
.. seealso::
:meth:`_orm.Session.scalars` - main documentation for scalars
:meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
"""
result = await self.execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
return result.scalars()
async def get(
self,
entity,
ident,
options=None,
populate_existing=False,
with_for_update=None,
identity_token=None,
):
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
.. seealso::
:meth:`_orm.Session.get` - main documentation for get
"""
return await greenlet_spawn(
self.sync_session.get,
entity,
ident,
options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
)
async def stream(
self,
statement,
params=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
**kw
):
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
"""
if execution_options:
execution_options = util.immutabledict(execution_options).union(
_STREAM_OPTIONS
)
else:
execution_options = _STREAM_OPTIONS
result = await greenlet_spawn(
self.sync_session.execute,
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
return _result.AsyncResult(result)
async def stream_scalars(
self,
statement,
params=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
**kw
):
"""Execute a statement and return a stream of scalar results.
:return: an :class:`_asyncio.AsyncScalarResult` object
.. versionadded:: 1.4.24
.. seealso::
:meth:`_orm.Session.scalars` - main documentation for scalars
:meth:`_asyncio.AsyncSession.scalars` - non streaming version
"""
result = await self.stream(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw
)
return result.scalars()
async def delete(self, instance):
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
As this operation may need to cascade along unloaded relationships,
it is awaitable to allow for those queries to take place.
.. seealso::
:meth:`_orm.Session.delete` - main documentation for delete
"""
return await greenlet_spawn(self.sync_session.delete, instance)
async def merge(self, instance, load=True, options=None):
"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
.. seealso::
:meth:`_orm.Session.merge` - main documentation for merge
"""
return await greenlet_spawn(
self.sync_session.merge, instance, load=load, options=options
)
async def flush(self, objects=None):
"""Flush all the object changes to the database.
.. seealso::
:meth:`_orm.Session.flush` - main documentation for flush
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
def get_transaction(self):
"""Return the current root transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
``None``.
.. versionadded:: 1.4.18
"""
trans = self.sync_session.get_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
else:
return None
def get_nested_transaction(self):
"""Return the current nested transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
``None``.
.. versionadded:: 1.4.18
"""
trans = self.sync_session.get_nested_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
else:
return None
def get_bind(self, mapper=None, clause=None, bind=None, **kw):
"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
Unlike the :meth:`_orm.Session.get_bind` method, this method is
currently **not** used by this :class:`.AsyncSession` in any way
in order to resolve engines for requests.
.. note::
This method proxies directly to the :meth:`_orm.Session.get_bind`
method, however is currently **not** useful as an override target,
in contrast to that of the :meth:`_orm.Session.get_bind` method.
The example below illustrates how to implement custom
:meth:`_orm.Session.get_bind` schemes that work with
:class:`.AsyncSession` and :class:`.AsyncEngine`.
The pattern introduced at :ref:`session_custom_partitioning`
illustrates how to apply a custom bind-lookup scheme to a
:class:`_orm.Session` given a set of :class:`_engine.Engine` objects.
To apply a corresponding :meth:`_orm.Session.get_bind` implementation
for use with a :class:`.AsyncSession` and :class:`.AsyncEngine`
objects, continue to subclass :class:`_orm.Session` and apply it to
:class:`.AsyncSession` using
:paramref:`.AsyncSession.sync_session_class`. The inner method must
continue to return :class:`_engine.Engine` instances, which can be
acquired from a :class:`_asyncio.AsyncEngine` using the
:attr:`_asyncio.AsyncEngine.sync_engine` attribute::
# using example from "Custom Vertical Partitioning"
import random
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session, sessionmaker
# construct async engines w/ async drivers
engines = {
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
}
class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kw):
# within get_bind(), return sync engines
if mapper and issubclass(mapper.class_, MyOtherClass):
return engines['other'].sync_engine
elif self._flushing or isinstance(clause, (Update, Delete)):
return engines['leader'].sync_engine
else:
return engines[
random.choice(['follower1','follower2'])
].sync_engine
# apply to AsyncSession using sync_session_class
AsyncSessionMaker = sessionmaker(
class_=AsyncSession,
sync_session_class=RoutingSession
)
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
implicitly non-blocking context in the same manner as ORM event hooks
and functions that are invoked via :meth:`.AsyncSession.run_sync`, so
routines that wish to run SQL commands inside of
:meth:`_orm.Session.get_bind` can continue to do so using
blocking-style code, which will be translated to implicitly async calls
at the point of invoking IO on the database drivers.
""" # noqa: E501
return self.sync_session.get_bind(
mapper=mapper, clause=clause, bind=bind, **kw
)
async def connection(self, **kw):
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
This method may also be used to establish execution options for the
database connection used by the current transaction.
.. versionadded:: 1.4.24 Added \**kw arguments which are passed
through to the underlying :meth:`_orm.Session.connection` method.
.. seealso::
:meth:`_orm.Session.connection` - main documentation for
"connection"
"""
sync_connection = await greenlet_spawn(
self.sync_session.connection, **kw
)
return engine.AsyncConnection._retrieve_proxy_for_target(
sync_connection
)
def begin(self, **kw):
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
The underlying :class:`_orm.Session` will perform the
"begin" action when the :class:`_asyncio.AsyncSessionTransaction`
object is entered::
async with async_session.begin():
# .. ORM transaction is begun
Note that database IO will not normally occur when the session-level
transaction is begun, as database transactions begin on an
on-demand basis. However, the begin block is async to accommodate
for a :meth:`_orm.SessionEvents.after_transaction_create`
event hook that may perform IO.
For a general description of ORM begin, see
:meth:`_orm.Session.begin`.
"""
return AsyncSessionTransaction(self)
def begin_nested(self, **kw):
"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
For a general description of ORM begin nested, see
:meth:`_orm.Session.begin_nested`.
"""
return AsyncSessionTransaction(self, nested=True)
async def rollback(self):
"""Rollback the current transaction in progress."""
return await greenlet_spawn(self.sync_session.rollback)
async def commit(self):
"""Commit the current transaction in progress."""
return await greenlet_spawn(self.sync_session.commit)
async def close(self):
"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
This expunges all ORM objects associated with this
:class:`_asyncio.AsyncSession`, ends any transaction in progress and
:term:`releases` any :class:`_asyncio.AsyncConnection` objects which
this :class:`_asyncio.AsyncSession` itself has checked out from
associated :class:`_asyncio.AsyncEngine` objects. The operation then
leaves the :class:`_asyncio.AsyncSession` in a state which it may be
used again.
.. tip::
The :meth:`_asyncio.AsyncSession.close` method **does not prevent
the Session from being used again**. The
:class:`_asyncio.AsyncSession` itself does not actually have a
distinct "closed" state; it merely means the
:class:`_asyncio.AsyncSession` will release all database
connections and ORM objects.
.. seealso::
:ref:`session_closing` - detail on the semantics of
:meth:`_asyncio.AsyncSession.close`
"""
await greenlet_spawn(self.sync_session.close)
async def invalidate(self):
"""Close this Session, using connection invalidation.
For a complete description, see :meth:`_orm.Session.invalidate`.
"""
return await greenlet_spawn(self.sync_session.invalidate)
@classmethod
async def close_all(self):
"""Close all :class:`_asyncio.AsyncSession` sessions."""
return await greenlet_spawn(self.sync_session.close_all)
async def __aenter__(self):
return self
async def __aexit__(self, type_, value, traceback):
await asyncio.shield(self.close())
def _maker_context_manager(self):
# no @contextlib.asynccontextmanager until python3.7, gr
return _AsyncSessionContextManager(self)
class _AsyncSessionContextManager:
def __init__(self, async_session):
self.async_session = async_session
async def __aenter__(self):
self.trans = self.async_session.begin()
await self.trans.__aenter__()
return self.async_session
async def __aexit__(self, type_, value, traceback):
async def go():
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)
await asyncio.shield(go())
class AsyncSessionTransaction(ReversibleProxy, StartableContext):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
for the :meth:`_asyncio.AsyncSession.begin` may be returned.
The object supports both explicit calls to
:meth:`_asyncio.AsyncSessionTransaction.commit` and
:meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
async context manager.
.. versionadded:: 1.4
"""
__slots__ = ("session", "sync_transaction", "nested")
def __init__(self, session, nested=False):
self.session = session
self.nested = nested
self.sync_transaction = None
@property
def is_active(self):
return (
self._sync_transaction() is not None
and self._sync_transaction().is_active
)
def _sync_transaction(self):
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
async def rollback(self):
"""Roll back this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().rollback)
async def commit(self):
"""Commit this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().commit)
async def start(self, is_ctxmanager=False):
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
self.session.sync_session.begin_nested
if self.nested
else self.session.sync_session.begin
)
)
if is_ctxmanager:
self.sync_transaction.__enter__()
return self
async def __aexit__(self, type_, value, traceback):
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
def async_object_session(instance):
"""Return the :class:`_asyncio.AsyncSession` to which the given instance
belongs.
This function makes use of the sync-API function
:class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
refers to the given instance, and from there links it to the original
:class:`_asyncio.AsyncSession`.
If the :class:`_asyncio.AsyncSession` has been garbage collected, the
return value is ``None``.
This functionality is also available from the
:attr:`_orm.InstanceState.async_session` accessor.
:param instance: an ORM mapped instance
:return: an :class:`_asyncio.AsyncSession` object, or ``None``.
.. versionadded:: 1.4.18
"""
session = object_session(instance)
if session is not None:
return async_session(session)
else:
return None
def async_session(session):
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
:class:`_orm.Session` object, if any.
:param session: a :class:`_orm.Session` instance.
:return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
.. versionadded:: 1.4.18
"""
return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
_instance_state._async_provider = async_session

File diff suppressed because it is too large Load Diff

648
lib/sqlalchemy/ext/baked.py Normal file
View File

@@ -0,0 +1,648 @@
# sqlalchemy/ext/baked.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Baked query extension.
Provides a creational pattern for the :class:`.query.Query` object which
allows the fully constructed object, Core select statement, and string
compiled result to be fully cached.
"""
import logging
from .. import exc as sa_exc
from .. import util
from ..orm import exc as orm_exc
from ..orm import strategy_options
from ..orm.query import Query
from ..orm.session import Session
from ..sql import func
from ..sql import literal_column
from ..sql import util as sql_util
from ..util import collections_abc
log = logging.getLogger(__name__)
class Bakery(object):
"""Callable which returns a :class:`.BakedQuery`.
This object is returned by the class method
:meth:`.BakedQuery.bakery`. It exists as an object
so that the "cache" can be easily inspected.
.. versionadded:: 1.2
"""
__slots__ = "cls", "cache"
def __init__(self, cls_, cache):
self.cls = cls_
self.cache = cache
def __call__(self, initial_fn, *args):
return self.cls(self.cache, initial_fn, args)
class BakedQuery(object):
"""A builder object for :class:`.query.Query` objects."""
__slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
def __init__(self, bakery, initial_fn, args=()):
self._cache_key = ()
self._update_cache_key(initial_fn, args)
self.steps = [initial_fn]
self._spoiled = False
self._bakery = bakery
@classmethod
def bakery(cls, size=200, _size_alert=None):
"""Construct a new bakery.
:return: an instance of :class:`.Bakery`
"""
return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
def _clone(self):
b1 = BakedQuery.__new__(BakedQuery)
b1._cache_key = self._cache_key
b1.steps = list(self.steps)
b1._bakery = self._bakery
b1._spoiled = self._spoiled
return b1
def _update_cache_key(self, fn, args=()):
self._cache_key += (fn.__code__,) + args
def __iadd__(self, other):
if isinstance(other, tuple):
self.add_criteria(*other)
else:
self.add_criteria(other)
return self
def __add__(self, other):
if isinstance(other, tuple):
return self.with_criteria(*other)
else:
return self.with_criteria(other)
def add_criteria(self, fn, *args):
"""Add a criteria function to this :class:`.BakedQuery`.
This is equivalent to using the ``+=`` operator to
modify a :class:`.BakedQuery` in-place.
"""
self._update_cache_key(fn, args)
self.steps.append(fn)
return self
def with_criteria(self, fn, *args):
"""Add a criteria function to a :class:`.BakedQuery` cloned from this
one.
This is equivalent to using the ``+`` operator to
produce a new :class:`.BakedQuery` with modifications.
"""
return self._clone().add_criteria(fn, *args)
def for_session(self, session):
"""Return a :class:`_baked.Result` object for this
:class:`.BakedQuery`.
This is equivalent to calling the :class:`.BakedQuery` as a
Python callable, e.g. ``result = my_baked_query(session)``.
"""
return Result(self, session)
def __call__(self, session):
return self.for_session(session)
def spoil(self, full=False):
"""Cancel any query caching that will occur on this BakedQuery object.
The BakedQuery can continue to be used normally, however additional
creational functions will not be cached; they will be called
on every invocation.
This is to support the case where a particular step in constructing
a baked query disqualifies the query from being cacheable, such
as a variant that relies upon some uncacheable value.
:param full: if False, only functions added to this
:class:`.BakedQuery` object subsequent to the spoil step will be
non-cached; the state of the :class:`.BakedQuery` up until
this point will be pulled from the cache. If True, then the
entire :class:`_query.Query` object is built from scratch each
time, with all creational functions being called on each
invocation.
"""
if not full and not self._spoiled:
_spoil_point = self._clone()
_spoil_point._cache_key += ("_query_only",)
self.steps = [_spoil_point._retrieve_baked_query]
self._spoiled = True
return self
def _effective_key(self, session):
"""Return the key that actually goes into the cache dictionary for
this :class:`.BakedQuery`, taking into account the given
:class:`.Session`.
This basically means we also will include the session's query_class,
as the actual :class:`_query.Query` object is part of what's cached
and needs to match the type of :class:`_query.Query` that a later
session will want to use.
"""
return self._cache_key + (session._query_cls,)
def _with_lazyload_options(self, options, effective_path, cache_path=None):
"""Cloning version of _add_lazyload_options."""
q = self._clone()
q._add_lazyload_options(options, effective_path, cache_path=cache_path)
return q
def _add_lazyload_options(self, options, effective_path, cache_path=None):
"""Used by per-state lazy loaders to add options to the
"lazy load" query from a parent query.
Creates a cache key based on given load path and query options;
if a repeatable cache key cannot be generated, the query is
"spoiled" so that it won't use caching.
"""
key = ()
if not cache_path:
cache_path = effective_path
for opt in options:
if opt._is_legacy_option or opt._is_compile_state:
ck = opt._generate_cache_key()
if ck is None:
self.spoil(full=True)
else:
assert not ck[1], (
"loader options with variable bound parameters "
"not supported with baked queries. Please "
"use new-style select() statements for cached "
"ORM queries."
)
key += ck[0]
self.add_criteria(
lambda q: q._with_current_path(effective_path).options(*options),
cache_path.path,
key,
)
def _retrieve_baked_query(self, session):
query = self._bakery.get(self._effective_key(session), None)
if query is None:
query = self._as_query(session)
self._bakery[self._effective_key(session)] = query.with_session(
None
)
return query.with_session(session)
def _bake(self, session):
query = self._as_query(session)
query.session = None
# in 1.4, this is where before_compile() event is
# invoked
statement = query._statement_20()
# if the query is not safe to cache, we still do everything as though
# we did cache it, since the receiver of _bake() assumes subqueryload
# context was set up, etc.
#
# note also we want to cache the statement itself because this
# allows the statement itself to hold onto its cache key that is
# used by the Connection, which in itself is more expensive to
# generate than what BakedQuery was able to provide in 1.3 and prior
if statement._compile_options._bake_ok:
self._bakery[self._effective_key(session)] = (
query,
statement,
)
return query, statement
def to_query(self, query_or_session):
"""Return the :class:`_query.Query` object for use as a subquery.
This method should be used within the lambda callable being used
to generate a step of an enclosing :class:`.BakedQuery`. The
parameter should normally be the :class:`_query.Query` object that
is passed to the lambda::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(
User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(lambda s: s.query(Address))
main_bq += lambda q: q.filter(
sub_bq.to_query(q).exists())
In the case where the subquery is used in the first callable against
a :class:`.Session`, the :class:`.Session` is also accepted::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(
User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(
lambda s: s.query(
Address.id, sub_bq.to_query(q).scalar_subquery())
)
:param query_or_session: a :class:`_query.Query` object or a class
:class:`.Session` object, that is assumed to be within the context
of an enclosing :class:`.BakedQuery` callable.
.. versionadded:: 1.3
"""
if isinstance(query_or_session, Session):
session = query_or_session
elif isinstance(query_or_session, Query):
session = query_or_session.session
if session is None:
raise sa_exc.ArgumentError(
"Given Query needs to be associated with a Session"
)
else:
raise TypeError(
"Query or Session object expected, got %r."
% type(query_or_session)
)
return self._as_query(session)
def _as_query(self, session):
query = self.steps[0](session)
for step in self.steps[1:]:
query = step(query)
return query
class Result(object):
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
The :class:`_baked.Result` object is where the actual :class:`.query.Query`
object gets created, or retrieved from the cache,
against a target :class:`.Session`, and is then invoked for results.
"""
__slots__ = "bq", "session", "_params", "_post_criteria"
def __init__(self, bq, session):
self.bq = bq
self.session = session
self._params = {}
self._post_criteria = []
def params(self, *args, **kw):
"""Specify parameters to be replaced into the string SQL statement."""
if len(args) == 1:
kw.update(args[0])
elif len(args) > 0:
raise sa_exc.ArgumentError(
"params() takes zero or one positional argument, "
"which is a dictionary."
)
self._params.update(kw)
return self
def _using_post_criteria(self, fns):
if fns:
self._post_criteria.extend(fns)
return self
def with_post_criteria(self, fn):
"""Add a criteria function that will be applied post-cache.
This adds a function that will be run against the
:class:`_query.Query` object after it is retrieved from the
cache. This currently includes **only** the
:meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
methods.
.. warning:: :meth:`_baked.Result.with_post_criteria`
functions are applied
to the :class:`_query.Query`
object **after** the query's SQL statement
object has been retrieved from the cache. Only
:meth:`_query.Query.params` and
:meth:`_query.Query.execution_options`
methods should be used.
.. versionadded:: 1.2
"""
return self._using_post_criteria([fn])
def _as_query(self):
q = self.bq._as_query(self.session).params(self._params)
for fn in self._post_criteria:
q = fn(q)
return q
def __str__(self):
return str(self._as_query())
def __iter__(self):
return self._iter().__iter__()
def _iter(self):
bq = self.bq
if not self.session.enable_baked_queries or bq._spoiled:
return self._as_query()._iter()
query, statement = bq._bakery.get(
bq._effective_key(self.session), (None, None)
)
if query is None:
query, statement = bq._bake(self.session)
if self._params:
q = query.params(self._params)
else:
q = query
for fn in self._post_criteria:
q = fn(q)
params = q._params
execution_options = dict(q._execution_options)
execution_options.update(
{
"_sa_orm_load_options": q.load_options,
"compiled_cache": bq._bakery,
}
)
result = self.session.execute(
statement, params, execution_options=execution_options
)
if result._attributes.get("is_single_entity", False):
result = result.scalars()
if result._attributes.get("filtered", False):
result = result.unique()
return result
def count(self):
"""return the 'count'.
Equivalent to :meth:`_query.Query.count`.
Note this uses a subquery to ensure an accurate count regardless
of the structure of the original statement.
.. versionadded:: 1.1.6
"""
col = func.count(literal_column("*"))
bq = self.bq.with_criteria(lambda q: q._from_self(col))
return bq.for_session(self.session).params(self._params).scalar()
def scalar(self):
"""Return the first element of the first result or None
if no rows present. If multiple rows are returned,
raises MultipleResultsFound.
Equivalent to :meth:`_query.Query.scalar`.
.. versionadded:: 1.1.6
"""
try:
ret = self.one()
if not isinstance(ret, collections_abc.Sequence):
return ret
return ret[0]
except orm_exc.NoResultFound:
return None
def first(self):
"""Return the first row.
Equivalent to :meth:`_query.Query.first`.
"""
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
return (
bq.for_session(self.session)
.params(self._params)
._using_post_criteria(self._post_criteria)
._iter()
.first()
)
def one(self):
"""Return exactly one result or raise an exception.
Equivalent to :meth:`_query.Query.one`.
"""
return self._iter().one()
def one_or_none(self):
"""Return one or zero results, or raise an exception for multiple
rows.
Equivalent to :meth:`_query.Query.one_or_none`.
.. versionadded:: 1.0.9
"""
return self._iter().one_or_none()
def all(self):
"""Return all rows.
Equivalent to :meth:`_query.Query.all`.
"""
return self._iter().all()
def get(self, ident):
"""Retrieve an object based on identity.
Equivalent to :meth:`_query.Query.get`.
"""
query = self.bq.steps[0](self.session)
return query._get_impl(ident, self._load_on_pk_identity)
def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
"""Load the given primary key identity from the database."""
mapper = query._raw_columns[0]._annotations["parententity"]
_get_clause, _get_params = mapper._get_clause
def setup(query):
_lcl_get_clause = _get_clause
q = query._clone()
q._get_condition()
q._order_by = None
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
nones = set(
[
_get_params[col].key
for col, value in zip(
mapper.primary_key, primary_key_identity
)
if value is None
]
)
_lcl_get_clause = sql_util.adapt_criterion_to_null(
_lcl_get_clause, nones
)
# TODO: can mapper._get_clause be pre-adapted?
q._where_criteria = (
sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
)
for fn in self._post_criteria:
q = fn(q)
return q
# cache the query against a key that includes
# which positions in the primary key are NULL
# (remember, we can map to an OUTER JOIN)
bq = self.bq
# add the clause we got from mapper._get_clause to the cache
# key so that if a race causes multiple calls to _get_clause,
# we've cached on ours
bq = bq._clone()
bq._cache_key += (_get_clause,)
bq = bq.with_criteria(
setup, tuple(elem is None for elem in primary_key_identity)
)
params = dict(
[
(_get_params[primary_key].key, id_val)
for id_val, primary_key in zip(
primary_key_identity, mapper.primary_key
)
]
)
result = list(bq.for_session(self.session).params(**params))
l = len(result)
if l > 1:
raise orm_exc.MultipleResultsFound()
elif l:
return result[0]
else:
return None
@util.deprecated(
"1.2", "Baked lazy loading is now the default implementation."
)
def bake_lazy_loaders():
"""Enable the use of baked queries for all lazyloaders systemwide.
The "baked" implementation of lazy loading is now the sole implementation
for the base lazy loader; this method has no effect except for a warning.
"""
pass
@util.deprecated(
"1.2", "Baked lazy loading is now the default implementation."
)
def unbake_lazy_loaders():
"""Disable the use of baked queries for all lazyloaders systemwide.
This method now raises NotImplementedError() as the "baked" implementation
is the only lazy load implementation. The
:paramref:`_orm.relationship.bake_queries` flag may be used to disable
the caching of queries on a per-relationship basis.
"""
raise NotImplementedError(
"Baked lazy loading is now the default implementation"
)
@strategy_options.loader_option()
def baked_lazyload(loadopt, attr):
"""Indicate that the given attribute should be loaded using "lazy"
loading with a "baked" query used in the load.
"""
return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
@baked_lazyload._add_unbound_fn
@util.deprecated(
"1.2",
"Baked lazy loading is now the default "
"implementation for lazy loading.",
)
def baked_lazyload(*keys):
return strategy_options._UnboundLoad._from_keys(
strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
)
@baked_lazyload._add_unbound_all_fn
@util.deprecated(
"1.2",
"Baked lazy loading is now the default "
"implementation for lazy loading.",
)
def baked_lazyload_all(*keys):
return strategy_options._UnboundLoad._from_keys(
strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
)
baked_lazyload = baked_lazyload._unbound_fn
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
bakery = BakedQuery.bakery

View File

@@ -0,0 +1,613 @@
# ext/compiler.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""Provides an API for creation of custom ClauseElements and compilers.
Synopsis
========
Usage involves the creation of one or more
:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
the base expression element for named column objects. The ``compiles``
decorator registers itself with the ``MyColumn`` class so that it is invoked
when the object is compiled to a string::
from sqlalchemy import select
s = select(MyColumn('x'), MyColumn('y'))
print(str(s))
Produces::
SELECT [x], [y]
Dialect-specific compilation rules
==================================
Compilers can also be made dialect-specific. The appropriate compiler will be
invoked for the dialect in use::
from sqlalchemy.schema import DDLElement
class AlterColumn(DDLElement):
inherit_cache = False
def __init__(self, column, cmd):
self.column = column
self.cmd = cmd
@compiles(AlterColumn)
def visit_alter_column(element, compiler, **kw):
return "ALTER COLUMN %s ..." % element.column.name
@compiles(AlterColumn, 'postgresql')
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
element.column.name)
The second ``visit_alter_table`` will be invoked when any ``postgresql``
dialect is used.
.. _compilerext_compiling_subelements:
Compiling sub-elements of a custom expression construct
=======================================================
The ``compiler`` argument is the
:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
can be inspected for any information about the in-progress compilation,
including ``compiler.dialect``, ``compiler.statement`` etc. The
:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement):
inherit_cache = False
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw)
)
insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
print(insert)
Produces::
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
FROM mytable WHERE mytable.x > :x_1)"
.. note::
The above ``InsertFromSelect`` construct is only an example, this actual
functionality is already available using the
:meth:`_expression.Insert.from_select` method.
.. note::
The above ``InsertFromSelect`` construct probably wants to have "autocommit"
enabled. See :ref:`enabling_compiled_autocommit` for this step.
Cross Compiling between SQL and DDL compilers
---------------------------------------------
SQL and DDL constructs are each compiled using different base compilers -
``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
compilation rules of SQL expressions from within a DDL expression. The
``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
below where we generate a CHECK constraint that embeds a SQL expression::
@compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw):
kw['literal_binds'] = True
return "CONSTRAINT %s CHECK (%s)" % (
constraint.name,
ddlcompiler.sql_compiler.process(
constraint.expression, **kw)
)
Above, we add an additional flag to the process step as called by
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
indicates that any SQL expression which refers to a :class:`.BindParameter`
object or other "literal" object such as those which refer to strings or
integers should be rendered **in-place**, rather than being referred to as
a bound parameter; when emitting DDL, bound parameters are typically not
supported.
.. _enabling_compiled_autocommit:
Enabling Autocommit on a Construct
==================================
Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`,
when
asked to execute a construct in the absence of a user-defined transaction,
detects if the given construct represents DML or DDL, that is, a data
modification or data definition statement, which requires (or may require,
in the case of DDL) that the transaction generated by the DBAPI be committed
(recall that DBAPI always has a transaction going on regardless of what
SQLAlchemy does). Checking for this is actually accomplished by checking for
the "autocommit" execution option on the construct. When building a
construct like an INSERT derivation, a new DDL type, or perhaps a stored
procedure that alters data, the "autocommit" option needs to be set in order
for the statement to function with "connectionless" execution
(as described in :ref:`dbengine_implicit`).
Currently a quick way to do this is to subclass :class:`.Executable`, then
add the "autocommit" flag to the ``_execution_options`` dictionary (note this
is a "frozen" dictionary which supplies a generative ``union()`` method)::
from sqlalchemy.sql.expression import Executable, ClauseElement
class MyInsertThing(Executable, ClauseElement):
_execution_options = \
Executable._execution_options.union({'autocommit': True})
More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
DELETE, :class:`.UpdateBase` can be used, which already is a subclass
of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the
``autocommit`` flag::
from sqlalchemy.sql.expression import UpdateBase
class MyInsertThing(UpdateBase):
def __init__(self, ...):
...
DDL elements that subclass :class:`.DDLElement` already have the
"autocommit" flag turned on.
Changing the default compilation of existing constructs
=======================================================
The compiler extension applies just as well to the existing constructs. When
overriding the compilation of a built in SQL construct, the @compiles
decorator is invoked upon the appropriate class (be sure to use the class,
i.e. ``Insert`` or ``Select``, instead of the creation function such
as ``insert()`` or ``select()``).
Within the new compilation function, to get at the "original" compilation
routine, use the appropriate visit_XXX method - this
because compiler.process() will call upon the overriding routine and cause
an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert
@compiles(Insert)
def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
The above compiler will prefix all INSERT statements with "some prefix" when
compiled.
.. _type_compilation_extension:
Changing Compilation of Types
=============================
``compiler`` works for types, too, such as below where we implement the
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
@compiles(String, 'mssql')
@compiles(VARCHAR, 'mssql')
def compile_varchar(element, compiler, **kw):
if element.length == 'max':
return "VARCHAR('max')"
else:
return compiler.visit_VARCHAR(element, **kw)
foo = Table('foo', metadata,
Column('data', VARCHAR('max'))
)
Subclassing Guidelines
======================
A big part of using the compiler extension is subclassing SQLAlchemy
expression constructs. To make this easier, the expression and
schema packages feature a set of "bases" intended for common tasks.
A synopsis is as follows:
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
expression class. Any SQL expression can be derived from this base, and is
probably the best choice for longer constructs such as specialized INSERT
statements.
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
"column-like" elements. Anything that you'd place in the "columns" clause of
a SELECT statement (as well as order by and group by) can derive from this -
the object will automatically have Python "comparison" behavior.
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
``type`` member which is expression's return type. This can be established
at the instance level in the constructor, or at the class level if its
generally constant::
class timestamp(ColumnElement):
type = TIMESTAMP()
inherit_cache = True
* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
``ColumnElement`` and a "from clause" like object, and represents a SQL
function or stored procedure type of call. Since most databases support
statements along the line of "SELECT FROM <some function>"
``FunctionElement`` adds in the ability to be used in the FROM clause of a
``select()`` construct::
from sqlalchemy.sql.expression import FunctionElement
class coalesce(FunctionElement):
name = 'coalesce'
inherit_cache = True
@compiles(coalesce)
def compile(element, compiler, **kw):
return "coalesce(%s)" % compiler.process(element.clauses, **kw)
@compiles(coalesce, 'oracle')
def compile(element, compiler, **kw):
if len(element.clauses) > 2:
raise TypeError("coalesce only supports two arguments on Oracle")
return "nvl(%s)" % compiler.process(element.clauses, **kw)
* :class:`.DDLElement` - The root of all DDL expressions,
like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement`
subclasses is issued by a :class:`.DDLCompiler` instead of a
:class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook
in conjunction with event hooks like :meth:`.DDLEvents.before_create` and
:meth:`.DDLEvents.after_create`, allowing the construct to be invoked
automatically during CREATE TABLE and DROP TABLE sequences.
.. seealso::
:ref:`metadata_ddl_toplevel` - contains examples of associating
:class:`.DDL` objects (which are themselves :class:`.DDLElement`
instances) with :class:`.DDLEvents` event hooks.
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
should be used with any expression class that represents a "standalone"
SQL statement that can be passed directly to an ``execute()`` method. It
is already implicit within ``DDLElement`` and ``FunctionElement``.
Most of the above constructs also respond to SQL statement caching. A
subclassed construct will want to define the caching behavior for the object,
which usually means setting the flag ``inherit_cache`` to the value of
``False`` or ``True``. See the next section :ref:`compilerext_caching`
for background.
.. _compilerext_caching:
Enabling Caching Support for Custom Constructs
==============================================
SQLAlchemy as of version 1.4 includes a
:ref:`SQL compilation caching facility <sql_caching>` which will allow
equivalent SQL constructs to cache their stringified form, along with other
structural information used to fetch results from the statement.
For reasons discussed at :ref:`caching_caveats`, the implementation of this
caching system takes a conservative approach towards including custom SQL
constructs and/or subclasses within the caching system. This includes that
any user-defined SQL constructs, including all the examples for this
extension, will not participate in caching by default unless they positively
assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
attribute when set to ``True`` at the class level of a specific subclass
will indicate that instances of this class may be safely cached, using the
cache key generation scheme of the immediate superclass. This applies
for example to the "synopsis" example indicated previously::
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
Above, the ``MyColumn`` class does not include any new state that
affects its SQL compilation; the cache key of ``MyColumn`` instances will
make use of that of the ``ColumnClause`` superclass, meaning it will take
into account the class of the object (``MyColumn``), the string name and
datatype of the object::
>>> MyColumn("some_name", String())._generate_cache_key()
CacheKey(
key=('0', <class '__main__.MyColumn'>,
'name', 'some_name',
'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
('length', None), ('collation', None))
), bindparams=[])
For objects that are likely to be **used liberally as components within many
larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
datatypes, it's important that **caching be enabled as much as possible**, as
this may otherwise negatively affect performance.
An example of an object that **does** contain state which affects its SQL
compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
this is an "INSERT FROM SELECT" construct that combines together a
:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
which independently affect the SQL string generation of the construct. For
this class, the example illustrates that it simply does not participate in
caching::
class InsertFromSelect(Executable, ClauseElement):
inherit_cache = False
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw)
)
While it is also possible that the above ``InsertFromSelect`` could be made to
produce a cache key that is composed of that of the :class:`_schema.Table` and
:class:`_sql.Select` components together, the API for this is not at the moment
fully public. However, for an "INSERT FROM SELECT" construct, which is only
used by itself for specific operations, caching is not as critical as in the
previous example.
For objects that are **used in relative isolation and are generally
standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
SELECT", **caching is generally less critical** as the lack of caching for such
a construct will have only localized implications for that specific operation.
Further Examples
================
"UTC timestamp" function
-------------------------
A function that works like "CURRENT_TIMESTAMP" except applies the
appropriate conversions so that the time is in UTC time. Timestamps are best
stored in relational databases as UTC, without time zones. UTC so that your
database doesn't think time has gone backwards in the hour when daylight
savings ends, without timezones because timezones are like character
encodings - they're best applied only at the endpoints of an application
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
For PostgreSQL and Microsoft SQL Server::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import DateTime
class utcnow(expression.FunctionElement):
type = DateTime()
inherit_cache = True
@compiles(utcnow, 'postgresql')
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@compiles(utcnow, 'mssql')
def ms_utcnow(element, compiler, **kw):
return "GETUTCDATE()"
Example usage::
from sqlalchemy import (
Table, Column, Integer, String, DateTime, MetaData
)
metadata = MetaData()
event = Table("event", metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50), nullable=False),
Column("timestamp", DateTime, server_default=utcnow())
)
"GREATEST" function
-------------------
The "GREATEST" function is given any number of arguments and returns the one
that is of the highest value - its equivalent to Python's ``max``
function. A SQL standard version versus a CASE based version which only
accommodates two arguments::
from sqlalchemy.sql import expression, case
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import Numeric
class greatest(expression.FunctionElement):
type = Numeric()
name = 'greatest'
inherit_cache = True
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, 'sqlite')
@compiles(greatest, 'mssql')
@compiles(greatest, 'oracle')
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw)
Example usage::
Session.query(Account).\
filter(
greatest(
Account.checking_balance,
Account.savings_balance) > 10000
)
"false" expression
------------------
Render a "false" constant expression, rendering as "0" on platforms that
don't have a "false" constant::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
class sql_false(expression.ColumnElement):
inherit_cache = True
@compiles(sql_false)
def default_false(element, compiler, **kw):
return "false"
@compiles(sql_false, 'mssql')
@compiles(sql_false, 'mysql')
@compiles(sql_false, 'oracle')
def int_false(element, compiler, **kw):
return "0"
Example usage::
from sqlalchemy import select, union_all
exp = union_all(
select(users.c.name, sql_false().label("enrolled")),
select(customers.c.name, customers.c.enrolled)
)
"""
from .. import exc
from .. import util
from ..sql import sqltypes
def compiles(class_, *specs):
"""Register a function as a compiler for a
given :class:`_expression.ClauseElement` type."""
def decorate(fn):
# get an existing @compiles handler
existing = class_.__dict__.get("_compiler_dispatcher", None)
# get the original handler. All ClauseElement classes have one
# of these, but some TypeEngine classes will not.
existing_dispatch = getattr(class_, "_compiler_dispatch", None)
if not existing:
existing = _dispatcher()
if existing_dispatch:
def _wrap_existing_dispatch(element, compiler, **kw):
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError as uce:
util.raise_(
exc.UnsupportedCompilationError(
compiler,
type(element),
message="%s construct has no default "
"compilation handler." % type(element),
),
from_=uce,
)
existing.specs["default"] = _wrap_existing_dispatch
# TODO: why is the lambda needed ?
setattr(
class_,
"_compiler_dispatch",
lambda *arg, **kw: existing(*arg, **kw),
)
setattr(class_, "_compiler_dispatcher", existing)
if specs:
for s in specs:
existing.specs[s] = fn
else:
existing.specs["default"] = fn
return fn
return decorate
def deregister(class_):
"""Remove all custom compilers associated with a given
:class:`_expression.ClauseElement` type.
"""
if hasattr(class_, "_compiler_dispatcher"):
class_._compiler_dispatch = class_._original_compiler_dispatch
del class_._compiler_dispatcher
class _dispatcher(object):
def __init__(self):
self.specs = {}
def __call__(self, element, compiler, **kw):
# TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None)
if not fn:
try:
fn = self.specs["default"]
except KeyError as ke:
util.raise_(
exc.UnsupportedCompilationError(
compiler,
type(element),
message="%s construct has no default "
"compilation handler." % type(element),
),
replace_context=ke,
)
# if compilation includes add_to_result_map, collect add_to_result_map
# arguments from the user-defined callable, which are probably none
# because this is not public API. if it wasn't called, then call it
# ourselves.
arm = kw.get("add_to_result_map", None)
if arm:
arm_collection = []
kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
expr = fn(element, compiler, **kw)
if arm:
if not arm_collection:
arm_collection.append(
(None, None, (element,), sqltypes.NULLTYPE)
)
for tup in arm_collection:
arm(*tup)
return expr

View File

@@ -0,0 +1,64 @@
# ext/declarative/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .extensions import AbstractConcreteBase
from .extensions import ConcreteBase
from .extensions import DeferredReflection
from .extensions import instrument_declarative
from ... import util
from ...orm.decl_api import as_declarative as _as_declarative
from ...orm.decl_api import declarative_base as _declarative_base
from ...orm.decl_api import DeclarativeMeta
from ...orm.decl_api import declared_attr
from ...orm.decl_api import has_inherited_table as _has_inherited_table
from ...orm.decl_api import synonym_for as _synonym_for
@util.moved_20(
"The ``declarative_base()`` function is now available as "
":func:`sqlalchemy.orm.declarative_base`."
)
def declarative_base(*arg, **kw):
return _declarative_base(*arg, **kw)
@util.moved_20(
"The ``as_declarative()`` function is now available as "
":func:`sqlalchemy.orm.as_declarative`"
)
def as_declarative(*arg, **kw):
return _as_declarative(*arg, **kw)
@util.moved_20(
"The ``has_inherited_table()`` function is now available as "
":func:`sqlalchemy.orm.has_inherited_table`."
)
def has_inherited_table(*arg, **kw):
return _has_inherited_table(*arg, **kw)
@util.moved_20(
"The ``synonym_for()`` function is now available as "
":func:`sqlalchemy.orm.synonym_for`"
)
def synonym_for(*arg, **kw):
return _synonym_for(*arg, **kw)
__all__ = [
"declarative_base",
"synonym_for",
"has_inherited_table",
"instrument_declarative",
"declared_attr",
"as_declarative",
"ConcreteBase",
"AbstractConcreteBase",
"DeclarativeMeta",
"DeferredReflection",
]

View File

@@ -0,0 +1,463 @@
# ext/declarative/extensions.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Public API functions and helpers for declarative."""
from ... import inspection
from ... import util
from ...orm import exc as orm_exc
from ...orm import registry
from ...orm import relationships
from ...orm.base import _mapper_or_none
from ...orm.clsregistry import _resolver
from ...orm.decl_base import _DeferredMapperConfig
from ...orm.util import polymorphic_union
from ...schema import Table
from ...util import OrderedDict
@util.deprecated(
"2.0",
"the instrument_declarative function is deprecated "
"and will be removed in SQLAlhcemy 2.0. Please use "
":meth:`_orm.registry.map_declaratively",
)
def instrument_declarative(cls, cls_registry, metadata):
"""Given a class, configure the class declaratively,
using the given registry, which can be any dictionary, and
MetaData object.
"""
registry(metadata=metadata, class_registry=cls_registry).map_declaratively(
cls
)
class ConcreteBase(object):
"""A helper class for 'concrete' declarative mappings.
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
:class:`.ConcreteBase` produces a mapped
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
which does not.
Example::
from sqlalchemy.ext.declarative import ConcreteBase
class Employee(ConcreteBase, Base):
__tablename__ = 'employee'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {
'polymorphic_identity':'employee',
'concrete':True}
class Manager(Employee):
__tablename__ = 'manager'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
The name of the discriminator column used by :func:`.polymorphic_union`
defaults to the name ``type``. To suit the use case of a mapping where an
actual column in a mapped table is already named ``type``, the
discriminator name can be configured by setting the
``_concrete_discriminator_name`` attribute::
class Employee(ConcreteBase, Base):
_concrete_discriminator_name = '_concrete_discriminator'
.. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
attribute to :class:`_declarative.ConcreteBase` so that the
virtual discriminator column name can be customized.
.. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
need only be placed on the basemost class to take correct effect for
all subclasses. An explicit error message is now raised if the
mapped column names conflict with the discriminator name, whereas
in the 1.3.x series there would be some warnings and then a non-useful
query would be generated.
.. seealso::
:class:`.AbstractConcreteBase`
:ref:`concrete_inheritance`
"""
@classmethod
def _create_polymorphic_union(cls, mappers, discriminator_name):
return polymorphic_union(
OrderedDict(
(mp.polymorphic_identity, mp.local_table) for mp in mappers
),
discriminator_name,
"pjoin",
)
@classmethod
def __declare_first__(cls):
m = cls.__mapper__
if m.with_polymorphic:
return
discriminator_name = (
getattr(cls, "_concrete_discriminator_name", None) or "type"
)
mappers = list(m.self_and_descendants)
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
m._set_with_polymorphic(("*", pjoin))
m._set_polymorphic_on(pjoin.c[discriminator_name])
class AbstractConcreteBase(ConcreteBase):
"""A helper class for 'concrete' declarative mappings.
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
function automatically, against all tables mapped as a subclass
to this class. The function is called via the
``__declare_last__()`` function, which is essentially
a hook for the :meth:`.after_configured` event.
:class:`.AbstractConcreteBase` does produce a mapped class
for the base class, however it is not persisted to any table; it
is instead mapped directly to the "polymorphic" selectable directly
and is only used for selecting. Compare to :class:`.ConcreteBase`,
which does create a persisted table for the base class.
.. note::
The :class:`.AbstractConcreteBase` class does not intend to set up the
mapping for the base class until all the subclasses have been defined,
as it needs to create a mapping against a selectable that will include
all subclass tables. In order to achieve this, it waits for the
**mapper configuration event** to occur, at which point it scans
through all the configured subclasses and sets up a mapping that will
query against all subclasses at once.
While this event is normally invoked automatically, in the case of
:class:`.AbstractConcreteBase`, it may be necessary to invoke it
explicitly after **all** subclass mappings are defined, if the first
operation is to be a query against this base class. To do so, invoke
:func:`.configure_mappers` once all the desired classes have been
configured::
from sqlalchemy.orm import configure_mappers
configure_mappers()
.. seealso::
:func:`_orm.configure_mappers`
Example::
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Employee(AbstractConcreteBase, Base):
pass
class Manager(Employee):
__tablename__ = 'manager'
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
configure_mappers()
The abstract base class is handled by declarative in a special way;
at class configuration time, it behaves like a declarative mixin
or an ``__abstract__`` base class. Once classes are configured
and mappings are produced, it then gets mapped itself, but
after all of its descendants. This is a very unique system of mapping
not found in any other SQLAlchemy system.
Using this approach, we can specify columns and properties
that will take place on mapped subclasses, in the way that
we normally do as in :ref:`declarative_mixins`::
class Company(Base):
__tablename__ = 'company'
id = Column(Integer, primary_key=True)
class Employee(AbstractConcreteBase, Base):
employee_id = Column(Integer, primary_key=True)
@declared_attr
def company_id(cls):
return Column(ForeignKey('company.id'))
@declared_attr
def company(cls):
return relationship("Company")
class Manager(Employee):
__tablename__ = 'manager'
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
configure_mappers()
When we make use of our mappings however, both ``Manager`` and
``Employee`` will have an independently usable ``.company`` attribute::
session.query(Employee).filter(Employee.company.has(id=5))
.. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
have been reworked to support relationships established directly
on the abstract base, without any special configurational steps.
.. seealso::
:class:`.ConcreteBase`
:ref:`concrete_inheritance`
"""
__no_table__ = True
@classmethod
def __declare_first__(cls):
cls._sa_decl_prepare_nocascade()
@classmethod
def _sa_decl_prepare_nocascade(cls):
if getattr(cls, "__mapper__", None):
return
to_map = _DeferredMapperConfig.config_for_cls(cls)
# can't rely on 'self_and_descendants' here
# since technically an immediate subclass
# might not be mapped, but a subclass
# may be.
mappers = []
stack = list(cls.__subclasses__())
while stack:
klass = stack.pop()
stack.extend(klass.__subclasses__())
mn = _mapper_or_none(klass)
if mn is not None:
mappers.append(mn)
discriminator_name = (
getattr(cls, "_concrete_discriminator_name", None) or "type"
)
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
# For columns that were declared on the class, these
# are normally ignored with the "__no_table__" mapping,
# unless they have a different attribute key vs. col name
# and are in the properties argument.
# In that case, ensure we update the properties entry
# to the correct column from the pjoin target table.
declared_cols = set(to_map.declared_columns)
for k, v in list(to_map.properties.items()):
if v in declared_cols:
to_map.properties[k] = pjoin.c[v.key]
to_map.local_table = pjoin
m_args = to_map.mapper_args_fn or dict
def mapper_args():
args = m_args()
args["polymorphic_on"] = pjoin.c[discriminator_name]
return args
to_map.mapper_args_fn = mapper_args
m = to_map.map()
for scls in cls.__subclasses__():
sm = _mapper_or_none(scls)
if sm and sm.concrete and cls in scls.__bases__:
sm._set_concrete_base(m)
@classmethod
def _sa_raise_deferred_config(cls):
raise orm_exc.UnmappedClassError(
cls,
msg="Class %s is a subclass of AbstractConcreteBase and "
"has a mapping pending until all subclasses are defined. "
"Call the sqlalchemy.orm.configure_mappers() function after "
"all subclasses have been defined to "
"complete the mapping of this class."
% orm_exc._safe_cls_name(cls),
)
class DeferredReflection(object):
"""A helper class for construction of mappings based on
a deferred reflection step.
Normally, declarative can be used with reflection by
setting a :class:`_schema.Table` object using autoload_with=engine
as the ``__table__`` attribute on a declarative class.
The caveat is that the :class:`_schema.Table` must be fully
reflected, or at the very least have a primary key column,
at the point at which a normal declarative mapping is
constructed, meaning the :class:`_engine.Engine` must be available
at class declaration time.
The :class:`.DeferredReflection` mixin moves the construction
of mappers to be at a later point, after a specific
method is called which first reflects all :class:`_schema.Table`
objects created so far. Classes can define it as such::
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeferredReflection
Base = declarative_base()
class MyClass(DeferredReflection, Base):
__tablename__ = 'mytable'
Above, ``MyClass`` is not yet mapped. After a series of
classes have been defined in the above fashion, all tables
can be reflected and mappings created using
:meth:`.prepare`::
engine = create_engine("someengine://...")
DeferredReflection.prepare(engine)
The :class:`.DeferredReflection` mixin can be applied to individual
classes, used as the base for the declarative base itself,
or used in a custom abstract class. Using an abstract base
allows that only a subset of classes to be prepared for a
particular prepare step, which is necessary for applications
that use more than one engine. For example, if an application
has two engines, you might use two bases, and prepare each
separately, e.g.::
class ReflectedOne(DeferredReflection, Base):
__abstract__ = True
class ReflectedTwo(DeferredReflection, Base):
__abstract__ = True
class MyClass(ReflectedOne):
__tablename__ = 'mytable'
class MyOtherClass(ReflectedOne):
__tablename__ = 'myothertable'
class YetAnotherClass(ReflectedTwo):
__tablename__ = 'yetanothertable'
# ... etc.
Above, the class hierarchies for ``ReflectedOne`` and
``ReflectedTwo`` can be configured separately::
ReflectedOne.prepare(engine_one)
ReflectedTwo.prepare(engine_two)
.. seealso::
:ref:`orm_declarative_reflected_deferred_reflection` - in the
:ref:`orm_declarative_table_config_toplevel` section.
"""
@classmethod
def prepare(cls, engine):
"""Reflect all :class:`_schema.Table` objects for all current
:class:`.DeferredReflection` subclasses"""
to_map = _DeferredMapperConfig.classes_for_base(cls)
with inspection.inspect(engine)._inspection_context() as insp:
for thingy in to_map:
cls._sa_decl_prepare(thingy.local_table, insp)
thingy.map()
mapper = thingy.cls.__mapper__
metadata = mapper.class_.metadata
for rel in mapper._props.values():
if (
isinstance(rel, relationships.RelationshipProperty)
and rel.secondary is not None
):
if isinstance(rel.secondary, Table):
cls._reflect_table(rel.secondary, insp)
elif isinstance(rel.secondary, str):
_, resolve_arg = _resolver(rel.parent.class_, rel)
rel.secondary = resolve_arg(rel.secondary)
rel.secondary._resolvers += (
cls._sa_deferred_table_resolver(
insp, metadata
),
)
# controversy! do we resolve it here? or leave
# it deferred? I think doing it here is necessary
# so the connection does not leak.
rel.secondary = rel.secondary()
@classmethod
def _sa_deferred_table_resolver(cls, inspector, metadata):
def _resolve(key):
t1 = Table(key, metadata)
cls._reflect_table(t1, inspector)
return t1
return _resolve
@classmethod
def _sa_decl_prepare(cls, local_table, inspector):
# autoload Table, which is already
# present in the metadata. This
# will fill in db-loaded columns
# into the existing Table object.
if local_table is not None:
cls._reflect_table(local_table, inspector)
@classmethod
def _sa_raise_deferred_config(cls):
raise orm_exc.UnmappedClassError(
cls,
msg="Class %s is a subclass of DeferredReflection. "
"Mappings are not produced until the .prepare() "
"method is called on the class hierarchy."
% orm_exc._safe_cls_name(cls),
)
@classmethod
def _reflect_table(cls, table, inspector):
Table(
table.name,
table.metadata,
extend_existing=True,
autoload_replace=False,
autoload_with=inspector,
schema=table.schema,
)

View File

@@ -0,0 +1,256 @@
# ext/horizontal_shard.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Horizontal sharding support.
Defines a rudimental 'horizontal sharding' system which allows a Session to
distribute queries and persistence operations across multiple databases.
For a usage example, see the :ref:`examples_sharding` example included in
the source distribution.
"""
from .. import event
from .. import exc
from .. import inspect
from .. import util
from ..orm.query import Query
from ..orm.session import Session
__all__ = ["ShardedSession", "ShardedQuery"]
class ShardedQuery(Query):
def __init__(self, *args, **kwargs):
super(ShardedQuery, self).__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
def set_shard(self, shard_id):
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
be against the single shard regardless of other state.
The shard_id can be passed for a 2.0 style execution to the
bind_arguments dictionary of :meth:`.Session.execute`::
results = session.execute(
stmt,
bind_arguments={"shard_id": "my_shard"}
)
"""
return self.execution_options(_sa_shard_id=shard_id)
class ShardedSession(Session):
def __init__(
self,
shard_chooser,
id_chooser,
execute_chooser=None,
shards=None,
query_cls=ShardedQuery,
**kwargs
):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
instance, and possibly a SQL clause, returns a shard ID. This id
may be based off of the attributes present within the object, or on
some round-robin scheme. If the scheme is based on a selection, it
should set whatever state on the instance to mark it in the future as
participating in that shard.
:param id_chooser: A callable, passed a query and a tuple of identity
values, which should return a list of shard ids where the ID might
reside. The databases will be queried in the order of this listing.
:param execute_chooser: For a given :class:`.ORMExecuteState`,
returns the list of shard_ids
where the query should be issued. Results from all shards returned
will be combined together into a single listing.
.. versionchanged:: 1.4 The ``execute_chooser`` parameter
supersedes the ``query_chooser`` parameter.
:param shards: A dictionary of string shard names
to :class:`~sqlalchemy.engine.Engine` objects.
"""
query_chooser = kwargs.pop("query_chooser", None)
super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
event.listen(
self, "do_orm_execute", execute_and_instances, retval=True
)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
if query_chooser:
util.warn_deprecated(
"The ``query_choser`` parameter is deprecated; "
"please use ``execute_chooser``.",
"1.4",
)
if execute_chooser:
raise exc.ArgumentError(
"Can't pass query_chooser and execute_chooser "
"at the same time."
)
def execute_chooser(orm_context):
return query_chooser(orm_context.statement)
self.execute_chooser = execute_chooser
else:
self.execute_chooser = execute_chooser
self.query_chooser = query_chooser
self.__binds = {}
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def _identity_lookup(
self,
mapper,
primary_key_identity,
identity_token=None,
lazy_loaded_from=None,
**kw
):
"""override the default :meth:`.Session._identity_lookup` method so
that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
.. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
the :class:`_query.Query` object to the :class:`.Session`.
"""
if identity_token is not None:
return super(ShardedSession, self)._identity_lookup(
mapper,
primary_key_identity,
identity_token=identity_token,
**kw
)
else:
q = self.query(mapper)
if lazy_loaded_from:
q = q._set_lazyload_from(lazy_loaded_from)
for shard_id in self.id_chooser(q, primary_key_identity):
obj = super(ShardedSession, self)._identity_lookup(
mapper,
primary_key_identity,
identity_token=shard_id,
lazy_loaded_from=lazy_loaded_from,
**kw
)
if obj is not None:
return obj
return None
def _choose_shard_and_assign(self, mapper, instance, **kw):
if instance is not None:
state = inspect(instance)
if state.key:
token = state.key[2]
assert token is not None
return token
elif state.identity_token:
return state.identity_token
shard_id = self.shard_chooser(mapper, instance, **kw)
if instance is not None:
state.identity_token = shard_id
return shard_id
def connection_callable(
self, mapper=None, instance=None, shard_id=None, **kwargs
):
"""Provide a :class:`_engine.Connection` to use in the unit of work
flush process.
"""
if shard_id is None:
shard_id = self._choose_shard_and_assign(mapper, instance)
if self.in_transaction():
return self.get_transaction().connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
mapper, shard_id=shard_id, instance=instance
).connect(**kwargs)
def get_bind(
self, mapper=None, shard_id=None, instance=None, clause=None, **kw
):
if shard_id is None:
shard_id = self._choose_shard_and_assign(
mapper, instance, clause=clause
)
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
def execute_and_instances(orm_context):
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
elif orm_context.is_update or orm_context.is_delete:
load_options = None
update_options = active_options = orm_context.update_delete_options
else:
load_options = update_options = active_options = None
session = orm_context.session
def iter_for_shard(shard_id, load_options, update_options):
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
if orm_context.is_select:
load_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_load_options"] = load_options
elif orm_context.is_update or orm_context.is_delete:
update_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_update_options"] = update_options
return orm_context.invoke_statement(
bind_arguments=bind_arguments, execution_options=execution_options
)
if active_options and active_options._refresh_identity_token is not None:
shard_id = active_options._refresh_identity_token
elif "_sa_shard_id" in orm_context.execution_options:
shard_id = orm_context.execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
shard_id = orm_context.bind_arguments["shard_id"]
else:
shard_id = None
if shard_id is not None:
return iter_for_shard(shard_id, load_options, update_options)
else:
partial = []
for shard_id in session.execute_chooser(orm_context):
result_ = iter_for_shard(shard_id, load_options, update_options)
partial.append(result_)
return partial[0].merge(*partial[1:])

1206
lib/sqlalchemy/ext/hybrid.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,352 @@
# ext/index.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Define attributes on ORM-mapped classes that have "index" attributes for
columns with :class:`_types.Indexable` types.
"index" means the attribute is associated with an element of an
:class:`_types.Indexable` column with the predefined index to access it.
The :class:`_types.Indexable` types include types such as
:class:`_types.ARRAY`, :class:`_types.JSON` and
:class:`_postgresql.HSTORE`.
The :mod:`~sqlalchemy.ext.indexable` extension provides
:class:`_schema.Column`-like interface for any element of an
:class:`_types.Indexable` typed column. In simple cases, it can be
treated as a :class:`_schema.Column` - mapped attribute.
.. versionadded:: 1.1
Synopsis
========
Given ``Person`` as a model with a primary key and JSON data field.
While this field may have any number of elements encoded within it,
we would like to refer to the element called ``name`` individually
as a dedicated attribute which behaves like a standalone column::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
name = index_property('data', 'name')
Above, the ``name`` attribute now behaves like a mapped column. We
can compose a new ``Person`` and set the value of ``name``::
>>> person = Person(name='Alchemist')
The value is now accessible::
>>> person.name
'Alchemist'
Behind the scenes, the JSON field was initialized to a new blank dictionary
and the field was set::
>>> person.data
{"name": "Alchemist'}
The field is mutable in place::
>>> person.name = 'Renamed'
>>> person.name
'Renamed'
>>> person.data
{'name': 'Renamed'}
When using :class:`.index_property`, the change that we make to the indexable
structure is also automatically tracked as history; we no longer need
to use :class:`~.mutable.MutableDict` in order to track this change
for the unit of work.
Deletions work normally as well::
>>> del person.name
>>> person.data
{}
Above, deletion of ``person.name`` deletes the value from the dictionary,
but not the dictionary itself.
A missing key will produce ``AttributeError``::
>>> person = Person()
>>> person.name
...
AttributeError: 'name'
Unless you set a default value::
>>> class Person(Base):
>>> __tablename__ = 'person'
>>>
>>> id = Column(Integer, primary_key=True)
>>> data = Column(JSON)
>>>
>>> name = index_property('data', 'name', default=None) # See default
>>> person = Person()
>>> print(person.name)
None
The attributes are also accessible at the class level.
Below, we illustrate ``Person.name`` used to generate
an indexed SQL criteria::
>>> from sqlalchemy.orm import Session
>>> session = Session()
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
The above query is equivalent to::
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
Multiple :class:`.index_property` objects can be chained to produce
multiple levels of indexing::
from sqlalchemy import Column, JSON, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.indexable import index_property
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
birthday = index_property('data', 'birthday')
year = index_property('birthday', 'year')
month = index_property('birthday', 'month')
day = index_property('birthday', 'day')
Above, a query such as::
q = session.query(Person).filter(Person.year == '1980')
On a PostgreSQL backend, the above query will render as::
SELECT person.id, person.data
FROM person
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
Default Values
==============
:class:`.index_property` includes special behaviors for when the indexed
data structure does not exist, and a set operation is called:
* For an :class:`.index_property` that is given an integer index value,
the default data structure will be a Python list of ``None`` values,
at least as long as the index value; the value is then set at its
place in the list. This means for an index value of zero, the list
will be initialized to ``[None]`` before setting the given value,
and for an index value of five, the list will be initialized to
``[None, None, None, None, None]`` before setting the fifth element
to the given value. Note that an existing list is **not** extended
in place to receive a value.
* for an :class:`.index_property` that is given any other kind of index
value (e.g. strings usually), a Python dictionary is used as the
default data structure.
* The default data structure can be set to any Python callable using the
:paramref:`.index_property.datatype` parameter, overriding the previous
rules.
Subclassing
===========
:class:`.index_property` can be subclassed, in particular for the common
use case of providing coercion of values or SQL expressions as they are
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
where we want to also include automatic casting plus ``astext()``::
class pg_json_property(index_property):
def __init__(self, attr_name, index, cast_type):
super(pg_json_property, self).__init__(attr_name, index)
self.cast_type = cast_type
def expr(self, model):
expr = super(pg_json_property, self).expr(model)
return expr.astext.cast(self.cast_type)
The above subclass can be used with the PostgreSQL-specific
version of :class:`_postgresql.JSON`::
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import JSON
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
id = Column(Integer, primary_key=True)
data = Column(JSON)
age = pg_json_property('data', 'age', Integer)
The ``age`` attribute at the instance level works as before; however
when rendering SQL, PostgreSQL's ``->>`` operator will be used
for indexed access, instead of the usual index operator of ``->``::
>>> query = session.query(Person).filter(Person.age < 20)
The above query will render::
SELECT person.id, person.data
FROM person
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
""" # noqa
from __future__ import absolute_import
from .. import inspect
from .. import util
from ..ext.hybrid import hybrid_property
from ..orm.attributes import flag_modified
__all__ = ["index_property"]
class index_property(hybrid_property): # noqa
"""A property generator. The generated property describes an object
attribute that corresponds to an :class:`_types.Indexable`
column.
.. versionadded:: 1.1
.. seealso::
:mod:`sqlalchemy.ext.indexable`
"""
_NO_DEFAULT_ARGUMENT = object()
def __init__(
self,
attr_name,
index,
default=_NO_DEFAULT_ARGUMENT,
datatype=None,
mutable=True,
onebased=True,
):
"""Create a new :class:`.index_property`.
:param attr_name:
An attribute name of an `Indexable` typed column, or other
attribute that returns an indexable structure.
:param index:
The index to be used for getting and setting this value. This
should be the Python-side index value for integers.
:param default:
A value which will be returned instead of `AttributeError`
when there is not a value at given index.
:param datatype: default datatype to use when the field is empty.
By default, this is derived from the type of index used; a
Python list for an integer index, or a Python dictionary for
any other style of index. For a list, the list will be
initialized to a list of None values that is at least
``index`` elements long.
:param mutable: if False, writes and deletes to the attribute will
be disallowed.
:param onebased: assume the SQL representation of this value is
one-based; that is, the first index in SQL is 1, not zero.
"""
if mutable:
super(index_property, self).__init__(
self.fget, self.fset, self.fdel, self.expr
)
else:
super(index_property, self).__init__(
self.fget, None, None, self.expr
)
self.attr_name = attr_name
self.index = index
self.default = default
is_numeric = isinstance(index, int)
onebased = is_numeric and onebased
if datatype is not None:
self.datatype = datatype
else:
if is_numeric:
self.datatype = lambda: [None for x in range(index + 1)]
else:
self.datatype = dict
self.onebased = onebased
def _fget_default(self, err=None):
if self.default == self._NO_DEFAULT_ARGUMENT:
util.raise_(AttributeError(self.attr_name), replace_context=err)
else:
return self.default
def fget(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
return self._fget_default()
try:
value = column_value[self.index]
except (KeyError, IndexError) as err:
return self._fget_default(err)
else:
return value
def fset(self, instance, value):
attr_name = self.attr_name
column_value = getattr(instance, attr_name, None)
if column_value is None:
column_value = self.datatype()
setattr(instance, attr_name, column_value)
column_value[self.index] = value
setattr(instance, attr_name, column_value)
if attr_name in inspect(instance).mapper.attrs:
flag_modified(instance, attr_name)
def fdel(self, instance):
attr_name = self.attr_name
column_value = getattr(instance, attr_name)
if column_value is None:
raise AttributeError(self.attr_name)
try:
del column_value[self.index]
except KeyError as err:
util.raise_(AttributeError(self.attr_name), replace_context=err)
else:
setattr(instance, attr_name, column_value)
flag_modified(instance, attr_name)
def expr(self, model):
column = getattr(model, self.attr_name)
index = self.index
if self.onebased:
index += 1
return column[index]

View File

@@ -0,0 +1,416 @@
"""Extensible class instrumentation.
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
systems of class instrumentation within the ORM. Class instrumentation
refers to how the ORM places attributes on the class which maintain
data and track changes to that data, as well as event hooks installed
on the class.
.. note::
The extension package is provided for the benefit of integration
with other object management packages, which already perform
their own instrumentation. It is not intended for general use.
For examples of how the instrumentation extension is used,
see the example :ref:`examples_instrumentation`.
"""
import weakref
from .. import util
from ..orm import attributes
from ..orm import base as orm_base
from ..orm import collections
from ..orm import exc as orm_exc
from ..orm import instrumentation as orm_instrumentation
from ..orm.instrumentation import _default_dict_getter
from ..orm.instrumentation import _default_manager_getter
from ..orm.instrumentation import _default_state_getter
from ..orm.instrumentation import ClassManager
from ..orm.instrumentation import InstrumentationFactory
INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
"""Attribute, elects custom instrumentation when present on a mapped class.
Allows a class to specify a slightly or wildly different technique for
tracking changes made to mapped attributes and collections.
Only one instrumentation implementation is allowed in a given object
inheritance hierarchy.
The value of this attribute must be a callable and will be passed a class
object. The callable must return one of:
- An instance of an :class:`.InstrumentationManager` or subclass
- An object implementing all or some of InstrumentationManager (TODO)
- A dictionary of callables, implementing all or some of the above (TODO)
- An instance of a :class:`.ClassManager` or subclass
This attribute is consulted by SQLAlchemy instrumentation
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
has been imported. If custom finders are installed in the global
instrumentation_finders list, they may or may not choose to honor this
attribute.
"""
def find_native_user_instrumentation_hook(cls):
"""Find user-specified instrumentation management for a class."""
return getattr(cls, INSTRUMENTATION_MANAGER, None)
instrumentation_finders = [find_native_user_instrumentation_hook]
"""An extensible sequence of callables which return instrumentation
implementations
When a class is registered, each callable will be passed a class object.
If None is returned, the
next finder in the sequence is consulted. Otherwise the return must be an
instrumentation factory that follows the same guidelines as
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
By default, the only finder is find_native_user_instrumentation_hook, which
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
ClassManager instrumentation is used.
"""
class ExtendedInstrumentationRegistry(InstrumentationFactory):
"""Extends :class:`.InstrumentationFactory` with additional
bookkeeping, to accommodate multiple types of
class managers.
"""
_manager_finders = weakref.WeakKeyDictionary()
_state_finders = weakref.WeakKeyDictionary()
_dict_finders = weakref.WeakKeyDictionary()
_extended = False
def _locate_extended_factory(self, class_):
for finder in instrumentation_finders:
factory = finder(class_)
if factory is not None:
manager = self._extended_class_manager(class_, factory)
return manager, factory
else:
return None, None
def _check_conflicts(self, class_, factory):
existing_factories = self._collect_management_factories_for(
class_
).difference([factory])
if existing_factories:
raise TypeError(
"multiple instrumentation implementations specified "
"in %s inheritance hierarchy: %r"
% (class_.__name__, list(existing_factories))
)
def _extended_class_manager(self, class_, factory):
manager = factory(class_)
if not isinstance(manager, ClassManager):
manager = _ClassInstrumentationAdapter(class_, manager)
if factory != ClassManager and not self._extended:
# somebody invoked a custom ClassManager.
# reinstall global "getter" functions with the more
# expensive ones.
self._extended = True
_install_instrumented_lookups()
self._manager_finders[class_] = manager.manager_getter()
self._state_finders[class_] = manager.state_getter()
self._dict_finders[class_] = manager.dict_getter()
return manager
def _collect_management_factories_for(self, cls):
"""Return a collection of factories in play or specified for a
hierarchy.
Traverses the entire inheritance graph of a cls and returns a
collection of instrumentation factories for those classes. Factories
are extracted from active ClassManagers, if available, otherwise
instrumentation_finders is consulted.
"""
hierarchy = util.class_hierarchy(cls)
factories = set()
for member in hierarchy:
manager = self.manager_of_class(member)
if manager is not None:
factories.add(manager.factory)
else:
for finder in instrumentation_finders:
factory = finder(member)
if factory is not None:
break
else:
factory = None
factories.add(factory)
factories.discard(None)
return factories
def unregister(self, class_):
super(ExtendedInstrumentationRegistry, self).unregister(class_)
if class_ in self._manager_finders:
del self._manager_finders[class_]
del self._state_finders[class_]
del self._dict_finders[class_]
def manager_of_class(self, cls):
if cls is None:
return None
try:
finder = self._manager_finders.get(cls, _default_manager_getter)
except TypeError:
# due to weakref lookup on invalid object
return None
else:
return finder(cls)
def state_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._state_finders.get(
instance.__class__, _default_state_getter
)(instance)
def dict_of(self, instance):
if instance is None:
raise AttributeError("None has no persistent state.")
return self._dict_finders.get(
instance.__class__, _default_dict_getter
)(instance)
orm_instrumentation._instrumentation_factory = (
_instrumentation_factory
) = ExtendedInstrumentationRegistry()
orm_instrumentation.instrumentation_finders = instrumentation_finders
class InstrumentationManager(object):
"""User-defined class instrumentation extension.
:class:`.InstrumentationManager` can be subclassed in order
to change
how class instrumentation proceeds. This class exists for
the purposes of integration with other object management
frameworks which would like to entirely modify the
instrumentation methodology of the ORM, and is not intended
for regular usage. For interception of class instrumentation
events, see :class:`.InstrumentationEvents`.
The API for this class should be considered as semi-stable,
and may change slightly with new releases.
"""
# r4361 added a mandatory (cls) constructor to this interface.
# given that, perhaps class_ should be dropped from all of these
# signatures.
def __init__(self, class_):
pass
def manage(self, class_, manager):
setattr(class_, "_default_class_manager", manager)
def unregister(self, class_, manager):
delattr(class_, "_default_class_manager")
def manager_getter(self, class_):
def get(cls):
return cls._default_class_manager
return get
def instrument_attribute(self, class_, key, inst):
pass
def post_configure_attribute(self, class_, key, inst):
pass
def install_descriptor(self, class_, key, inst):
setattr(class_, key, inst)
def uninstall_descriptor(self, class_, key):
delattr(class_, key)
def install_member(self, class_, key, implementation):
setattr(class_, key, implementation)
def uninstall_member(self, class_, key):
delattr(class_, key)
def instrument_collection_class(self, class_, key, collection_class):
return collections.prepare_instrumentation(collection_class)
def get_instance_dict(self, class_, instance):
return instance.__dict__
def initialize_instance_dict(self, class_, instance):
pass
def install_state(self, class_, instance, state):
setattr(instance, "_default_state", state)
def remove_state(self, class_, instance):
delattr(instance, "_default_state")
def state_getter(self, class_):
return lambda instance: getattr(instance, "_default_state")
def dict_getter(self, class_):
return lambda inst: self.get_instance_dict(class_, inst)
class _ClassInstrumentationAdapter(ClassManager):
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
def __init__(self, class_, override):
self._adapted = override
self._get_state = self._adapted.state_getter(class_)
self._get_dict = self._adapted.dict_getter(class_)
ClassManager.__init__(self, class_)
def manage(self):
self._adapted.manage(self.class_, self)
def unregister(self):
self._adapted.unregister(self.class_, self)
def manager_getter(self):
return self._adapted.manager_getter(self.class_)
def instrument_attribute(self, key, inst, propagated=False):
ClassManager.instrument_attribute(self, key, inst, propagated)
if not propagated:
self._adapted.instrument_attribute(self.class_, key, inst)
def post_configure_attribute(self, key):
super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
self._adapted.post_configure_attribute(self.class_, key, self[key])
def install_descriptor(self, key, inst):
self._adapted.install_descriptor(self.class_, key, inst)
def uninstall_descriptor(self, key):
self._adapted.uninstall_descriptor(self.class_, key)
def install_member(self, key, implementation):
self._adapted.install_member(self.class_, key, implementation)
def uninstall_member(self, key):
self._adapted.uninstall_member(self.class_, key)
def instrument_collection_class(self, key, collection_class):
return self._adapted.instrument_collection_class(
self.class_, key, collection_class
)
def initialize_collection(self, key, state, factory):
delegate = getattr(self._adapted, "initialize_collection", None)
if delegate:
return delegate(key, state, factory)
else:
return ClassManager.initialize_collection(
self, key, state, factory
)
def new_instance(self, state=None):
instance = self.class_.__new__(self.class_)
self.setup_instance(instance, state)
return instance
def _new_state_if_none(self, instance):
"""Install a default InstanceState if none is present.
A private convenience method used by the __init__ decorator.
"""
if self.has_state(instance):
return False
else:
return self.setup_instance(instance)
def setup_instance(self, instance, state=None):
self._adapted.initialize_instance_dict(self.class_, instance)
if state is None:
state = self._state_constructor(instance, self)
# the given instance is assumed to have no state
self._adapted.install_state(self.class_, instance, state)
return state
def teardown_instance(self, instance):
self._adapted.remove_state(self.class_, instance)
def has_state(self, instance):
try:
self._get_state(instance)
except orm_exc.NO_STATE:
return False
else:
return True
def state_getter(self):
return self._get_state
def dict_getter(self):
return self._get_dict
def _install_instrumented_lookups():
"""Replace global class/object management functions
with ExtendedInstrumentationRegistry implementations, which
allow multiple types of class managers to be present,
at the cost of performance.
This function is called only by ExtendedInstrumentationRegistry
and unit tests specific to this behavior.
The _reinstall_default_lookups() function can be called
after this one to re-establish the default functions.
"""
_install_lookups(
dict(
instance_state=_instrumentation_factory.state_of,
instance_dict=_instrumentation_factory.dict_of,
manager_of_class=_instrumentation_factory.manager_of_class,
)
)
def _reinstall_default_lookups():
"""Restore simplified lookups."""
_install_lookups(
dict(
instance_state=_default_state_getter,
instance_dict=_default_dict_getter,
manager_of_class=_default_manager_getter,
)
)
_instrumentation_factory._extended = False
def _install_lookups(lookups):
global instance_state, instance_dict, manager_of_class
instance_state = lookups["instance_state"]
instance_dict = lookups["instance_dict"]
manager_of_class = lookups["manager_of_class"]
orm_base.instance_state = (
attributes.instance_state
) = orm_instrumentation.instance_state = instance_state
orm_base.instance_dict = (
attributes.instance_dict
) = orm_instrumentation.instance_dict = instance_dict
orm_base.manager_of_class = (
attributes.manager_of_class
) = orm_instrumentation.manager_of_class = manager_of_class

View File

@@ -0,0 +1,958 @@
# ext/mutable.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""Provide support for tracking of in-place changes to scalar values,
which are propagated into ORM change events on owning parent objects.
.. _mutable_scalars:
Establishing Mutability on Scalar Column Values
===============================================
A typical example of a "mutable" structure is a Python dictionary.
Following the example introduced in :ref:`types_toplevel`, we
begin with a custom type that marshals Python dictionaries into
JSON strings before being persisted::
from sqlalchemy.types import TypeDecorator, VARCHAR
import json
class JSONEncodedDict(TypeDecorator):
"Represents an immutable structure as a json-encoded string."
impl = VARCHAR
def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value)
return value
def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value)
return value
The usage of ``json`` is only for the purposes of example. The
:mod:`sqlalchemy.ext.mutable` extension can be used
with any type whose target Python type may be mutable, including
:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc.
When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
tracks all parents which reference it. Below, we illustrate a simple
version of the :class:`.MutableDict` dictionary object, which applies
the :class:`.Mutable` mixin to a plain Python dictionary::
from sqlalchemy.ext.mutable import Mutable
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
"Convert plain dictionaries to MutableDict."
if not isinstance(value, MutableDict):
if isinstance(value, dict):
return MutableDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
else:
return value
def __setitem__(self, key, value):
"Detect dictionary set events and emit change events."
dict.__setitem__(self, key, value)
self.changed()
def __delitem__(self, key):
"Detect dictionary del events and emit change events."
dict.__delitem__(self, key)
self.changed()
The above dictionary class takes the approach of subclassing the Python
built-in ``dict`` to produce a dict
subclass which routes all mutation events through ``__setitem__``. There are
variants on this approach, such as subclassing ``UserDict.UserDict`` or
``collections.MutableMapping``; the part that's important to this example is
that the :meth:`.Mutable.changed` method is called whenever an in-place
change to the datastructure takes place.
We also redefine the :meth:`.Mutable.coerce` method which will be used to
convert any values that are not instances of ``MutableDict``, such
as the plain dictionaries returned by the ``json`` module, into the
appropriate type. Defining this method is optional; we could just as well
created our ``JSONEncodedDict`` such that it always returns an instance
of ``MutableDict``, and additionally ensured that all calling code
uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
overridden, any values applied to a parent object which are not instances
of the mutable type will raise a ``ValueError``.
Our new ``MutableDict`` type offers a class method
:meth:`~.Mutable.as_mutable` which we can use within column metadata
to associate with types. This method grabs the given type object or
class and associates a listener that will detect all future mappings
of this type, applying event listening instrumentation to the mapped
attribute. Such as, with classical table metadata::
from sqlalchemy import Table, Column, Integer
my_data = Table('my_data', metadata,
Column('id', Integer, primary_key=True),
Column('data', MutableDict.as_mutable(JSONEncodedDict))
)
Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
(if the type object was not an instance already), which will intercept any
attributes which are mapped against this type. Below we establish a simple
mapping against the ``my_data`` table::
from sqlalchemy import mapper
class MyDataClass(object):
pass
# associates mutation listeners with MyDataClass.data
mapper(MyDataClass, my_data)
The ``MyDataClass.data`` member will now be notified of in place changes
to its value.
There's no difference in usage when using declarative::
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class MyDataClass(Base):
__tablename__ = 'my_data'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(JSONEncodedDict))
Any in-place changes to the ``MyDataClass.data`` member
will flag the attribute as "dirty" on the parent object::
>>> from sqlalchemy.orm import Session
>>> sess = Session()
>>> m1 = MyDataClass(data={'value1':'foo'})
>>> sess.add(m1)
>>> sess.commit()
>>> m1.data['value1'] = 'bar'
>>> assert m1 in sess.dirty
True
The ``MutableDict`` can be associated with all future instances
of ``JSONEncodedDict`` in one step, using
:meth:`~.Mutable.associate_with`. This is similar to
:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
of ``MutableDict`` in all mappings unconditionally, without
the need to declare it individually::
MutableDict.associate_with(JSONEncodedDict)
class MyDataClass(Base):
__tablename__ = 'my_data'
id = Column(Integer, primary_key=True)
data = Column(JSONEncodedDict)
Supporting Pickling
--------------------
The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
stores a mapping of parent mapped objects keyed to the attribute name under
which they are associated with this value. ``WeakKeyDictionary`` objects are
not picklable, due to the fact that they contain weakrefs and function
callbacks. In our case, this is a good thing, since if this dictionary were
picklable, it could lead to an excessively large pickle size for our value
objects that are pickled by themselves outside of the context of the parent.
The developer responsibility here is only to provide a ``__getstate__`` method
that excludes the :meth:`~MutableBase._parents` collection from the pickle
stream::
class MyMutableType(Mutable):
def __getstate__(self):
d = self.__dict__.copy()
d.pop('_parents', None)
return d
With our dictionary example, we need to return the contents of the dict itself
(and also restore them on __setstate__)::
class MutableDict(Mutable, dict):
# ....
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(state)
In the case that our mutable value object is pickled as it is attached to one
or more parent objects that are also part of the pickle, the :class:`.Mutable`
mixin will re-establish the :attr:`.Mutable._parents` collection on each value
object as the owning parents themselves are unpickled.
Receiving Events
----------------
The :meth:`.AttributeEvents.modified` event handler may be used to receive
an event when a mutable scalar emits a change event. This event handler
is called when the :func:`.attributes.flag_modified` function is called
from within the mutable extension::
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import event
Base = declarative_base()
class MyDataClass(Base):
__tablename__ = 'my_data'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(JSONEncodedDict))
@event.listens_for(MyDataClass.data, "modified")
def modified_json(instance):
print("json value modified:", instance.data)
.. _mutable_composites:
Establishing Mutability on Composites
=====================================
Composites are a special ORM feature which allow a single scalar attribute to
be assigned an object value which represents information "composed" from one
or more columns from the underlying mapped table. The usual example is that of
a geometric "point", and is introduced in :ref:`mapper_composite`.
As is the case with :class:`.Mutable`, the user-defined composite class
subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
change events to its parents via the :meth:`.MutableComposite.changed` method.
In the case of a composite class, the detection is usually via the usage of
Python descriptors (i.e. ``@property``), or alternatively via the special
Python method ``__setattr__()``. Below we expand upon the ``Point`` class
introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
and to also route attribute set events via ``__setattr__`` to the
:meth:`.MutableComposite.changed` method::
from sqlalchemy.ext.mutable import MutableComposite
class Point(MutableComposite):
def __init__(self, x, y):
self.x = x
self.y = y
def __setattr__(self, key, value):
"Intercept set events"
# set the attribute
object.__setattr__(self, key, value)
# alert all parents to the change
self.changed()
def __composite_values__(self):
return self.x, self.y
def __eq__(self, other):
return isinstance(other, Point) and \
other.x == self.x and \
other.y == self.y
def __ne__(self, other):
return not self.__eq__(other)
The :class:`.MutableComposite` class uses a Python metaclass to automatically
establish listeners for any usage of :func:`_orm.composite` that specifies our
``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
listeners are established which will route change events from ``Point``
objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
from sqlalchemy.orm import composite, mapper
from sqlalchemy import Table, Column
vertices = Table('vertices', metadata,
Column('id', Integer, primary_key=True),
Column('x1', Integer),
Column('y1', Integer),
Column('x2', Integer),
Column('y2', Integer),
)
class Vertex(object):
pass
mapper(Vertex, vertices, properties={
'start': composite(Point, vertices.c.x1, vertices.c.y1),
'end': composite(Point, vertices.c.x2, vertices.c.y2)
})
Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
will flag the attribute as "dirty" on the parent object::
>>> from sqlalchemy.orm import Session
>>> sess = Session()
>>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
>>> sess.add(v1)
>>> sess.commit()
>>> v1.end.x = 8
>>> assert v1 in sess.dirty
True
Coercing Mutable Composites
---------------------------
The :meth:`.MutableBase.coerce` method is also supported on composite types.
In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
method is only called for attribute set operations, not load operations.
Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
to using a :func:`.validates` validation routine for all attributes which
make use of the custom composite type::
class Point(MutableComposite):
# other Point methods
# ...
def coerce(cls, key, value):
if isinstance(value, tuple):
value = Point(*value)
elif not isinstance(value, Point):
raise ValueError("tuple or Point expected")
return value
Supporting Pickling
--------------------
As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
class uses a ``weakref.WeakKeyDictionary`` available via the
:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
the minimal form of our ``Point`` class::
class Point(MutableComposite):
# ...
def __getstate__(self):
return self.x, self.y
def __setstate__(self, state):
self.x, self.y = state
As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
pickling process of the parent's object-relational state so that the
:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
"""
from collections import defaultdict
import weakref
from .. import event
from .. import inspect
from .. import types
from ..orm import Mapper
from ..orm import mapper
from ..orm.attributes import flag_modified
from ..sql.base import SchemaEventTarget
from ..util import memoized_property
class MutableBase(object):
"""Common base class to :class:`.Mutable`
and :class:`.MutableComposite`.
"""
@memoized_property
def _parents(self):
"""Dictionary of parent object's :class:`.InstanceState`->attribute
name on the parent.
This attribute is a so-called "memoized" property. It initializes
itself with a new ``weakref.WeakKeyDictionary`` the first time
it is accessed, returning the same object upon subsequent access.
.. versionchanged:: 1.4 the :class:`.InstanceState` is now used
as the key in the weak dictionary rather than the instance
itself.
"""
return weakref.WeakKeyDictionary()
@classmethod
def coerce(cls, key, value):
"""Given a value, coerce it into the target type.
Can be overridden by custom subclasses to coerce incoming
data into a particular type.
By default, raises ``ValueError``.
This method is called in different scenarios depending on if
the parent class is of type :class:`.Mutable` or of type
:class:`.MutableComposite`. In the case of the former, it is called
for both attribute-set operations as well as during ORM loading
operations. For the latter, it is only called during attribute-set
operations; the mechanics of the :func:`.composite` construct
handle coercion during load operations.
:param key: string name of the ORM-mapped attribute being set.
:param value: the incoming value.
:return: the method should return the coerced value, or raise
``ValueError`` if the coercion cannot be completed.
"""
if value is None:
return None
msg = "Attribute '%s' does not accept objects of type %s"
raise ValueError(msg % (key, type(value)))
@classmethod
def _get_listen_keys(cls, attribute):
"""Given a descriptor attribute, return a ``set()`` of the attribute
keys which indicate a change in the state of this attribute.
This is normally just ``set([attribute.key])``, but can be overridden
to provide for additional keys. E.g. a :class:`.MutableComposite`
augments this set with the attribute keys associated with the columns
that comprise the composite value.
This collection is consulted in the case of intercepting the
:meth:`.InstanceEvents.refresh` and
:meth:`.InstanceEvents.refresh_flush` events, which pass along a list
of attribute names that have been refreshed; the list is compared
against this set to determine if action needs to be taken.
.. versionadded:: 1.0.5
"""
return {attribute.key}
@classmethod
def _listen_on_attribute(cls, attribute, coerce, parent_cls):
"""Establish this type as a mutation listener for the given
mapped descriptor.
"""
key = attribute.key
if parent_cls is not attribute.class_:
return
# rely on "propagate" here
parent_cls = attribute.class_
listen_keys = cls._get_listen_keys(attribute)
def load(state, *args):
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
``Mutable``.
"""
val = state.dict.get(key, None)
if val is not None:
if coerce:
val = cls.coerce(key, val)
state.dict[key] = val
val._parents[state] = key
def load_attrs(state, ctx, attrs):
if not attrs or listen_keys.intersection(attrs):
load(state)
def set_(target, value, oldvalue, initiator):
"""Listen for set/replace events on the target
data member.
Establish a weak reference to the parent object
on the incoming value, remove it for the one
outgoing.
"""
if value is oldvalue:
return value
if not isinstance(value, cls):
value = cls.coerce(key, value)
if value is not None:
value._parents[target] = key
if isinstance(oldvalue, cls):
oldvalue._parents.pop(inspect(target), None)
return value
def pickle(state, state_dict):
val = state.dict.get(key, None)
if val is not None:
if "ext.mutable.values" not in state_dict:
state_dict["ext.mutable.values"] = defaultdict(list)
state_dict["ext.mutable.values"][key].append(val)
def unpickle(state, state_dict):
if "ext.mutable.values" in state_dict:
collection = state_dict["ext.mutable.values"]
if isinstance(collection, list):
# legacy format
for val in collection:
val._parents[state] = key
else:
for val in state_dict["ext.mutable.values"][key]:
val._parents[state] = key
event.listen(parent_cls, "load", load, raw=True, propagate=True)
event.listen(
parent_cls, "refresh", load_attrs, raw=True, propagate=True
)
event.listen(
parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
)
event.listen(
attribute, "set", set_, raw=True, retval=True, propagate=True
)
event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
event.listen(
parent_cls, "unpickle", unpickle, raw=True, propagate=True
)
class Mutable(MutableBase):
"""Mixin that defines transparent propagation of change
events to a parent object.
See the example in :ref:`mutable_scalars` for usage information.
"""
def changed(self):
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
flag_modified(parent.obj(), key)
@classmethod
def associate_with_attribute(cls, attribute):
"""Establish this type as a mutation listener for the given
mapped descriptor.
"""
cls._listen_on_attribute(attribute, True, attribute.class_)
@classmethod
def associate_with(cls, sqltype):
"""Associate this wrapper with all future mapped columns
of the given type.
This is a convenience method that calls
``associate_with_attribute`` automatically.
.. warning::
The listeners established by this method are *global*
to all mappers, and are *not* garbage collected. Only use
:meth:`.associate_with` for types that are permanent to an
application, not with ad-hoc types else this will cause unbounded
growth in memory usage.
"""
def listen_for_type(mapper, class_):
if mapper.non_primary:
return
for prop in mapper.column_attrs:
if isinstance(prop.columns[0].type, sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
event.listen(mapper, "mapper_configured", listen_for_type)
@classmethod
def as_mutable(cls, sqltype):
"""Associate a SQL type with this mutable Python type.
This establishes listeners that will detect ORM mappings against
the given type, adding mutation event trackers to those mappings.
The type is returned, unconditionally as an instance, so that
:meth:`.as_mutable` can be used inline::
Table('mytable', metadata,
Column('id', Integer, primary_key=True),
Column('data', MyMutableType.as_mutable(PickleType))
)
Note that the returned type is always an instance, even if a class
is given, and that only columns which are declared specifically with
that type instance receive additional instrumentation.
To associate a particular mutable type with all occurrences of a
particular type, use the :meth:`.Mutable.associate_with` classmethod
of the particular :class:`.Mutable` subclass to establish a global
association.
.. warning::
The listeners established by this method are *global*
to all mappers, and are *not* garbage collected. Only use
:meth:`.as_mutable` for types that are permanent to an application,
not with ad-hoc types else this will cause unbounded growth
in memory usage.
"""
sqltype = types.to_instance(sqltype)
# a SchemaType will be copied when the Column is copied,
# and we'll lose our ability to link that type back to the original.
# so track our original type w/ columns
if isinstance(sqltype, SchemaEventTarget):
@event.listens_for(sqltype, "before_parent_attach")
def _add_column_memo(sqltyp, parent):
parent.info["_ext_mutable_orig_type"] = sqltyp
schema_event_check = True
else:
schema_event_check = False
def listen_for_type(mapper, class_):
if mapper.non_primary:
return
for prop in mapper.column_attrs:
if (
schema_event_check
and hasattr(prop.expression, "info")
and prop.expression.info.get("_ext_mutable_orig_type")
is sqltype
) or (prop.columns[0].type is sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
event.listen(mapper, "mapper_configured", listen_for_type)
return sqltype
class MutableComposite(MutableBase):
"""Mixin that defines transparent propagation of change
events on a SQLAlchemy "composite" object to its
owning parent or parents.
See the example in :ref:`mutable_composites` for usage information.
"""
@classmethod
def _get_listen_keys(cls, attribute):
return {attribute.key}.union(attribute.property._attribute_keys)
def changed(self):
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
prop = parent.mapper.get_property(key)
for value, attr_name in zip(
self.__composite_values__(), prop._attribute_keys
):
setattr(parent.obj(), attr_name, value)
def _setup_composite_listener():
def _listen_for_type(mapper, class_):
for prop in mapper.iterate_properties:
if (
hasattr(prop, "composite_class")
and isinstance(prop.composite_class, type)
and issubclass(prop.composite_class, MutableComposite)
):
prop.composite_class._listen_on_attribute(
getattr(class_, prop.key), False, class_
)
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
event.listen(Mapper, "mapper_configured", _listen_for_type)
_setup_composite_listener()
class MutableDict(Mutable, dict):
"""A dictionary type that implements :class:`.Mutable`.
The :class:`.MutableDict` object implements a dictionary that will
emit change events to the underlying mapping when the contents of
the dictionary are altered, including when values are added or removed.
Note that :class:`.MutableDict` does **not** apply mutable tracking to the
*values themselves* inside the dictionary. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
dictionary structure, such as a JSON structure. To support this use case,
build a subclass of :class:`.MutableDict` that provides appropriate
coercion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. seealso::
:class:`.MutableList`
:class:`.MutableSet`
"""
def __setitem__(self, key, value):
"""Detect dictionary set events and emit change events."""
dict.__setitem__(self, key, value)
self.changed()
def setdefault(self, key, value):
result = dict.setdefault(self, key, value)
self.changed()
return result
def __delitem__(self, key):
"""Detect dictionary del events and emit change events."""
dict.__delitem__(self, key)
self.changed()
def update(self, *a, **kw):
dict.update(self, *a, **kw)
self.changed()
def pop(self, *arg):
result = dict.pop(self, *arg)
self.changed()
return result
def popitem(self):
result = dict.popitem(self)
self.changed()
return result
def clear(self):
dict.clear(self)
self.changed()
@classmethod
def coerce(cls, key, value):
"""Convert plain dictionary to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, dict):
return cls(value)
return Mutable.coerce(key, value)
else:
return value
def __getstate__(self):
return dict(self)
def __setstate__(self, state):
self.update(state)
class MutableList(Mutable, list):
"""A list type that implements :class:`.Mutable`.
The :class:`.MutableList` object implements a list that will
emit change events to the underlying mapping when the contents of
the list are altered, including when values are added or removed.
Note that :class:`.MutableList` does **not** apply mutable tracking to the
*values themselves* inside the list. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
mutable structure, such as a JSON structure. To support this use case,
build a subclass of :class:`.MutableList` that provides appropriate
coercion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. versionadded:: 1.1
.. seealso::
:class:`.MutableDict`
:class:`.MutableSet`
"""
def __reduce_ex__(self, proto):
return (self.__class__, (list(self),))
# needed for backwards compatibility with
# older pickles
def __setstate__(self, state):
self[:] = state
def __setitem__(self, index, value):
"""Detect list set events and emit change events."""
list.__setitem__(self, index, value)
self.changed()
def __setslice__(self, start, end, value):
"""Detect list set events and emit change events."""
list.__setslice__(self, start, end, value)
self.changed()
def __delitem__(self, index):
"""Detect list del events and emit change events."""
list.__delitem__(self, index)
self.changed()
def __delslice__(self, start, end):
"""Detect list del events and emit change events."""
list.__delslice__(self, start, end)
self.changed()
def pop(self, *arg):
result = list.pop(self, *arg)
self.changed()
return result
def append(self, x):
list.append(self, x)
self.changed()
def extend(self, x):
list.extend(self, x)
self.changed()
def __iadd__(self, x):
self.extend(x)
return self
def insert(self, i, x):
list.insert(self, i, x)
self.changed()
def remove(self, i):
list.remove(self, i)
self.changed()
def clear(self):
list.clear(self)
self.changed()
def sort(self, **kw):
list.sort(self, **kw)
self.changed()
def reverse(self):
list.reverse(self)
self.changed()
@classmethod
def coerce(cls, index, value):
"""Convert plain list to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, list):
return cls(value)
return Mutable.coerce(index, value)
else:
return value
class MutableSet(Mutable, set):
"""A set type that implements :class:`.Mutable`.
The :class:`.MutableSet` object implements a set that will
emit change events to the underlying mapping when the contents of
the set are altered, including when values are added or removed.
Note that :class:`.MutableSet` does **not** apply mutable tracking to the
*values themselves* inside the set. Therefore it is not a sufficient
solution for the use case of tracking deep changes to a *recursive*
mutable structure. To support this use case,
build a subclass of :class:`.MutableSet` that provides appropriate
coercion to the values placed in the dictionary so that they too are
"mutable", and emit events up to their parent structure.
.. versionadded:: 1.1
.. seealso::
:class:`.MutableDict`
:class:`.MutableList`
"""
def update(self, *arg):
set.update(self, *arg)
self.changed()
def intersection_update(self, *arg):
set.intersection_update(self, *arg)
self.changed()
def difference_update(self, *arg):
set.difference_update(self, *arg)
self.changed()
def symmetric_difference_update(self, *arg):
set.symmetric_difference_update(self, *arg)
self.changed()
def __ior__(self, other):
self.update(other)
return self
def __iand__(self, other):
self.intersection_update(other)
return self
def __ixor__(self, other):
self.symmetric_difference_update(other)
return self
def __isub__(self, other):
self.difference_update(other)
return self
def add(self, elem):
set.add(self, elem)
self.changed()
def remove(self, elem):
set.remove(self, elem)
self.changed()
def discard(self, elem):
set.discard(self, elem)
self.changed()
def pop(self, *arg):
result = set.pop(self, *arg)
self.changed()
return result
def clear(self):
set.clear(self)
self.changed()
@classmethod
def coerce(cls, index, value):
"""Convert plain set to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, set):
return cls(value)
return Mutable.coerce(index, value)
else:
return value
def __getstate__(self):
return set(self)
def __setstate__(self, state):
self.update(state)
def __reduce_ex__(self, proto):
return (self.__class__, (list(self),))

View File

View File

@@ -0,0 +1,299 @@
# ext/mypy/apply.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import ARG_NAMED_OPT
from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import MDEF
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneTyp
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import infer
from . import util
from .names import NAMED_TYPE_SQLA_MAPPED
def apply_mypy_mapped_attr(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
attributes: List[util.SQLAlchemyAttribute],
) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
name = item.value
else:
return None
for stmt in cls.defs.body:
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name == name
):
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
return None
if stmt.type is None:
util.fail(
api,
"Statement linked from _mypy_mapped_attrs has no "
"typing information",
stmt,
)
return None
left_hand_explicit_type = get_proper_type(stmt.type)
assert isinstance(
left_hand_explicit_type, (Instance, UnionType, UnboundType)
)
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=item.line,
column=item.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
apply_type_to_mapped_statement(
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
)
def re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
"""
mapped_attr_lookup = {attr.name: attr for attr in attributes}
update_cls_metadata = False
for stmt in cls.defs.body:
# for a re-apply, all of our statements are AssignmentStmt;
# @declared_attr calls will have been converted and this
# currently seems to be preserved by mypy (but who knows if this
# will change).
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and stmt.lvalues[0].name in mapped_attr_lookup
and isinstance(stmt.lvalues[0].node, Var)
):
left_node = stmt.lvalues[0].node
python_type_for_type = mapped_attr_lookup[
stmt.lvalues[0].name
].type
left_node_proper_type = get_proper_type(left_node.type)
# if we have scanned an UnboundType and now there's a more
# specific type than UnboundType, call the re-scan so we
# can get that set up correctly
if (
isinstance(python_type_for_type, UnboundType)
and not isinstance(left_node_proper_type, UnboundType)
and (
isinstance(stmt.rvalue, CallExpr)
and isinstance(stmt.rvalue.callee, MemberExpr)
and isinstance(stmt.rvalue.callee.expr, NameExpr)
and stmt.rvalue.callee.expr.node is not None
and stmt.rvalue.callee.expr.node.fullname
== NAMED_TYPE_SQLA_MAPPED
and stmt.rvalue.callee.name == "_empty_constructor"
and isinstance(stmt.rvalue.args[0], CallExpr)
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
python_type_for_type = (
infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
left_node,
left_node_proper_type,
stmt.rvalue.args[0].callee,
)
)
if python_type_for_type is None or isinstance(
python_type_for_type, UnboundType
):
continue
# update the SQLAlchemyAttribute with the better information
mapped_attr_lookup[
stmt.lvalues[0].name
].type = python_type_for_type
update_cls_metadata = True
if python_type_for_type is not None:
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
)
if update_cls_metadata:
util.set_mapped_attributes(cls.info, attributes)
def apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
left_hand_explicit_type: Optional[ProperType],
python_type_for_type: Optional[ProperType],
) -> None:
"""Apply the Mapped[<type>] annotation and right hand object to a
declarative assignment statement.
This converts a Python declarative class statement such as::
class User(Base):
# ...
attrname = Column(Integer)
To one that describes the final Python behavior to Mypy::
class User(Base):
# ...
attrname : Mapped[Optional[int]] = <meaningless temp node>
"""
left_node = lvalue.node
assert isinstance(left_node, Var)
if left_hand_explicit_type is not None:
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
else:
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED,
[] if python_type_for_type is None else [python_type_for_type],
)
# so to have it skip the right side totally, we can do this:
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
# however, if we instead manufacture a new node that uses the old
# one, then we can still get type checking for the call itself,
# e.g. the Column, relationship() call, etc.
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
def add_additional_orm_attributes(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Apply __init__, __table__ and other attributes to the mapped class."""
info = util.info_for_cls(cls, api)
if info is None:
return
is_base = util.get_is_base(info)
if "__init__" not in info.names and not is_base:
mapped_attr_names = {attr.name: attr.type for attr in attributes}
for base in info.mro[1:-1]:
if "sqlalchemy" not in info.metadata:
continue
base_cls_attributes = util.get_mapped_attributes(base, api)
if base_cls_attributes is None:
continue
for attr in base_cls_attributes:
mapped_attr_names.setdefault(attr.name, attr.type)
arguments = []
for name, typ in mapped_attr_names.items():
if typ is None:
typ = AnyType(TypeOfAny.special_form)
arguments.append(
Argument(
variable=Var(name, typ),
type_annotation=typ,
initializer=TempNode(typ),
kind=ARG_NAMED_OPT,
)
)
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
if "__table__" not in info.names and util.get_has_table(info):
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
)
if not is_base:
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
)
def _apply_placeholder_attr_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
qualified_name: str,
attrname: str,
) -> None:
sym = api.lookup_fully_qualified_or_none(qualified_name)
if sym:
assert isinstance(sym.node, TypeInfo)
type_: ProperType = Instance(sym.node, [])
else:
type_ = AnyType(TypeOfAny.special_form)
var = Var(attrname)
var._fullname = cls.fullname + "." + attrname
var.info = cls.info
var.type = type_
cls.info.names[attrname] = SymbolTableNode(MDEF, var)

View File

@@ -0,0 +1,516 @@
# ext/mypy/decl_class.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from typing import List
from typing import Optional
from typing import Union
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
from mypy.nodes import LambdaExpr
from mypy.nodes import ListExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolNode
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import Type
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
from . import apply
from . import infer
from . import names
from . import util
def scan_declarative_assignments_and_apply_types(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
) -> Optional[List[util.SQLAlchemyAttribute]]:
info = util.info_for_cls(cls, api)
if info is None:
# this can occur during cached passes
return None
elif cls.fullname.startswith("builtins"):
return None
mapped_attributes: Optional[
List[util.SQLAlchemyAttribute]
] = util.get_mapped_attributes(info, api)
# used by assign.add_additional_orm_attributes among others
util.establish_as_sqlalchemy(info)
if mapped_attributes is not None:
# ensure that a class that's mapped is always picked up by
# its mapped() decorator or declarative metaclass before
# it would be detected as an unmapped mixin class
if not is_mixin_scan:
# mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them, or if we have an UnboundType,
# we can re-scan
apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
return mapped_attributes
mapped_attributes = []
if not cls.defs.body:
# when we get a mixin class from another file, the body is
# empty (!) but the names are in the symbol table. so use that.
for sym_name, sym in info.names.items():
_scan_symbol_table_entry(
cls, api, sym_name, sym, mapped_attributes
)
else:
for stmt in util.flatten_typechecking(cls.defs.body):
if isinstance(stmt, AssignmentStmt):
_scan_declarative_assignment_stmt(
cls, api, stmt, mapped_attributes
)
elif isinstance(stmt, Decorator):
_scan_declarative_decorator_stmt(
cls, api, stmt, mapped_attributes
)
_scan_for_mapped_bases(cls, api)
if not is_mixin_scan:
apply.add_additional_orm_attributes(cls, api, mapped_attributes)
util.set_mapped_attributes(info, mapped_attributes)
return mapped_attributes
def _scan_symbol_table_entry(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
name: str,
value: SymbolTableNode,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
"""
value_type = get_proper_type(value.type)
if not isinstance(value_type, Instance):
return
left_hand_explicit_type = None
type_id = names.type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
# TODO: this is nearly the same logic as that of
# _scan_declarative_decorator_stmt, likely can be merged
if type_id in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}:
if value_type.args:
left_hand_explicit_type = get_proper_type(value_type.args[0])
else:
err = True
elif type_id is names.COLUMN:
if not value_type.args:
err = True
else:
typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
value_type.args[0]
)
if isinstance(typeengine_arg, Instance):
typeengine_arg = typeengine_arg.type
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
value_type,
)
if err:
msg = (
"Can't infer type from attribute {} on class {}. "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(name, cls.name), cls)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
if left_hand_explicit_type is not None:
assert value.node is not None
attributes.append(
util.SQLAlchemyAttribute(
name=name,
line=value.node.line,
column=value.node.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
E.g.::
@reg.mapped
class MyClass:
# ...
@declared_attr
def updated_at(cls) -> Column[DateTime]:
return Column(DateTime)
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
updated_at: Mapped[Optional[datetime.datetime]]
"""
for dec in stmt.decorators:
if (
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
):
break
else:
return
dec_index = cls.defs.body.index(stmt)
left_hand_explicit_type: Optional[ProperType] = None
if util.name_is_dunder(stmt.name):
# for dunder names like __table_args__, __tablename__,
# __mapper_args__ etc., rewrite these as simple assignment
# statements; otherwise mypy doesn't like if the decorated
# function has an annotation like ``cls: Type[Foo]`` because
# it isn't @classmethod
any_ = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
new_stmt = AssignmentStmt([left_node], TempNode(any_))
new_stmt.type = left_node.node.type
cls.defs.body[dec_index] = new_stmt
return
elif isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
type_id = names.type_id_for_unbound_type(func_type, cls, api)
else:
# this does not seem to occur unless the type argument is
# incorrect
return
if (
type_id
in {
names.MAPPED,
names.RELATIONSHIP,
names.COMPOSITE_PROPERTY,
names.MAPPER_PROPERTY,
names.SYNONYM_PROPERTY,
names.COLUMN_PROPERTY,
}
and func_type.args
):
left_hand_explicit_type = get_proper_type(func_type.args[0])
elif type_id is names.COLUMN and func_type.args:
typeengine_arg = func_type.args[0]
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
]
)
else:
util.fail(
api,
"Column type should be a TypeEngine "
"subclass not '{}'".format(sym.node.fullname),
func_type,
)
if left_hand_explicit_type is None:
# no type on the decorated function. our option here is to
# dig into the function body and get the return type, but they
# should just have an annotation.
msg = (
"Can't infer type from @declared_attr on function '{}'; "
"please specify a return type from this function that is "
"one of: Mapped[<python type>], relationship[<target class>], "
"Column[<TypeEngine>], MapperProperty[<python type>]"
)
util.fail(api, msg.format(stmt.var.name), stmt)
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
left_node = NameExpr(stmt.var.name)
left_node.node = stmt.var
# totally feeling around in the dark here as I don't totally understand
# the significance of UnboundType. It seems to be something that is
# not going to do what's expected when it is applied as the type of
# an AssignmentStatement. So do a feeling-around-in-the-dark version
# of converting it to the regular Instance/TypeInfo/UnionType structures
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
left_hand_explicit_type = get_proper_type(
util.unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
)
# this will ignore the rvalue entirely
# rvalue = TempNode(AnyType(TypeOfAny.special_form))
# rewrite the node as:
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
rvalue = util.expr_to_mapped_constructor(
LambdaExpr(stmt.func.arguments, stmt.func.body)
)
new_stmt = AssignmentStmt([left_node], rvalue)
new_stmt.type = left_node.node.type
attributes.append(
util.SQLAlchemyAttribute(
name=left_node.name,
line=stmt.line,
column=stmt.column,
typ=left_hand_explicit_type,
info=cls.info,
)
)
cls.defs.body[dec_index] = new_stmt
def _scan_declarative_assignment_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
"""
lvalue = stmt.lvalues[0]
if not isinstance(lvalue, NameExpr):
return
sym = cls.info.names.get(lvalue.name)
# this establishes that semantic analysis has taken place, which
# means the nodes are populated and we are called from an appropriate
# hook.
assert sym is not None
node = sym.node
if isinstance(node, PlaceholderNode):
return
assert node is lvalue.node
assert isinstance(node, Var)
if node.name == "__abstract__":
if api.parse_bool(stmt.rvalue) is True:
util.set_is_base(cls.info)
return
elif node.name == "__tablename__":
util.set_has_table(cls.info)
elif node.name.startswith("__"):
return
elif node.name == "_mypy_mapped_attrs":
if not isinstance(stmt.rvalue, ListExpr):
util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
apply.apply_mypy_mapped_attr(cls, api, item, attributes)
left_hand_mapped_type: Optional[Type] = None
left_hand_explicit_type: Optional[ProperType] = None
if node.is_inferred or node.type is None:
if isinstance(stmt.type, UnboundType):
# look for an explicit Mapped[] type annotation on the left
# side with nothing on the right
# print(stmt.type)
# Mapped?[Optional?[A?]]
left_hand_explicit_type = stmt.type
if stmt.type.name == "Mapped":
mapped_sym = api.lookup_qualified("Mapped", cls)
if (
mapped_sym is not None
and mapped_sym.node is not None
and names.type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
left_hand_explicit_type = get_proper_type(
stmt.type.args[0]
)
left_hand_mapped_type = stmt.type
# TODO: do we need to convert from unbound for this case?
# left_hand_explicit_type = util._unbound_to_instance(
# api, left_hand_explicit_type
# )
else:
node_type = get_proper_type(node.type)
if (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
left_hand_explicit_type = get_proper_type(node_type.args[0])
left_hand_mapped_type = node_type
else:
# print(node.type)
# <python type>
left_hand_explicit_type = node_type
left_hand_mapped_type = None
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
# annotation without assignment and Mapped is present
# as type annotation
# equivalent to using _infer_type_from_left_hand_type_only.
python_type_for_type = left_hand_explicit_type
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
stmt.rvalue.callee, RefExpr
):
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
if python_type_for_type is None:
return
else:
return
assert python_type_for_type is not None
attributes.append(
util.SQLAlchemyAttribute(
name=node.name,
line=stmt.line,
column=stmt.column,
typ=python_type_for_type,
info=cls.info,
)
)
apply.apply_type_to_mapped_statement(
api,
stmt,
lvalue,
left_hand_explicit_type,
python_type_for_type,
)
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
Locates non-mapped mixins and scans them for mapped attributes to be
applied to subclasses.
"""
info = util.info_for_cls(cls, api)
if info is None:
return
for base_info in info.mro[1:-1]:
if base_info.fullname.startswith("builtins"):
continue
# scan each base for mapped attributes. if they are not already
# scanned (but have all their type info), that means they are unmapped
# mixins
scan_declarative_assignments_and_apply_types(
base_info.defn, api, is_mixin_scan=True
)

View File

@@ -0,0 +1,556 @@
# ext/mypy/infer.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from typing import Optional
from typing import Sequence
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import LambdaExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.subtypes import is_subtype
from mypy.types import AnyType
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import ProperType
from mypy.types import TypeOfAny
from mypy.types import UnionType
from . import names
from . import util
def infer_type_from_right_hand_nameexpr(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
type_id = names.type_id_for_callee(infer_from_right_side)
if type_id is None:
return None
elif type_id is names.COLUMN:
python_type_for_type = _infer_type_from_decl_column(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.RELATIONSHIP:
python_type_for_type = _infer_type_from_relationship(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.COLUMN_PROPERTY:
python_type_for_type = _infer_type_from_decl_column_property(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.SYNONYM_PROPERTY:
python_type_for_type = infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif type_id is names.COMPOSITE_PROPERTY:
python_type_for_type = _infer_type_from_decl_composite_property(
api, stmt, node, left_hand_explicit_type
)
else:
return None
return python_type_for_type
def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a relationship.
E.g.::
@reg.mapped
class MyClass:
# ...
addresses = relationship(Address, uselist=True)
order: Mapped["Order"] = relationship("Order")
Will resolve in mypy as::
@reg.mapped
class MyClass:
# ...
addresses: Mapped[List[Address]]
order: Mapped["Order"]
"""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type: Optional[ProperType] = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
# type
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
# other cases not covered - an error message directs the user
# to set an explicit type annotation
#
# node.type == str, it's a string
# if isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, Var
# )
# points to a type
# isinstance(target_cls_arg, NameExpr) and isinstance(
# target_cls_arg.node, TypeAlias
# )
# string expression
# isinstance(target_cls_arg, StrExpr)
uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
# this can be used to determine Optional for a many-to-one
# in the same way nullable=False could be used, if we start supporting
# that.
# innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
if (
uselist_arg is not None
and api.parse_bool(uselist_arg) is True
and collection_cls_arg is None
):
type_is_a_collection = True
if python_type_for_type is not None:
python_type_for_type = api.named_type(
names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
)
elif (
uselist_arg is None or api.parse_bool(uselist_arg) is True
) and collection_cls_arg is not None:
type_is_a_collection = True
if isinstance(collection_cls_arg, CallExpr):
collection_cls_arg = collection_cls_arg.callee
if isinstance(collection_cls_arg, NameExpr) and isinstance(
collection_cls_arg.node, TypeInfo
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
python_type_for_type = Instance(
collection_cls_arg.node, [python_type_for_type]
)
elif (
isinstance(collection_cls_arg, NameExpr)
and isinstance(collection_cls_arg.node, FuncDef)
and collection_cls_arg.node.type is not None
):
if python_type_for_type is not None:
# this can still be overridden by the left hand side
# within _infer_Type_from_left_and_inferred_right
# TODO: handle mypy.types.Overloaded
if isinstance(collection_cls_arg.node.type, CallableType):
rt = get_proper_type(collection_cls_arg.node.type.ret_type)
if isinstance(rt, CallableType):
callable_ret_type = get_proper_type(rt.ret_type)
if isinstance(callable_ret_type, Instance):
python_type_for_type = Instance(
callable_ret_type.type,
[python_type_for_type],
)
else:
util.fail(
api,
"Expected Python collection type for "
"collection_class parameter",
stmt.rvalue,
)
python_type_for_type = None
elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
if collection_cls_arg is not None:
util.fail(
api,
"Sending uselist=False and collection_class at the same time "
"does not make sense",
stmt.rvalue,
)
if python_type_for_type is not None:
python_type_for_type = UnionType(
[python_type_for_type, NoneType()]
)
else:
if left_hand_explicit_type is None:
msg = (
"Can't infer scalar or collection for ORM mapped expression "
"assigned to attribute '{}' if both 'uselist' and "
"'collection_class' arguments are absent from the "
"relationship(); please specify a "
"type annotation on the left hand side."
)
util.fail(api, msg.format(node.name), node)
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
if type_is_a_collection:
assert isinstance(left_hand_explicit_type, Instance)
assert isinstance(python_type_for_type, Instance)
return _infer_collection_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_explicit_type,
python_type_for_type,
)
else:
return python_type_for_type
def _infer_type_from_decl_composite_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a CompositeProperty."""
assert isinstance(stmt.rvalue, CallExpr)
target_cls_arg = stmt.rvalue.args[0]
python_type_for_type = None
if isinstance(target_cls_arg, NameExpr) and isinstance(
target_cls_arg.node, TypeInfo
):
related_object_type = target_cls_arg.node
python_type_for_type = Instance(related_object_type, [])
else:
python_type_for_type = None
if python_type_for_type is None:
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return python_type_for_type
def _infer_type_from_decl_column_property(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Infer the type of mapping from a ColumnProperty.
This includes mappings against ``column_property()`` as well as the
``deferred()`` function.
"""
assert isinstance(stmt.rvalue, CallExpr)
if stmt.rvalue.args:
first_prop_arg = stmt.rvalue.args[0]
if isinstance(first_prop_arg, CallExpr):
type_id = names.type_id_for_callee(first_prop_arg.callee)
# look for column_property() / deferred() etc with Column as first
# argument
if type_id is names.COLUMN:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
right_hand_expression=first_prop_arg,
)
if isinstance(stmt.rvalue, CallExpr):
type_id = names.type_id_for_callee(stmt.rvalue.callee)
# this is probably not strictly necessary as we have to use the left
# hand type for query expression in any case. any other no-arg
# column prop objects would go here also
if type_id is names.QUERY_EXPRESSION:
return _infer_type_from_decl_column(
api,
stmt,
node,
left_hand_explicit_type,
)
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_decl_column(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
right_hand_expression: Optional[CallExpr] = None,
) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
E.g.::
@reg.mapped
class MyClass:
# ...
a = Column(Integer)
b = Column("b", String)
c: Mapped[int] = Column(Integer)
d: bool = Column(Boolean)
Will resolve in MyPy as::
@reg.mapped
class MyClass:
# ...
a : Mapped[int]
b : Mapped[str]
c: Mapped[int]
d: Mapped[bool]
"""
assert isinstance(node, Var)
callee = None
if right_hand_expression is None:
if not isinstance(stmt.rvalue, CallExpr):
return None
right_hand_expression = stmt.rvalue
for column_arg in right_hand_expression.args[0:2]:
if isinstance(column_arg, CallExpr):
if isinstance(column_arg.callee, RefExpr):
# x = Column(String(50))
callee = column_arg.callee
type_args: Sequence[Expression] = column_arg.args
break
elif isinstance(column_arg, (NameExpr, MemberExpr)):
if isinstance(column_arg.node, TypeInfo):
# x = Column(String)
callee = column_arg
type_args = ()
break
else:
# x = Column(some_name, String), go to next argument
continue
elif isinstance(column_arg, (StrExpr,)):
# x = Column("name", String), go to next argument
continue
elif isinstance(column_arg, (LambdaExpr,)):
# x = Column("name", String, default=lambda: uuid.uuid4())
# go to next argument
continue
else:
assert False
if callee is None:
return None
if isinstance(callee.node, TypeInfo) and names.mro_has_id(
callee.node.mro, names.TYPEENGINE
):
python_type_for_type = extract_python_type_from_typeengine(
api, callee.node, type_args
)
if left_hand_explicit_type is not None:
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
else:
return UnionType([python_type_for_type, NoneType()])
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
def _infer_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: ProperType,
python_type_for_type: ProperType,
orig_left_hand_type: Optional[ProperType] = None,
orig_python_type_for_type: Optional[ProperType] = None,
) -> Optional[ProperType]:
"""Validate type when a left hand annotation is present and we also
could infer the right hand side::
attrname: SomeType = Column(SomeDBType)
"""
if orig_left_hand_type is None:
orig_left_hand_type = left_hand_explicit_type
if orig_python_type_for_type is None:
orig_python_type_for_type = python_type_for_type
if not is_subtype(left_hand_explicit_type, python_type_for_type):
effective_type = api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
)
msg = (
"Left hand assignment '{}: {}' not compatible "
"with ORM mapped expression of type {}"
)
util.fail(
api,
msg.format(
node.name,
format_type(orig_left_hand_type),
format_type(effective_type),
),
node,
)
return orig_left_hand_type
def _infer_collection_type_from_left_and_inferred_right(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Instance,
python_type_for_type: Instance,
) -> Optional[ProperType]:
orig_left_hand_type = left_hand_explicit_type
orig_python_type_for_type = python_type_for_type
if left_hand_explicit_type.args:
left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
python_type_arg = get_proper_type(python_type_for_type.args[0])
else:
left_hand_arg = left_hand_explicit_type
python_type_arg = python_type_for_type
assert isinstance(left_hand_arg, (Instance, UnionType))
assert isinstance(python_type_arg, (Instance, UnionType))
return _infer_type_from_left_and_inferred_right(
api,
node,
left_hand_arg,
python_type_arg,
orig_left_hand_type=orig_left_hand_type,
orig_python_type_for_type=orig_python_type_for_type,
)
def infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Optional[ProperType],
) -> Optional[ProperType]:
"""Determine the type based on explicit annotation only.
if no annotation were present, note that we need one there to know
the type.
"""
if left_hand_explicit_type is None:
msg = (
"Can't infer type from ORM mapped expression "
"assigned to attribute '{}'; please specify a "
"Python type or "
"Mapped[<python type>] on the left hand side."
)
util.fail(api, msg.format(node.name), node)
return api.named_type(
names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
)
else:
# use type from the left hand side
return left_hand_explicit_type
def extract_python_type_from_typeengine(
api: SemanticAnalyzerPluginInterface,
node: TypeInfo,
type_args: Sequence[Expression],
) -> ProperType:
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
first_arg = type_args[0]
if isinstance(first_arg, RefExpr) and isinstance(
first_arg.node, TypeInfo
):
for base_ in first_arg.node.mro:
if base_.fullname == "enum.Enum":
return Instance(first_arg.node, [])
# TODO: support other pep-435 types here
else:
return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
"could not extract Python type from node: %s" % node
)
type_engine_sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.sql.type_api.TypeEngine"
)
assert type_engine_sym is not None and isinstance(
type_engine_sym.node, TypeInfo
)
type_engine = map_instance_to_supertype(
Instance(node, []),
type_engine_sym.node,
)
return get_proper_type(type_engine.args[-1])

View File

@@ -0,0 +1,253 @@
# ext/mypy/names.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
from mypy.nodes import ClassDef
from mypy.nodes import Expression
from mypy.nodes import FuncDef
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import SymbolNode
from mypy.nodes import TypeAlias
from mypy.nodes import TypeInfo
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import CallableType
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import UnboundType
from ... import util
COLUMN: int = util.symbol("COLUMN") # type: ignore
RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
REGISTRY: int = util.symbol("REGISTRY") # type: ignore
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
MAPPED: int = util.symbol("MAPPED") # type: ignore
DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
# names that must succeed with mypy.api.named_type
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
COLUMN,
{
"sqlalchemy.sql.schema.Column",
"sqlalchemy.sql.Column",
},
),
"RelationshipProperty": (
RELATIONSHIP,
{
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.RelationshipProperty",
},
),
"registry": (
REGISTRY,
{
"sqlalchemy.orm.decl_api.registry",
"sqlalchemy.orm.registry",
},
),
"ColumnProperty": (
COLUMN_PROPERTY,
{
"sqlalchemy.orm.properties.ColumnProperty",
"sqlalchemy.orm.ColumnProperty",
},
),
"SynonymProperty": (
SYNONYM_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.SynonymProperty",
"sqlalchemy.orm.SynonymProperty",
},
),
"CompositeProperty": (
COMPOSITE_PROPERTY,
{
"sqlalchemy.orm.descriptor_props.CompositeProperty",
"sqlalchemy.orm.CompositeProperty",
},
),
"MapperProperty": (
MAPPER_PROPERTY,
{
"sqlalchemy.orm.interfaces.MapperProperty",
"sqlalchemy.orm.MapperProperty",
},
),
"TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
"Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
"declarative_base": (
DECLARATIVE_BASE,
{
"sqlalchemy.ext.declarative.declarative_base",
"sqlalchemy.orm.declarative_base",
"sqlalchemy.orm.decl_api.declarative_base",
},
),
"DeclarativeMeta": (
DECLARATIVE_META,
{
"sqlalchemy.ext.declarative.DeclarativeMeta",
"sqlalchemy.orm.DeclarativeMeta",
"sqlalchemy.orm.decl_api.DeclarativeMeta",
},
),
"mapped": (
MAPPED_DECORATOR,
{
"sqlalchemy.orm.decl_api.registry.mapped",
"sqlalchemy.orm.registry.mapped",
},
),
"as_declarative": (
AS_DECLARATIVE,
{
"sqlalchemy.ext.declarative.as_declarative",
"sqlalchemy.orm.decl_api.as_declarative",
"sqlalchemy.orm.as_declarative",
},
),
"as_declarative_base": (
AS_DECLARATIVE_BASE,
{
"sqlalchemy.orm.decl_api.registry.as_declarative_base",
"sqlalchemy.orm.registry.as_declarative_base",
},
),
"declared_attr": (
DECLARED_ATTR,
{
"sqlalchemy.orm.decl_api.declared_attr",
"sqlalchemy.orm.declared_attr",
},
),
"declarative_mixin": (
DECLARATIVE_MIXIN,
{
"sqlalchemy.orm.decl_api.declarative_mixin",
"sqlalchemy.orm.declarative_mixin",
},
),
"query_expression": (
QUERY_EXPRESSION,
{"sqlalchemy.orm.query_expression"},
),
}
def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
for mr in info.mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
for mr in mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
break
else:
return False
if fullnames is None:
return False
return mr.fullname in fullnames
def type_id_for_unbound_type(
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[int]:
sym = api.lookup_qualified(type_.name, type_)
if sym is not None:
if isinstance(sym.node, TypeAlias):
target_type = get_proper_type(sym.node.target)
if isinstance(target_type, Instance):
return type_id_for_named_node(target_type.type)
elif isinstance(sym.node, TypeInfo):
return type_id_for_named_node(sym.node)
return None
def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, FuncDef):
if callee.node.type and isinstance(callee.node.type, CallableType):
ret_type = get_proper_type(callee.node.type.ret_type)
if isinstance(ret_type, Instance):
return type_id_for_fullname(ret_type.type.fullname)
return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
return type_id_for_fullname(target_type.type.fullname)
elif isinstance(callee.node, TypeInfo):
return type_id_for_named_node(callee)
return None
def type_id_for_named_node(
node: Union[NameExpr, MemberExpr, SymbolNode]
) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))
if type_id is None or fullnames is None:
return None
elif node.fullname in fullnames:
return type_id
else:
return None
def type_id_for_fullname(fullname: str) -> Optional[int]:
tokens = fullname.split(".")
immediate = tokens[-1]
type_id, fullnames = _lookup.get(immediate, (None, None))
if type_id is None or fullnames is None:
return None
elif fullname in fullnames:
return type_id
else:
return None

View File

@@ -0,0 +1,284 @@
# ext/mypy/plugin.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
Mypy plugin for SQLAlchemy ORM.
"""
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type as TypingType
from typing import Union
from mypy import nodes
from mypy.mro import calculate_mro
from mypy.mro import MroError
from mypy.nodes import Block
from mypy.nodes import ClassDef
from mypy.nodes import GDEF
from mypy.nodes import MypyFile
from mypy.nodes import NameExpr
from mypy.nodes import SymbolTable
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import AttributeContext
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import Plugin
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.types import get_proper_type
from mypy.types import Instance
from mypy.types import Type
from . import decl_class
from . import names
from . import util
class SQLAlchemyPlugin(Plugin):
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
return _dynamic_class_hook
return None
def get_customize_class_mro_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return _fill_in_decorators
def get_class_decorator_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if sym is not None and sym.node is not None:
type_id = names.type_id_for_named_node(sym.node)
if type_id is names.MAPPED_DECORATOR:
return _cls_decorator_hook
elif type_id in (
names.AS_DECLARATIVE,
names.AS_DECLARATIVE_BASE,
):
return _base_cls_decorator_hook
elif type_id is names.DECLARATIVE_MIXIN:
return _declarative_mixin_hook
return None
def get_metaclass_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
# Set any classes that explicitly have metaclass=DeclarativeMeta
# as declarative so the check in `get_base_class_hook()` works
return _metaclass_cls_hook
return None
def get_base_class_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
if (
sym
and isinstance(sym.node, TypeInfo)
and util.has_declarative_base(sym.node)
):
return _base_cls_hook
return None
def get_attribute_hook(
self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
if fullname.startswith(
"sqlalchemy.orm.attributes.QueryableAttribute."
):
return _queryable_getattr_hook
return None
def get_additional_deps(
self, file: MypyFile
) -> List[Tuple[int, str, int]]:
return [
(10, "sqlalchemy.orm.attributes", -1),
(10, "sqlalchemy.orm.decl_api", -1),
]
def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
return SQLAlchemyPlugin
def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
"""Generate a declarative Base class when the declarative_base() function
is encountered."""
_add_globals(ctx)
cls = ClassDef(ctx.name, Block([]))
cls.fullname = ctx.api.qualified_name(ctx.name)
info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
cls.info = info
_set_declarative_metaclass(ctx.api, cls)
cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
util.set_is_base(cls_arg.node)
decl_class.scan_declarative_assignments_and_apply_types(
cls_arg.node.defn, ctx.api, is_mixin_scan=True
)
info.bases = [Instance(cls_arg.node, [])]
else:
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
try:
calculate_mro(info)
except MroError:
util.fail(
ctx.api, "Not able to calculate MRO for declarative base", ctx.call
)
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
info.bases = [obj]
info.fallback_to_any = True
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
util.set_is_base(info)
def _fill_in_decorators(ctx: ClassDefContext) -> None:
for decorator in ctx.cls.decorators:
# set the ".fullname" attribute of a class decorator
# that is a MemberExpr. This causes the logic in
# semanal.py->apply_class_plugin_hooks to invoke the
# get_class_decorator_hook for our "registry.map_class()"
# and "registry.as_declarative_base()" methods.
# this seems like a bug in mypy that these decorators are otherwise
# skipped.
if (
isinstance(decorator, nodes.CallExpr)
and isinstance(decorator.callee, nodes.MemberExpr)
and decorator.callee.name == "as_declarative_base"
):
target = decorator.callee
elif (
isinstance(decorator, nodes.MemberExpr)
and decorator.name == "mapped"
):
target = decorator
else:
continue
assert isinstance(target.expr, NameExpr)
sym = ctx.api.lookup_qualified(
target.expr.name, target, suppress_errors=True
)
if sym and sym.node:
sym_type = get_proper_type(sym.type)
if isinstance(sym_type, Instance):
target.fullname = f"{sym_type.type.fullname}.{target.name}"
else:
# if the registry is in the same file as where the
# decorator is used, it might not have semantic
# symbols applied and we can't get a fully qualified
# name or an inferred type, so we are actually going to
# flag an error in this case that they need to annotate
# it. The "registry" is declared just
# once (or few times), so they have to just not use
# type inference for its assignment in this one case.
util.fail(
ctx.api,
"Class decorator called %s(), but we can't "
"tell if it's from an ORM registry. Please "
"annotate the registry assignment, e.g. "
"my_registry: registry = registry()" % target.name,
sym.node,
)
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
expr = ctx.reason.expr
assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
node_type = get_proper_type(expr.node.type)
assert (
isinstance(node_type, Instance)
and names.type_id_for_named_node(node_type.type) is names.REGISTRY
)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
cls = ctx.cls
_set_declarative_metaclass(ctx.api, cls)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
cls, ctx.api, is_mixin_scan=True
)
def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
util.set_is_base(ctx.cls.info)
decl_class.scan_declarative_assignments_and_apply_types(
ctx.cls, ctx.api, is_mixin_scan=True
)
def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
util.set_is_base(ctx.cls.info)
def _base_cls_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
# how do I....tell it it has no attribute of a certain name?
# can't find any Type that seems to match that
return ctx.default_attr_type
def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
for all class defs
"""
util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
def _set_declarative_metaclass(
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
) -> None:
info = target_cls.info
sym = api.lookup_fully_qualified_or_none(
"sqlalchemy.orm.decl_api.DeclarativeMeta"
)
assert sym is not None and isinstance(sym.node, TypeInfo)
info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])

View File

@@ -0,0 +1,305 @@
import re
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type as TypingType
from typing import TypeVar
from typing import Union
from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import Expression
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.nodes import TypeInfo
from mypy.plugin import ClassDefContext
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
from mypy.typeops import map_type_from_supertype
from mypy.types import Instance
from mypy.types import NoneType
from mypy.types import Type
from mypy.types import TypeVarType
from mypy.types import UnboundType
from mypy.types import UnionType
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
class SQLAlchemyAttribute:
def __init__(
self,
name: str,
line: int,
column: int,
typ: Optional[Type],
info: TypeInfo,
) -> None:
self.name = name
self.line = line
self.column = column
self.type = typ
self.info = info
def serialize(self) -> JsonDict:
assert self.type
return {
"name": self.name,
"line": self.line,
"column": self.column,
"type": self.type.serialize(),
}
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is
inherited from a generic super type.
"""
if not isinstance(self.type, TypeVarType):
return
self.type = map_type_from_supertype(self.type, sub_type, self.info)
@classmethod
def deserialize(
cls,
info: TypeInfo,
data: JsonDict,
api: SemanticAnalyzerPluginInterface,
) -> "SQLAlchemyAttribute":
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(typ=typ, info=info, **data)
def name_is_dunder(name):
return bool(re.match(r"^__.+?__$", name))
def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
info.metadata.setdefault("sqlalchemy", {})[key] = data
def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
return info.metadata.get("sqlalchemy", {}).get(key, None)
def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
if info.mro:
for base in info.mro:
metadata = _get_info_metadata(base, key)
if metadata is not None:
return metadata
return None
def establish_as_sqlalchemy(info: TypeInfo) -> None:
info.metadata.setdefault("sqlalchemy", {})
def set_is_base(info: TypeInfo) -> None:
_set_info_metadata(info, "is_base", True)
def get_is_base(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "is_base")
return is_base is True
def has_declarative_base(info: TypeInfo) -> bool:
is_base = _get_info_mro_metadata(info, "is_base")
return is_base is True
def set_has_table(info: TypeInfo) -> None:
_set_info_metadata(info, "has_table", True)
def get_has_table(info: TypeInfo) -> bool:
is_base = _get_info_metadata(info, "has_table")
return is_base is True
def get_mapped_attributes(
info: TypeInfo, api: SemanticAnalyzerPluginInterface
) -> Optional[List[SQLAlchemyAttribute]]:
mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
info, "mapped_attributes"
)
if mapped_attributes is None:
return None
attributes: List[SQLAlchemyAttribute] = []
for data in mapped_attributes:
attr = SQLAlchemyAttribute.deserialize(info, data, api)
attr.expand_typevar_from_subtype(info)
attributes.append(attr)
return attributes
def set_mapped_attributes(
info: TypeInfo, attributes: List[SQLAlchemyAttribute]
) -> None:
_set_info_metadata(
info,
"mapped_attributes",
[attribute.serialize() for attribute in attributes],
)
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
def add_global(
ctx: Union[ClassDefContext, DynamicClassDefContext],
module: str,
symbol_name: str,
asname: str,
) -> None:
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
if asname not in module_globals:
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
symbol_name
]
module_globals[asname] = lookup_sym
@overload
def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]:
...
@overload
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Tuple[TypingType[_TArgType], ...]
) -> Optional[_TArgType]:
...
def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
expr_types: Optional[Tuple[TypingType[Any], ...]] = None
) -> Optional[Any]:
try:
arg_idx = callexpr.arg_names.index(name)
except ValueError:
return None
kwarg = callexpr.args[arg_idx]
if isinstance(
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
):
return kwarg
return None
def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
if (
isinstance(stmt, IfStmt)
and isinstance(stmt.expr[0], NameExpr)
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
):
for substmt in stmt.body[0].body:
yield substmt
else:
yield stmt
def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
and convert it into an Instance/TypeInfo kind of structure that seems
to work as the left-hand type of an AssignmentStatement.
"""
if not isinstance(typ, UnboundType):
return typ
# TODO: figure out a more robust way to check this. The node is some
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
# but I cant figure out how to get them to match up
if typ.name == "Optional":
# convert from "Optional?" to the more familiar
# UnionType[..., NoneType()]
return unbound_to_instance(
api,
UnionType(
[unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [NoneType()]
),
)
node = api.lookup_qualified(typ.name, typ)
if (
node is not None
and isinstance(node, SymbolTableNode)
and isinstance(node.node, TypeInfo)
):
bound_type = node.node
return Instance(
bound_type,
[
unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
for arg in typ.args
],
)
else:
return typ
def info_for_cls(
cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[TypeInfo]:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
if sym is None:
return None
assert sym and isinstance(sym.node, TypeInfo)
return sym.node
return cls.info
def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
column_descriptor = NameExpr("__sa_Mapped")
column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
member_expr = MemberExpr(column_descriptor, "_empty_constructor")
return CallExpr(
member_expr,
[expr],
[ARG_POS],
["arg1"],
)

View File

@@ -0,0 +1,388 @@
# ext/orderinglist.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""A custom list that manages index/position information for contained
elements.
:author: Jason Kirtland
``orderinglist`` is a helper for mutable ordered relationships. It will
intercept list operations performed on a :func:`_orm.relationship`-managed
collection and
automatically synchronize changes in list position onto a target scalar
attribute.
Example: A ``slide`` table, where each row refers to zero or more entries
in a related ``bullet`` table. The bullets within a slide are
displayed in order based on the value of the ``position`` column in the
``bullet`` table. As entries are reordered in memory, the value of the
``position`` attribute should be updated to reflect the new sort order::
Base = declarative_base()
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position")
class Bullet(Base):
__tablename__ = 'bullet'
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
position = Column(Integer)
text = Column(String)
The standard relationship mapping will produce a list-like attribute on each
``Slide`` containing all related ``Bullet`` objects,
but coping with changes in ordering is not handled automatically.
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
attribute will remain unset until manually assigned. When the ``Bullet``
is inserted into the middle of the list, the following ``Bullet`` objects
will also need to be renumbered.
The :class:`.OrderingList` object automates this task, managing the
``position`` attribute on all ``Bullet`` objects in the collection. It is
constructed using the :func:`.ordering_list` factory::
from sqlalchemy.ext.orderinglist import ordering_list
Base = declarative_base()
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
class Bullet(Base):
__tablename__ = 'bullet'
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
position = Column(Integer)
text = Column(String)
With the above mapping the ``Bullet.position`` attribute is managed::
s = Slide()
s.bullets.append(Bullet())
s.bullets.append(Bullet())
s.bullets[1].position
>>> 1
s.bullets.insert(1, Bullet())
s.bullets[2].position
>>> 2
The :class:`.OrderingList` construct only works with **changes** to a
collection, and not the initial load from the database, and requires that the
list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
:func:`_orm.relationship` against the target ordering attribute, so that the
ordering is correct when first loaded.
.. warning::
:class:`.OrderingList` only provides limited functionality when a primary
key column or unique column is the target of the sort. Operations
that are unsupported or are problematic include:
* two entries must trade values. This is not supported directly in the
case of a primary key or unique constraint because it means at least
one row would need to be temporarily removed first, or changed to
a third, neutral value while the switch occurs.
* an entry must be deleted in order to make room for a new entry.
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
single flush. In the case of a primary key, it will trade
an INSERT/DELETE of the same primary key for an UPDATE statement in order
to lessen the impact of this limitation, however this does not take place
for a UNIQUE column.
A future feature will allow the "DELETE before INSERT" behavior to be
possible, alleviating this limitation, though this feature will require
explicit configuration at the mapper level for sets of columns that
are to be handled in this way.
:func:`.ordering_list` takes the name of the related object's ordering
attribute as an argument. By default, the zero-based integer index of the
object's position in the :func:`.ordering_list` is synchronized with the
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
start numbering at 1 or some other integer, provide ``count_from=1``.
"""
from ..orm.collections import collection
from ..orm.collections import collection_adapter
__all__ = ["ordering_list"]
def ordering_list(attr, count_from=None, **kw):
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper
relationship's ``collection_class`` option. e.g.::
from sqlalchemy.ext.orderinglist import ordering_list
class Slide(Base):
__tablename__ = 'slide'
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
:param attr:
Name of the mapped attribute to use for storage and retrieval of
ordering information
:param count_from:
Set up an integer-based ordering, starting at ``count_from``. For
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
list in SQL, storing the value in the 'pos' column. Ignored if
``ordering_func`` is supplied.
Additional arguments are passed to the :class:`.OrderingList` constructor.
"""
kw = _unsugar_count_from(count_from=count_from, **kw)
return lambda: OrderingList(attr, **kw)
# Ordering utility functions
def count_from_0(index, collection):
"""Numbering function: consecutive integers starting at 0."""
return index
def count_from_1(index, collection):
"""Numbering function: consecutive integers starting at 1."""
return index + 1
def count_from_n_factory(start):
"""Numbering function: consecutive integers starting at arbitrary start."""
def f(index, collection):
return index + start
try:
f.__name__ = "count_from_%i" % start
except TypeError:
pass
return f
def _unsugar_count_from(**kw):
"""Builds counting functions from keyword arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
"""
count_from = kw.pop("count_from", None)
if kw.get("ordering_func", None) is None and count_from is not None:
if count_from == 0:
kw["ordering_func"] = count_from_0
elif count_from == 1:
kw["ordering_func"] = count_from_1
else:
kw["ordering_func"] = count_from_n_factory(count_from)
return kw
class OrderingList(list):
"""A custom list that manages position information for its children.
The :class:`.OrderingList` object is normally set up using the
:func:`.ordering_list` factory function, used in conjunction with
the :func:`_orm.relationship` function.
"""
def __init__(
self, ordering_attr=None, ordering_func=None, reorder_on_append=False
):
"""A custom list that manages position information for its children.
``OrderingList`` is a ``collection_class`` list implementation that
syncs position in a Python list with a position attribute on the
mapped objects.
This implementation relies on the list starting in the proper order,
so be **sure** to put an ``order_by`` on your relationship.
:param ordering_attr:
Name of the attribute that stores the object's order in the
relationship.
:param ordering_func: Optional. A function that maps the position in
the Python list to a value to store in the
``ordering_attr``. Values returned are usually (but need not be!)
integers.
An ``ordering_func`` is called with two positional parameters: the
index of the element in the list, and the list itself.
If omitted, Python list indexes are used for the attribute values.
Two basic pre-built numbering functions are provided in this module:
``count_from_0`` and ``count_from_1``. For more exotic examples
like stepped numbering, alphabetical and Fibonacci numbering, see
the unit tests.
:param reorder_on_append:
Default False. When appending an object with an existing (non-None)
ordering value, that value will be left untouched unless
``reorder_on_append`` is true. This is an optimization to avoid a
variety of dangerous unexpected database writes.
SQLAlchemy will add instances to the list via append() when your
object loads. If for some reason the result set from the database
skips a step in the ordering (say, row '1' is missing but you get
'2', '3', and '4'), reorder_on_append=True would immediately
renumber the items to '1', '2', '3'. If you have multiple sessions
making changes, any of whom happen to load this collection even in
passing, all of the sessions would try to "clean up" the numbering
in their commits, possibly causing all but one to fail with a
concurrent modification error.
Recommend leaving this with the default of False, and just call
``reorder()`` if you're doing ``append()`` operations with
previously ordered instances or when doing some housekeeping after
manual sql operations.
"""
self.ordering_attr = ordering_attr
if ordering_func is None:
ordering_func = count_from_0
self.ordering_func = ordering_func
self.reorder_on_append = reorder_on_append
# More complex serialization schemes (multi column, e.g.) are possible by
# subclassing and reimplementing these two methods.
def _get_order_value(self, entity):
return getattr(entity, self.ordering_attr)
def _set_order_value(self, entity, value):
setattr(entity, self.ordering_attr, value)
def reorder(self):
"""Synchronize ordering for the entire collection.
Sweeps through the list and ensures that each object has accurate
ordering information set.
"""
for index, entity in enumerate(self):
self._order_entity(index, entity, True)
# As of 0.5, _reorder is no longer semi-private
_reorder = reorder
def _order_entity(self, index, entity, reorder=True):
have = self._get_order_value(entity)
# Don't disturb existing ordering if reorder is False
if have is not None and not reorder:
return
should_be = self.ordering_func(index, self)
if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
super(OrderingList, self).append(entity)
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
def _raw_append(self, entity):
"""Append without any ordering behavior."""
super(OrderingList, self).append(entity)
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
super(OrderingList, self).insert(index, entity)
self._reorder()
def remove(self, entity):
super(OrderingList, self).remove(entity)
adapter = collection_adapter(self)
if adapter and adapter._referenced_by_owner:
self._reorder()
def pop(self, index=-1):
entity = super(OrderingList, self).pop(index)
self._reorder()
return entity
def __setitem__(self, index, entity):
if isinstance(index, slice):
step = index.step or 1
start = index.start or 0
if start < 0:
start += len(self)
stop = index.stop or len(self)
if stop < 0:
stop += len(self)
for i in range(start, stop, step):
self.__setitem__(i, entity[i])
else:
self._order_entity(index, entity, True)
super(OrderingList, self).__setitem__(index, entity)
def __delitem__(self, index):
super(OrderingList, self).__delitem__(index)
self._reorder()
def __setslice__(self, start, end, values):
super(OrderingList, self).__setslice__(start, end, values)
self._reorder()
def __delslice__(self, start, end):
super(OrderingList, self).__delslice__(start, end)
self._reorder()
def __reduce__(self):
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
if (
callable(func)
and func.__name__ == func_name
and not func.__doc__
and hasattr(list, func_name)
):
func.__doc__ = getattr(list, func_name).__doc__
del func_name, func
def _reconstitute(cls, dict_, items):
"""Reconstitute an :class:`.OrderingList`.
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
unpickling :class:`.OrderingList` objects.
"""
obj = cls.__new__(cls)
obj.__dict__.update(dict_)
list.extend(obj, items)
return obj

View File

@@ -0,0 +1,177 @@
# ext/serializer.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
allowing "contextual" deserialization.
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
etc. which are referenced by the structure are not persisted in serialized
form, but are instead re-associated with the query structure
when it is deserialized.
Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker())
# ... define mappers
query = Session.query(MyClass).
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
# pickle the query
serialized = dumps(query)
# unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session)
print query2.all()
Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level
namespace.
The serializer module is only appropriate for query structures. It is not
needed for:
* instances of user-defined classes. These contain no references to engines,
sessions or expression constructs in the typical case and can be serialized
directly.
* Table metadata that is to be loaded entirely from the serialized structure
(i.e. is not already declared in the application). Regular
pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
typically one which was reflected from an existing database at some previous
point in time. The serializer module is specifically for the opposite case,
where the Table metadata is already present in memory.
"""
import re
from .. import Column
from .. import Table
from ..engine import Engine
from ..orm import class_mapper
from ..orm.interfaces import MapperProperty
from ..orm.mapper import Mapper
from ..orm.session import Session
from ..util import b64decode
from ..util import b64encode
from ..util import byte_buffer
from ..util import pickle
from ..util import text_type
__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw)
def persistent_id(obj):
# print "serializing:", repr(obj)
if isinstance(obj, Mapper) and not obj.non_primary:
id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
id_ = (
"mapperprop:"
+ b64encode(pickle.dumps(obj.parent.class_))
+ ":"
+ obj.key
)
elif isinstance(obj, Table):
if "parententity" in obj._annotations:
id_ = "mapper_selectable:" + b64encode(
pickle.dumps(obj._annotations["parententity"].class_)
)
else:
id_ = "table:" + text_type(obj.key)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
id_ = (
"column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
)
elif isinstance(obj, Session):
id_ = "session:"
elif isinstance(obj, Engine):
id_ = "engine:"
else:
return None
return id_
pickler.persistent_id = persistent_id
return pickler
our_ids = re.compile(
r"(mapperprop|mapper|mapper_selectable|table|column|"
r"session|attribute|engine):(.*)"
)
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler = pickle.Unpickler(file)
def get_engine():
if engine:
return engine
elif scoped_session and scoped_session().bind:
return scoped_session().bind
elif metadata and metadata.bind:
return metadata.bind
else:
return None
def persistent_load(id_):
m = our_ids.match(text_type(id_))
if not m:
return None
else:
type_, args = m.group(1, 2)
if type_ == "attribute":
key, clsarg = args.split(":")
cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
elif type_ == "mapper":
cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "mapper_selectable":
cls = pickle.loads(b64decode(args))
return class_mapper(cls).__clause_element__()
elif type_ == "mapperprop":
mapper, keyname = args.split(":")
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return metadata.tables[args]
elif type_ == "column":
table, colname = args.split(":")
return metadata.tables[table].c[colname]
elif type_ == "session":
return scoped_session()
elif type_ == "engine":
return get_engine()
else:
raise Exception("Unknown token: %s" % type_)
unpickler.persistent_load = persistent_load
return unpickler
def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
buf = byte_buffer()
pickler = Serializer(buf, protocol)
pickler.dump(obj)
return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None):
buf = byte_buffer(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load()