diff options
Diffstat (limited to 'archivist/closure.py')
-rw-r--r-- | archivist/closure.py | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/archivist/closure.py b/archivist/closure.py new file mode 100644 index 0000000..01fdc75 --- /dev/null +++ b/archivist/closure.py @@ -0,0 +1,92 @@ +from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField + +def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.rel.values(): + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + primary_key = model_class._meta.primary_key + + if id_column is None: + id_column = primary_key + + class BaseClosureTable(VirtualModel): + depth = VirtualIntegerField() + id = VirtualIntegerField() + idcolumn = VirtualCharField() + parentcolumn = VirtualCharField() + root = VirtualIntegerField() + tablename = VirtualCharField() + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.id)) + .where(cls.root == node) + .naive()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.root)) + .where(cls.id == node) + .naive()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node._data.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(id_column) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(primary_key << siblings) + .naive()) + + if not include_node: + query = query.where(primary_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + extension_options = { + 'tablename': referencing_class._meta.db_table, + 'idcolumn': id_column.db_column, + 'parentcolumn': foreign_key.db_column} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__}) + |