import sqlalchemy as sa
import warnings
from collections import defaultdict
from py_meta_utils import McsArgs, McsInitArgs, Singleton, deep_getattr
from sqlalchemy.exc import SAWarning
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm.interfaces import MapperProperty
from typing import *
[docs]class ModelRegistry(metaclass=Singleton):
"""
The SQLAlchemy Unchained model registry.
"""
enable_lazy_mapping: bool = False
default_primary_key_column: str = 'id'
def __init__(self):
from .base_model import BaseModel as Model
# keyed by: full.base.model.module.name.BaseModelClassName
# values are the base classes themselves
# ordered by registration/discovery order, so the last class to be
# inserted into this lookup is the correct base class to use
self._base_model_classes: Dict[str, Type[Model]] = {}
# all discovered models "classes", before type.__new__ has been called:
# - keyed by model class name
# - order of keys signifies model class discovery order at import time
# - values are a lookup of:
# - keys: module name of this particular model class
# - values: McsArgs(model_mcs, name, bases, clsdict)
# this dict is used for inspecting base classes when __new__ is
# called on a model class that extends another of the same name
self._registry: Dict[str, Dict[str, McsArgs]] = defaultdict(dict)
# actual model classes awaiting initialization (after type.__new__ but
# before type.__init__):
# - keyed by model class name
# - values are McsInitArgs(model_cls, name, bases, clsdict)
# this lookup contains the knowledge of which version of a model class
# should maybe get mapped (BaseModelMetaclass populates this dict
# via the register method - insertion order of the correct version of a
# model class by name is therefore determined by the import order of
# bundles' models modules (essentially, by the RegisterModelsHook))
self._models: Dict[str, McsInitArgs] = {}
# like self._models, except its values are the relationships each model
# class name expects on the other side
# - keyed by model class name
# - values are a dict:
# - keyed by the model name on the other side
# - value is the attribute expected to exist
self._relationships: Dict[str, Dict[str, str]] = {}
# which keys in self._models have already been initialized
self._initialized: Set[str] = set()
def register_base_model_class(self, model):
self._base_model_classes[model.__module__ + '.' + model.__name__] = model
def _reset(self):
"""
This method is for use by tests only!
"""
self._base_model_classes = {}
self._registry = defaultdict(dict)
self._models = {}
self._initialized = set()
self._relationships = {}
def register_new(self, mcs_args: McsArgs) -> None:
if self._should_convert_bases_to_mixins(mcs_args):
self._convert_bases_to_mixins(mcs_args)
self._registry[mcs_args.name][mcs_args.module] = mcs_args
def register(self, mcs_init_args: McsInitArgs) -> None:
self._models[mcs_init_args.name] = mcs_init_args
if not self.enable_lazy_mapping or not mcs_init_args.cls.Meta.lazy_mapped:
self._initialized.add(mcs_init_args.name)
relationships = mcs_init_args.cls.Meta.relationships
if relationships:
self._relationships[mcs_init_args.name] = relationships
[docs] def finalize_mappings(self) -> Dict[str, object]:
"""
Returns a dictionary of the model classes that were finalized.
Keyed by the names of the model classes, values are the classes themselves.
"""
from sqlalchemy_unchained.base_model_metaclass import DeclarativeMeta
# this outer loop is needed to perform initializations in the order the
# classes were originally discovered at import time
for name in self._registry:
if self.should_initialize(self._models[name]):
model_cls, name, bases, clsdict = self._models[name]
model_cls._pre_mcs_init()
super(DeclarativeMeta, model_cls).__init__(name, bases, clsdict)
model_cls._post_mcs_init()
self._initialized.add(name)
return {name: self._models[name].cls for name in self._initialized}
[docs] def should_initialize(self, mcs_init_args: McsInitArgs) -> bool:
"""
Whether or not the model represented by ``mcs_init_args`` should be initialized.
"""
model_name = mcs_init_args.name
if model_name in self._initialized:
return False
if model_name not in self._relationships:
return True
with warnings.catch_warnings():
# not all related classes will have been initialized yet, ie they
# might still be non-mapped from SQLAlchemy's perspective, which is
# safe to ignore here
filter_re = r'Unmanaged access of declarative attribute \w+ from ' \
r'non-mapped class \w+'
warnings.filterwarnings('ignore', filter_re, SAWarning)
for related_model_name in self._relationships[model_name]:
related_model = self._models[related_model_name].cls
try:
other_side_relationships = \
self._relationships[related_model_name]
except KeyError:
related_model_module = \
self._models[related_model_name].cls.__module__
raise KeyError(
'Incomplete `relationships` Meta declaration for '
f'{related_model_module}.{related_model_name} '
f'(missing {model_name})')
if model_name not in other_side_relationships:
continue
related_attr = other_side_relationships[model_name]
if hasattr(related_model, related_attr):
return True
def _ensure_correct_base_model(self, mcs_args: McsArgs) -> None:
"""
Makes sure the given ``mcs_args`` uses the correct BaseModel class.
"""
if not self._base_model_classes:
return
correct_base = list(self._base_model_classes.values())[-1]
for b in mcs_args.bases:
if issubclass(b, correct_base):
return
mcs_args.clsdict['Meta'] = \
deep_getattr({}, mcs_args.bases, 'Meta', None)
mcs_args.bases = tuple([correct_base] + list(mcs_args.bases))
def _should_convert_bases_to_mixins(self, mcs_args: McsArgs) -> bool:
"""
Figures out whether the base classes for the given ``mcs_args`` should be
converted to mixins (as opposed to extending BaseModel)
"""
if mcs_args.Meta.polymorphic:
return False
for b in mcs_args.bases:
if b.__name__ in self._registry:
return True
return mcs_args.name in self._registry
def _convert_bases_to_mixins(self, mcs_args: McsArgs) -> None:
"""
For each base class in bases that the ModelRegistry knows about, create
a replacement class containing the methods and attributes from the base
class:
- the mixin should only extend object (not db.Model)
- if any of the attributes are MapperProperty instances (relationship,
association_proxy, etc), then turn them into @declared_attr props
"""
def _mixin_name(name):
return name + '_FSQLAConvertedMixin'
new_base_names = set()
new_bases = []
for b in reversed(mcs_args.bases):
if b.__name__ not in self._registry:
if b not in new_bases:
new_bases.append(b)
continue
_, base_name, base_bases, base_clsdict = \
self._registry[b.__name__][b.__module__]
for bb in reversed(base_bases):
if bb.__module__ + '.' + bb.__name__ in self._base_model_classes:
if bb not in new_bases:
new_bases.append(bb)
elif (bb.__name__ not in new_base_names
and _mixin_name(bb.__name__) not in new_base_names):
new_base_names.add(bb.__name__)
new_bases.append(bb)
clsdict = {}
for attr, value in base_clsdict.items():
if attr in {'__name__', '__qualname__'}:
continue
has_fk = isinstance(value, sa.Column) and value.foreign_keys
if has_fk or isinstance(value, MapperProperty):
# programmatically add a method wrapped with declared_attr
# to the new mixin class
exec("""\
@declared_attr
def {attr}(self):
return value
""".format(attr=attr),
{'value': value, 'declared_attr': declared_attr},
clsdict)
else:
clsdict[attr] = value
mixin_name = _mixin_name(base_name)
new_bases.append(type(mixin_name, (object,), clsdict))
new_base_names.add(mixin_name)
mcs_args.bases = tuple(reversed(new_bases))
__all__ = [
'ModelRegistry',
]