Coverage for sqlalchemy_crud_plus\utils.py: 98%
186 statements
« prev ^ index » next coverage.py v7.11.1, created at 2025-11-26 11:02 +0800
« prev ^ index » next coverage.py v7.11.1, created at 2025-11-26 11:02 +0800
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3from __future__ import annotations
5import warnings
7from typing import Any, Callable
9from sqlalchemy import ColumnElement, Select, and_, asc, desc, or_
10from sqlalchemy.orm import (
11 contains_eager,
12 defaultload,
13 defer,
14 immediateload,
15 joinedload,
16 lazyload,
17 load_only,
18 noload,
19 raiseload,
20 selectinload,
21 subqueryload,
22 undefer,
23 undefer_group,
24)
25from sqlalchemy.orm.util import AliasedClass
26from sqlalchemy.sql.base import ExecutableOption
27from sqlalchemy.sql.operators import ColumnOperators
28from sqlalchemy.sql.schema import Column
30from sqlalchemy_crud_plus.errors import (
31 ColumnSortError,
32 JoinConditionError,
33 LoadingStrategyError,
34 ModelColumnError,
35 SelectOperatorError,
36)
37from sqlalchemy_crud_plus.types import JoinConditions, JoinConfig, LoadStrategies, Model
39_SUPPORTED_FILTERS = {
40 # Comparison: https://docs.sqlalchemy.org/en/20/core/operators.html#comparison-operators
41 'gt': lambda column: column.__gt__,
42 'lt': lambda column: column.__lt__,
43 'ge': lambda column: column.__ge__,
44 'le': lambda column: column.__le__,
45 'eq': lambda column: column.__eq__,
46 'ne': lambda column: column.__ne__,
47 'between': lambda column: column.between,
48 # IN: https://docs.sqlalchemy.org/en/20/core/operators.html#in-comparisons
49 'in': lambda column: column.in_,
50 'not_in': lambda column: column.not_in,
51 # Identity: https://docs.sqlalchemy.org/en/20/core/operators.html#identity-comparisons
52 'is': lambda column: column.is_,
53 'is_not': lambda column: column.is_not,
54 'is_distinct_from': lambda column: column.is_distinct_from,
55 'is_not_distinct_from': lambda column: column.is_not_distinct_from,
56 # String: https://docs.sqlalchemy.org/en/20/core/operators.html#string-comparisons
57 'like': lambda column: column.like,
58 'not_like': lambda column: column.not_like,
59 'ilike': lambda column: column.ilike,
60 'not_ilike': lambda column: column.not_ilike,
61 # String Containment: https://docs.sqlalchemy.org/en/20/core/operators.html#string-containment
62 'startswith': lambda column: column.startswith,
63 'endswith': lambda column: column.endswith,
64 'contains': lambda column: column.contains,
65 # String matching: https://docs.sqlalchemy.org/en/20/core/operators.html#string-matching
66 'match': lambda column: column.match,
67 # String Alteration: https://docs.sqlalchemy.org/en/20/core/operators.html#string-alteration
68 'concat': lambda column: column.concat,
69 # Arithmetic: https://docs.sqlalchemy.org/en/20/core/operators.html#arithmetic-operators
70 'add': lambda column: column.__add__,
71 'radd': lambda column: column.__radd__,
72 'sub': lambda column: column.__sub__,
73 'rsub': lambda column: column.__rsub__,
74 'mul': lambda column: column.__mul__,
75 'rmul': lambda column: column.__rmul__,
76 'truediv': lambda column: column.__truediv__,
77 'rtruediv': lambda column: column.__rtruediv__,
78 'floordiv': lambda column: column.__floordiv__,
79 'rfloordiv': lambda column: column.__rfloordiv__,
80 'mod': lambda column: column.__mod__,
81 'rmod': lambda column: column.__rmod__,
82}
84_DYNAMIC_OPERATORS = [
85 'concat',
86 'add',
87 'radd',
88 'sub',
89 'rsub',
90 'mul',
91 'rmul',
92 'truediv',
93 'rtruediv',
94 'floordiv',
95 'rfloordiv',
96 'mod',
97 'rmod',
98]
101def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[..., Any] | None:
102 if operator in ['in', 'not_in', 'between']:
103 if not isinstance(value, (tuple, list, set)):
104 raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set')
106 if operator in _DYNAMIC_OPERATORS and not allow_arithmetic:
107 raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}')
109 sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator)
110 if sqlalchemy_filter is None and operator != 'or':
111 warnings.warn(
112 f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.',
113 SyntaxWarning,
114 )
115 return None
117 return sqlalchemy_filter
120def get_column(model: type[Model] | AliasedClass, field_name: str) -> Column:
121 """
122 Get column from model with validation.
124 :param model: The SQLAlchemy model class or aliased class
125 :param field_name: The column name to retrieve
126 :return:
127 """
128 column = getattr(model, field_name, None)
129 if column is None:
130 raise ModelColumnError(f'Column {field_name} is not found in {model}')
132 if hasattr(model, '__table__') and hasattr(column, 'property'):
133 if not hasattr(column.property, 'columns'):
134 raise ModelColumnError(f'{field_name} is not a valid column in {model}')
136 return column
139def _create_or_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]:
140 """
141 Create OR filter expressions.
143 :param column: The SQLAlchemy column
144 :param op: The operator (should be 'or')
145 :param value: Dictionary of operator-value pairs
146 :return:
147 """
148 or_filters = []
149 if op == 'or':
150 for or_op, or_value in value.items():
151 sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
152 if sqlalchemy_filter is not None:
153 or_filters.append(sqlalchemy_filter(column)(or_value))
154 return or_filters
157def _create_arithmetic_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]:
158 """
159 Create arithmetic filter expressions.
161 :param column: The SQLAlchemy column
162 :param op: The arithmetic operator
163 :param value: Dictionary containing 'value' and 'condition' keys
164 :return:
165 """
166 arithmetic_filters = []
167 if isinstance(value, dict) and {'value', 'condition'}.issubset(value):
168 arithmetic_value = value['value']
169 condition = value['condition']
170 sqlalchemy_filter = get_sqlalchemy_filter(op, arithmetic_value)
171 if sqlalchemy_filter is not None:
172 for cond_op, cond_value in condition.items():
173 arithmetic_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
174 if arithmetic_filter is not None:
175 arithmetic_filters.append(
176 arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(cond_value)
177 if cond_op != 'between'
178 else arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(*cond_value)
179 )
180 return arithmetic_filters
183def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnElement[Any] | None]:
184 """
185 Create AND filter expressions.
187 :param column: The SQLAlchemy column
188 :param op: The filter operator
189 :param value: The filter value
190 :return:
191 """
192 and_filters = []
193 sqlalchemy_filter = get_sqlalchemy_filter(op, value)
194 if sqlalchemy_filter is not None:
195 and_filters.append(sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value))
196 return and_filters
199def parse_filters(model: type[Model] | AliasedClass, **kwargs) -> list[ColumnElement[Any]]:
200 """
201 Parse filter expressions from keyword arguments.
203 :param model: The SQLAlchemy model class or aliased class
204 :param kwargs: Filter expressions using field__operator=value syntax
205 :return:
206 """
207 filters = []
209 for key, value in kwargs.items():
210 if '__' not in key:
211 column = get_column(model, key)
212 filters.append(column == value)
213 continue
215 field_name, op = key.rsplit('__', 1)
217 if field_name == '__or' and op == '':
218 __or__filters = []
220 if not isinstance(value, dict):
221 raise SelectOperatorError('__or__ filter value must be a dictionary')
223 for _key, _value in value.items():
224 if '__' not in _key:
225 _column = get_column(model, _key)
226 if isinstance(_value, list):
227 for single_value in _value:
228 __or__filters.append(_column == single_value)
229 else:
230 __or__filters.append(_column == _value)
231 else:
232 _field_name, _op = _key.rsplit('__', 1)
233 _column = get_column(model, _field_name)
235 if isinstance(_value, list) and _op not in ['in', 'not_in', 'between']:
236 for single_value in _value:
237 __or__filters.extend(_create_and_filters(_column, _op, single_value))
238 else:
239 if _op == 'or':
240 __or__filters.extend(_create_or_filters(_column, _op, _value))
241 elif _op in _DYNAMIC_OPERATORS:
242 __or__filters.extend(_create_arithmetic_filters(_column, _op, _value))
243 else:
244 __or__filters.extend(_create_and_filters(_column, _op, _value))
246 if __or__filters:
247 filters.append(or_(*__or__filters))
248 else:
249 column = get_column(model, field_name)
251 if op == 'or':
252 filters.append(or_(*_create_or_filters(column, op, value)))
253 elif op in _DYNAMIC_OPERATORS:
254 arithmetic_filters = _create_arithmetic_filters(column, op, value)
255 if arithmetic_filters:
256 filters.append(and_(*arithmetic_filters))
257 else:
258 filters.extend(_create_and_filters(column, op, value))
260 return filters
263def apply_sorting(
264 model: type[Model] | AliasedClass,
265 stmt: Select,
266 sort_columns: str | list[str],
267 sort_orders: str | list[str] | None = None,
268) -> Select:
269 """
270 Apply sorting to a SQLAlchemy query based on specified column names and sort orders.
272 :param model: The SQLAlchemy model
273 :param stmt: The SQLAlchemy Select statement to which sorting will be applied
274 :param sort_columns: Column name or list of column names to sort by
275 :param sort_orders: Sort order ("asc" or "desc") or list of sort orders
276 :return:
277 """
278 if sort_orders and not sort_columns:
279 raise ValueError('Sort orders provided without corresponding sort columns.')
281 if sort_columns:
282 if not isinstance(sort_columns, list):
283 sort_columns = [sort_columns]
285 if sort_orders:
286 if not isinstance(sort_orders, list):
287 sort_orders = [sort_orders] * len(sort_columns)
289 if len(sort_columns) != len(sort_orders):
290 raise ColumnSortError('The length of sort_columns and sort_orders must match.')
292 for order in sort_orders:
293 if order not in ['asc', 'desc']:
294 raise SelectOperatorError(
295 f'Select sort operator {order} is not supported, only supports `asc`, `desc`'
296 )
298 validated_sort_orders = ['asc'] * len(sort_columns) if not sort_orders else sort_orders
300 for idx, column_name in enumerate(sort_columns):
301 column = get_column(model, column_name)
302 order = validated_sort_orders[idx]
303 stmt = stmt.order_by(asc(column) if order == 'asc' else desc(column))
305 return stmt
308def build_load_strategies(model: type[Model], load_strategies: LoadStrategies | None) -> list[ExecutableOption]:
309 """
310 Build relationship loading strategy options.
312 :param model: SQLAlchemy model class
313 :param load_strategies: Loading strategies configuration
314 :return:
315 """
316 if load_strategies is None:
317 return []
319 strategies_map = {
320 'contains_eager': contains_eager,
321 'defaultload': defaultload,
322 'immediateload': immediateload,
323 'joinedload': joinedload,
324 'lazyload': lazyload,
325 'noload': noload,
326 'raiseload': raiseload,
327 'selectinload': selectinload,
328 'subqueryload': subqueryload,
329 # Load
330 'defer': defer,
331 'load_only': load_only,
332 # 'selectin_polymorphic': selectin_polymorphic,
333 'undefer': undefer,
334 'undefer_group': undefer_group,
335 # 'with_expression': with_expression,
336 }
338 options = []
339 default_strategy = 'selectinload'
341 if isinstance(load_strategies, list):
342 for column in load_strategies:
343 try:
344 attr = getattr(model, column)
345 strategy_func = strategies_map[default_strategy]
346 options.append(strategy_func(attr))
347 except AttributeError:
348 raise ModelColumnError(f'Invalid relationship column: {column}')
350 elif isinstance(load_strategies, dict):
351 for column, strategy_name in load_strategies.items():
352 if strategy_name not in strategies_map:
353 raise LoadingStrategyError(
354 f'Invalid loading strategy: {strategy_name}, only supports {list(strategies_map.keys())}'
355 )
356 try:
357 attr = getattr(model, column)
358 strategy_func = strategies_map.get(strategy_name)
359 if strategy_func:
360 options.append(strategy_func(attr))
361 except AttributeError:
362 raise ModelColumnError(f'Invalid relationship column: {column}')
364 return options
367def has_join_fill_result(join_conditions: JoinConditions) -> bool:
368 """
369 Check if any JoinConfig in join_conditions has fill_result=True.
371 :param join_conditions: JOIN conditions configuration
372 :return:
373 """
374 if isinstance(join_conditions, list):
375 for v in join_conditions:
376 if isinstance(v, JoinConfig) and v.fill_result:
377 return True
379 return False
382def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditions) -> Select:
383 """
384 Apply JOIN conditions to the query statement.
386 :param model: SQLAlchemy model class
387 :param stmt: SQLAlchemy Select statement
388 :param join_conditions: JOIN conditions configuration
389 :return:
390 """
391 if isinstance(join_conditions, list):
392 for v in join_conditions:
393 if isinstance(v, str):
394 try:
395 attr = getattr(model, v)
396 stmt = stmt.join(attr)
397 except AttributeError:
398 raise ModelColumnError(f'Invalid model column: {v}')
399 elif isinstance(v, JoinConfig):
400 if v.join_type == 'inner':
401 stmt = stmt.join(v.model, v.join_on)
402 elif v.join_type == 'left':
403 stmt = stmt.join(v.model, v.join_on, isouter=True)
404 elif v.join_type == 'full':
405 stmt = stmt.join(v.model, v.join_on, full=True)
407 if v.fill_result:
408 if not any(
409 col for col in stmt.selected_columns if hasattr(col, 'class_') and col.class_ == v.model
410 ):
411 stmt = stmt.add_columns(v.model)
413 elif isinstance(join_conditions, dict):
414 for column, join_type in join_conditions.items():
415 allowed_join_types = ['inner', 'left', 'full'] # SQLAlchemy doesn't support right join
416 if join_type not in allowed_join_types:
417 raise JoinConditionError(f'Invalid join type: {join_type}, only supports {allowed_join_types}')
418 try:
419 attr = getattr(model, column)
420 if join_type == 'inner':
421 stmt = stmt.join(attr)
422 elif join_type == 'left':
423 stmt = stmt.join(attr, isouter=True)
424 elif join_type == 'full':
425 stmt = stmt.join(attr, full=True)
426 else:
427 stmt = stmt.join(attr)
428 except AttributeError:
429 raise ModelColumnError(f'Invalid model column: {column}')
431 return stmt