1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField
def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
"""Model factory for the transitive closure extension."""
if referencing_class is None:
referencing_class = model_class
if foreign_key is None:
for field_obj in model_class._meta.rel.values():
if field_obj.rel_model is model_class:
foreign_key = field_obj
break
else:
raise ValueError('Unable to find self-referential foreign key.')
primary_key = model_class._meta.primary_key
if id_column is None:
id_column = primary_key
class BaseClosureTable(VirtualModel):
depth = VirtualIntegerField()
id = VirtualIntegerField()
idcolumn = VirtualCharField()
parentcolumn = VirtualCharField()
root = VirtualIntegerField()
tablename = VirtualCharField()
class Meta:
extension_module = 'transitive_closure'
@classmethod
def descendants(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.id))
.where(cls.root == node)
.naive())
if depth is not None:
query = query.where(cls.depth == depth)
elif not include_node:
query = query.where(cls.depth > 0)
return query
@classmethod
def ancestors(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.root))
.where(cls.id == node)
.naive())
if depth:
query = query.where(cls.depth == depth)
elif not include_node:
query = query.where(cls.depth > 0)
return query
@classmethod
def siblings(cls, node, include_node=False):
if referencing_class is model_class:
# self-join
fk_value = node._data.get(foreign_key.name)
query = model_class.select().where(foreign_key == fk_value)
else:
# siblings as given in reference_class
siblings = (referencing_class
.select(id_column)
.join(cls, on=(foreign_key == cls.root))
.where((cls.id == node) & (cls.depth == 1)))
# the according models
query = (model_class
.select()
.where(primary_key << siblings)
.naive())
if not include_node:
query = query.where(primary_key != node)
return query
class Meta:
database = referencing_class._meta.database
extension_options = {
'tablename': referencing_class._meta.db_table,
'idcolumn': id_column.db_column,
'parentcolumn': foreign_key.db_column}
primary_key = False
name = '%sClosure' % model_class.__name__
return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})
|