forked from Netflix/dispatch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
163 lines (124 loc) · 5.2 KB
/
Copy pathdatabase.py
File metadata and controls
163 lines (124 loc) · 5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import re
from typing import Any, List
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import Query, sessionmaker
from sqlalchemy_filters import apply_pagination, apply_sort, apply_filters
from sqlalchemy_searchable import make_searchable
from sqlalchemy_searchable import search as search_db
from starlette.requests import Request
from dispatch.common.utils.composite_search import CompositeSearch
from .config import SQLALCHEMY_DATABASE_URI
engine = create_engine(str(SQLALCHEMY_DATABASE_URI))
SessionLocal = sessionmaker(bind=engine)
def resolve_table_name(name):
"""Resolves table names to their mapped names."""
names = re.split("(?=[A-Z])", name) # noqa
return "_".join([x.lower() for x in names if x])
class CustomBase:
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
Base = declarative_base(cls=CustomBase)
make_searchable(Base.metadata)
def get_db(request: Request):
return request.state.db
def get_model_name_by_tablename(table_fullname: str) -> str:
"""Returns the model name of a given table."""
return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
mapped_name = resolve_table_name(table_fullname)
for c in Base._decl_class_registry.values():
if hasattr(c, "__table__") and c.__table__.fullname == mapped_name:
return c
raise Exception(f"Incorrect tablename '{mapped_name}'. Check the name of your model.")
def paginate(query: Query, page: int, items_per_page: int):
# Never pass a negative OFFSET value to SQL.
offset_adj = 0 if page <= 0 else page - 1
items = query.limit(items_per_page).offset(offset_adj * items_per_page).all()
total = query.order_by(None).count()
return items, total
def composite_search(*, db_session, query_str: str, models: List[Base]):
"""Perform a multi-table search based on the supplied query."""
s = CompositeSearch(db_session, models)
q = s.build_query(query_str, sort=True)
return s.search(query=q)
def search(*, db_session, query_str: str, model: str):
"""Perform a search based on the query."""
q = db_session.query(get_class_by_tablename(model))
return search_db(q, query_str, sort=True)
def create_filter_spec(model, fields, ops, values):
"""Creates a filter spec."""
filter_spec = []
if fields and ops and values:
for field, op, value in zip(fields, ops, values):
# we have a complex field, we may need to join
if "." in field:
complex_model, complex_field = field.split(".")
filter_spec.append(
{
"model": get_model_name_by_tablename(complex_model),
"field": complex_field,
"op": op,
"value": value,
}
)
else:
filter_spec.append({"model": model, "field": field, "op": op, "value": value})
# NOTE we default to AND filters
if filter_spec:
return {"and": filter_spec}
return filter_spec
def create_sort_spec(model, sort_by, descending):
"""Creates sort_spec."""
sort_spec = []
if sort_by and descending:
for field, direction in zip(sort_by, descending):
direction = "desc" if direction else "asc"
# we have a complex field, we may need to join
if "." in field:
complex_model, complex_field = field.split(".")
sort_spec.append(
{
"model": get_model_name_by_tablename(complex_model),
"field": complex_field,
"direction": direction,
}
)
else:
sort_spec.append({"model": model, "field": field, "direction": direction})
return sort_spec
def get_all(*, db_session, model):
"""Fetches a query object based on the model class name."""
return db_session.query(get_class_by_tablename(model))
def search_filter_sort_paginate(
db_session,
model,
query_str: str = None,
page: int = 1,
items_per_page: int = 5,
sort_by: List[str] = None,
descending: List[bool] = None,
fields: List[str] = None,
ops: List[str] = None,
values: List[str] = None,
):
"""Common functionality for searching, filtering and sorting"""
if query_str:
query = search(db_session=db_session, query_str=query_str, model=model)
else:
query = get_all(db_session=db_session, model=model)
filter_spec = create_filter_spec(model, fields, ops, values)
query = apply_filters(query, filter_spec)
sort_spec = create_sort_spec(model, sort_by, descending)
query = apply_sort(query, sort_spec)
if items_per_page == -1:
items_per_page = None
query, pagination = apply_pagination(query, page_number=page, page_size=items_per_page)
return {
"items": query.all(),
"itemsPerPage": pagination.page_size,
"page": pagination.page_number,
"total": pagination.total_results,
}