summaryrefslogtreecommitdiff
path: root/archivist/peewee_ext.py
blob: 80bb6f4b045d60400ce432744246e16e17096222 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField
from peewee import Field

from itertools import starmap
from functools import reduce
import operator as op

def sqlite_tuple_in(fields, values):
    """SQLite does not support (foo, bar) IN ((1,2),(3,4)).
    So we construct a '(foo = 1 AND bar = 2) OR (foo = 3 AND bar = 4)' monstrum."""
    subqueries = (reduce(op.and_, starmap(op.eq, zip(fields, value_tuple))) for value_tuple in values)
    return reduce(op.or_, subqueries)

class EnumField(Field):
    db_field = 'enum'

    def __init__(self, enum_class, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.enum_class = enum_class

    def _enum_value(self, value):
        if isinstance(value, str):
            try:
                return self.enum_class[value.upper()]
            except KeyError:
                pass

        try:
            return self.enum_class(int(value))
        except ValueError:
            raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__))

    def db_value(self, value):
        if value is None:
            return value

        if isinstance(value, self.enum_class):
            return value.value

        # force check of enum value
        return self._enum_value(value).value

    def python_value(self, value):
        return value if value is None else self._enum_value(value)

def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
    """Model factory for the transitive closure extension."""
    if referencing_class is None:
        referencing_class = model_class

    if foreign_key is None:
        for field_obj in model_class._meta.rel.values():
            if field_obj.rel_model is model_class:
                foreign_key = field_obj
                break
        else:
            raise ValueError('Unable to find self-referential foreign key.')

    primary_key = model_class._meta.primary_key

    if id_column is None:
        id_column = primary_key

    class BaseClosureTable(VirtualModel):
        depth = VirtualIntegerField()
        id = VirtualIntegerField()
        idcolumn = VirtualCharField()
        parentcolumn = VirtualCharField()
        root = VirtualIntegerField()
        tablename = VirtualCharField()

        class Meta:
            extension_module = 'transitive_closure'

        @classmethod
        def descendants(cls, node, depth=None, include_node=False):
            query = (model_class
                     .select(model_class, cls.depth.alias('depth'))
                     .join(cls, on=(primary_key == cls.id))
                     .where(cls.root == node)
                     .naive())
            if depth is not None:
                query = query.where(cls.depth == depth)
            elif not include_node:
                query = query.where(cls.depth > 0)
            return query

        @classmethod
        def ancestors(cls, node, depth=None, include_node=False):
            query = (model_class
                     .select(model_class, cls.depth.alias('depth'))
                     .join(cls, on=(primary_key == cls.root))
                     .where(cls.id == node)
                     .naive())
            if depth:
                query = query.where(cls.depth == depth)
            elif not include_node:
                query = query.where(cls.depth > 0)
            return query

        @classmethod
        def siblings(cls, node, include_node=False):
            if referencing_class is model_class:
                # self-join
                fk_value = node._data.get(foreign_key.name)
                query = model_class.select().where(foreign_key == fk_value)
            else:
                # siblings as given in reference_class
                siblings = (referencing_class
                        .select(id_column)
                        .join(cls, on=(foreign_key == cls.root))
                        .where((cls.id == node) & (cls.depth == 1)))

                # the according models
                query = (model_class
                     .select()
                     .where(primary_key << siblings)
                     .naive())

            if not include_node: 
                query = query.where(primary_key != node)

            return query
    
    class Meta:
        database = referencing_class._meta.database
        extension_options = {
            'tablename': referencing_class._meta.db_table,
            'idcolumn': id_column.db_column,
            'parentcolumn': foreign_key.db_column}
        primary_key = False

    name = '%sClosure' % model_class.__name__
    return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})