Coverage for sqlalchemy_crud_plus\utils.py: 98%

186 statements  

« prev     ^ index     » next       coverage.py v7.11.1, created at 2025-11-26 11:02 +0800

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3from __future__ import annotations 

4 

5import warnings 

6 

7from typing import Any, Callable 

8 

9from sqlalchemy import ColumnElement, Select, and_, asc, desc, or_ 

10from sqlalchemy.orm import ( 

11 contains_eager, 

12 defaultload, 

13 defer, 

14 immediateload, 

15 joinedload, 

16 lazyload, 

17 load_only, 

18 noload, 

19 raiseload, 

20 selectinload, 

21 subqueryload, 

22 undefer, 

23 undefer_group, 

24) 

25from sqlalchemy.orm.util import AliasedClass 

26from sqlalchemy.sql.base import ExecutableOption 

27from sqlalchemy.sql.operators import ColumnOperators 

28from sqlalchemy.sql.schema import Column 

29 

30from sqlalchemy_crud_plus.errors import ( 

31 ColumnSortError, 

32 JoinConditionError, 

33 LoadingStrategyError, 

34 ModelColumnError, 

35 SelectOperatorError, 

36) 

37from sqlalchemy_crud_plus.types import JoinConditions, JoinConfig, LoadStrategies, Model 

38 

39_SUPPORTED_FILTERS = { 

40 # Comparison: https://docs.sqlalchemy.org/en/20/core/operators.html#comparison-operators 

41 'gt': lambda column: column.__gt__, 

42 'lt': lambda column: column.__lt__, 

43 'ge': lambda column: column.__ge__, 

44 'le': lambda column: column.__le__, 

45 'eq': lambda column: column.__eq__, 

46 'ne': lambda column: column.__ne__, 

47 'between': lambda column: column.between, 

48 # IN: https://docs.sqlalchemy.org/en/20/core/operators.html#in-comparisons 

49 'in': lambda column: column.in_, 

50 'not_in': lambda column: column.not_in, 

51 # Identity: https://docs.sqlalchemy.org/en/20/core/operators.html#identity-comparisons 

52 'is': lambda column: column.is_, 

53 'is_not': lambda column: column.is_not, 

54 'is_distinct_from': lambda column: column.is_distinct_from, 

55 'is_not_distinct_from': lambda column: column.is_not_distinct_from, 

56 # String: https://docs.sqlalchemy.org/en/20/core/operators.html#string-comparisons 

57 'like': lambda column: column.like, 

58 'not_like': lambda column: column.not_like, 

59 'ilike': lambda column: column.ilike, 

60 'not_ilike': lambda column: column.not_ilike, 

61 # String Containment: https://docs.sqlalchemy.org/en/20/core/operators.html#string-containment 

62 'startswith': lambda column: column.startswith, 

63 'endswith': lambda column: column.endswith, 

64 'contains': lambda column: column.contains, 

65 # String matching: https://docs.sqlalchemy.org/en/20/core/operators.html#string-matching 

66 'match': lambda column: column.match, 

67 # String Alteration: https://docs.sqlalchemy.org/en/20/core/operators.html#string-alteration 

68 'concat': lambda column: column.concat, 

69 # Arithmetic: https://docs.sqlalchemy.org/en/20/core/operators.html#arithmetic-operators 

70 'add': lambda column: column.__add__, 

71 'radd': lambda column: column.__radd__, 

72 'sub': lambda column: column.__sub__, 

73 'rsub': lambda column: column.__rsub__, 

74 'mul': lambda column: column.__mul__, 

75 'rmul': lambda column: column.__rmul__, 

76 'truediv': lambda column: column.__truediv__, 

77 'rtruediv': lambda column: column.__rtruediv__, 

78 'floordiv': lambda column: column.__floordiv__, 

79 'rfloordiv': lambda column: column.__rfloordiv__, 

80 'mod': lambda column: column.__mod__, 

81 'rmod': lambda column: column.__rmod__, 

82} 

83 

84_DYNAMIC_OPERATORS = [ 

85 'concat', 

86 'add', 

87 'radd', 

88 'sub', 

89 'rsub', 

90 'mul', 

91 'rmul', 

92 'truediv', 

93 'rtruediv', 

94 'floordiv', 

95 'rfloordiv', 

96 'mod', 

97 'rmod', 

98] 

99 

100 

101def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[..., Any] | None: 

102 if operator in ['in', 'not_in', 'between']: 

103 if not isinstance(value, (tuple, list, set)): 

104 raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set') 

105 

106 if operator in _DYNAMIC_OPERATORS and not allow_arithmetic: 

107 raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}') 

108 

109 sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator) 

110 if sqlalchemy_filter is None and operator != 'or': 

111 warnings.warn( 

112 f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.', 

113 SyntaxWarning, 

114 ) 

115 return None 

116 

117 return sqlalchemy_filter 

118 

119 

120def get_column(model: type[Model] | AliasedClass, field_name: str) -> Column: 

121 """ 

122 Get column from model with validation. 

123 

124 :param model: The SQLAlchemy model class or aliased class 

125 :param field_name: The column name to retrieve 

126 :return: 

127 """ 

128 column = getattr(model, field_name, None) 

129 if column is None: 

130 raise ModelColumnError(f'Column {field_name} is not found in {model}') 

131 

132 if hasattr(model, '__table__') and hasattr(column, 'property'): 

133 if not hasattr(column.property, 'columns'): 

134 raise ModelColumnError(f'{field_name} is not a valid column in {model}') 

135 

136 return column 

137 

138 

139def _create_or_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]: 

140 """ 

141 Create OR filter expressions. 

142 

143 :param column: The SQLAlchemy column 

144 :param op: The operator (should be 'or') 

145 :param value: Dictionary of operator-value pairs 

146 :return: 

147 """ 

148 or_filters = [] 

149 if op == 'or': 

150 for or_op, or_value in value.items(): 

151 sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value) 

152 if sqlalchemy_filter is not None: 

153 or_filters.append(sqlalchemy_filter(column)(or_value)) 

154 return or_filters 

155 

156 

157def _create_arithmetic_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]: 

158 """ 

159 Create arithmetic filter expressions. 

160 

161 :param column: The SQLAlchemy column 

162 :param op: The arithmetic operator 

163 :param value: Dictionary containing 'value' and 'condition' keys 

164 :return: 

165 """ 

166 arithmetic_filters = [] 

167 if isinstance(value, dict) and {'value', 'condition'}.issubset(value): 

168 arithmetic_value = value['value'] 

169 condition = value['condition'] 

170 sqlalchemy_filter = get_sqlalchemy_filter(op, arithmetic_value) 

171 if sqlalchemy_filter is not None: 

172 for cond_op, cond_value in condition.items(): 

173 arithmetic_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False) 

174 if arithmetic_filter is not None: 

175 arithmetic_filters.append( 

176 arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(cond_value) 

177 if cond_op != 'between' 

178 else arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(*cond_value) 

179 ) 

180 return arithmetic_filters 

181 

182 

183def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnElement[Any] | None]: 

184 """ 

185 Create AND filter expressions. 

186 

187 :param column: The SQLAlchemy column 

188 :param op: The filter operator 

189 :param value: The filter value 

190 :return: 

191 """ 

192 and_filters = [] 

193 sqlalchemy_filter = get_sqlalchemy_filter(op, value) 

194 if sqlalchemy_filter is not None: 

195 and_filters.append(sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)) 

196 return and_filters 

197 

198 

199def parse_filters(model: type[Model] | AliasedClass, **kwargs) -> list[ColumnElement[Any]]: 

200 """ 

201 Parse filter expressions from keyword arguments. 

202 

203 :param model: The SQLAlchemy model class or aliased class 

204 :param kwargs: Filter expressions using field__operator=value syntax 

205 :return: 

206 """ 

207 filters = [] 

208 

209 for key, value in kwargs.items(): 

210 if '__' not in key: 

211 column = get_column(model, key) 

212 filters.append(column == value) 

213 continue 

214 

215 field_name, op = key.rsplit('__', 1) 

216 

217 if field_name == '__or' and op == '': 

218 __or__filters = [] 

219 

220 if not isinstance(value, dict): 

221 raise SelectOperatorError('__or__ filter value must be a dictionary') 

222 

223 for _key, _value in value.items(): 

224 if '__' not in _key: 

225 _column = get_column(model, _key) 

226 if isinstance(_value, list): 

227 for single_value in _value: 

228 __or__filters.append(_column == single_value) 

229 else: 

230 __or__filters.append(_column == _value) 

231 else: 

232 _field_name, _op = _key.rsplit('__', 1) 

233 _column = get_column(model, _field_name) 

234 

235 if isinstance(_value, list) and _op not in ['in', 'not_in', 'between']: 

236 for single_value in _value: 

237 __or__filters.extend(_create_and_filters(_column, _op, single_value)) 

238 else: 

239 if _op == 'or': 

240 __or__filters.extend(_create_or_filters(_column, _op, _value)) 

241 elif _op in _DYNAMIC_OPERATORS: 

242 __or__filters.extend(_create_arithmetic_filters(_column, _op, _value)) 

243 else: 

244 __or__filters.extend(_create_and_filters(_column, _op, _value)) 

245 

246 if __or__filters: 

247 filters.append(or_(*__or__filters)) 

248 else: 

249 column = get_column(model, field_name) 

250 

251 if op == 'or': 

252 filters.append(or_(*_create_or_filters(column, op, value))) 

253 elif op in _DYNAMIC_OPERATORS: 

254 arithmetic_filters = _create_arithmetic_filters(column, op, value) 

255 if arithmetic_filters: 

256 filters.append(and_(*arithmetic_filters)) 

257 else: 

258 filters.extend(_create_and_filters(column, op, value)) 

259 

260 return filters 

261 

262 

263def apply_sorting( 

264 model: type[Model] | AliasedClass, 

265 stmt: Select, 

266 sort_columns: str | list[str], 

267 sort_orders: str | list[str] | None = None, 

268) -> Select: 

269 """ 

270 Apply sorting to a SQLAlchemy query based on specified column names and sort orders. 

271 

272 :param model: The SQLAlchemy model 

273 :param stmt: The SQLAlchemy Select statement to which sorting will be applied 

274 :param sort_columns: Column name or list of column names to sort by 

275 :param sort_orders: Sort order ("asc" or "desc") or list of sort orders 

276 :return: 

277 """ 

278 if sort_orders and not sort_columns: 

279 raise ValueError('Sort orders provided without corresponding sort columns.') 

280 

281 if sort_columns: 

282 if not isinstance(sort_columns, list): 

283 sort_columns = [sort_columns] 

284 

285 if sort_orders: 

286 if not isinstance(sort_orders, list): 

287 sort_orders = [sort_orders] * len(sort_columns) 

288 

289 if len(sort_columns) != len(sort_orders): 

290 raise ColumnSortError('The length of sort_columns and sort_orders must match.') 

291 

292 for order in sort_orders: 

293 if order not in ['asc', 'desc']: 

294 raise SelectOperatorError( 

295 f'Select sort operator {order} is not supported, only supports `asc`, `desc`' 

296 ) 

297 

298 validated_sort_orders = ['asc'] * len(sort_columns) if not sort_orders else sort_orders 

299 

300 for idx, column_name in enumerate(sort_columns): 

301 column = get_column(model, column_name) 

302 order = validated_sort_orders[idx] 

303 stmt = stmt.order_by(asc(column) if order == 'asc' else desc(column)) 

304 

305 return stmt 

306 

307 

308def build_load_strategies(model: type[Model], load_strategies: LoadStrategies | None) -> list[ExecutableOption]: 

309 """ 

310 Build relationship loading strategy options. 

311 

312 :param model: SQLAlchemy model class 

313 :param load_strategies: Loading strategies configuration 

314 :return: 

315 """ 

316 if load_strategies is None: 

317 return [] 

318 

319 strategies_map = { 

320 'contains_eager': contains_eager, 

321 'defaultload': defaultload, 

322 'immediateload': immediateload, 

323 'joinedload': joinedload, 

324 'lazyload': lazyload, 

325 'noload': noload, 

326 'raiseload': raiseload, 

327 'selectinload': selectinload, 

328 'subqueryload': subqueryload, 

329 # Load 

330 'defer': defer, 

331 'load_only': load_only, 

332 # 'selectin_polymorphic': selectin_polymorphic, 

333 'undefer': undefer, 

334 'undefer_group': undefer_group, 

335 # 'with_expression': with_expression, 

336 } 

337 

338 options = [] 

339 default_strategy = 'selectinload' 

340 

341 if isinstance(load_strategies, list): 

342 for column in load_strategies: 

343 try: 

344 attr = getattr(model, column) 

345 strategy_func = strategies_map[default_strategy] 

346 options.append(strategy_func(attr)) 

347 except AttributeError: 

348 raise ModelColumnError(f'Invalid relationship column: {column}') 

349 

350 elif isinstance(load_strategies, dict): 

351 for column, strategy_name in load_strategies.items(): 

352 if strategy_name not in strategies_map: 

353 raise LoadingStrategyError( 

354 f'Invalid loading strategy: {strategy_name}, only supports {list(strategies_map.keys())}' 

355 ) 

356 try: 

357 attr = getattr(model, column) 

358 strategy_func = strategies_map.get(strategy_name) 

359 if strategy_func: 

360 options.append(strategy_func(attr)) 

361 except AttributeError: 

362 raise ModelColumnError(f'Invalid relationship column: {column}') 

363 

364 return options 

365 

366 

367def has_join_fill_result(join_conditions: JoinConditions) -> bool: 

368 """ 

369 Check if any JoinConfig in join_conditions has fill_result=True. 

370 

371 :param join_conditions: JOIN conditions configuration 

372 :return: 

373 """ 

374 if isinstance(join_conditions, list): 

375 for v in join_conditions: 

376 if isinstance(v, JoinConfig) and v.fill_result: 

377 return True 

378 

379 return False 

380 

381 

382def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditions) -> Select: 

383 """ 

384 Apply JOIN conditions to the query statement. 

385 

386 :param model: SQLAlchemy model class 

387 :param stmt: SQLAlchemy Select statement 

388 :param join_conditions: JOIN conditions configuration 

389 :return: 

390 """ 

391 if isinstance(join_conditions, list): 

392 for v in join_conditions: 

393 if isinstance(v, str): 

394 try: 

395 attr = getattr(model, v) 

396 stmt = stmt.join(attr) 

397 except AttributeError: 

398 raise ModelColumnError(f'Invalid model column: {v}') 

399 elif isinstance(v, JoinConfig): 

400 if v.join_type == 'inner': 

401 stmt = stmt.join(v.model, v.join_on) 

402 elif v.join_type == 'left': 

403 stmt = stmt.join(v.model, v.join_on, isouter=True) 

404 elif v.join_type == 'full': 

405 stmt = stmt.join(v.model, v.join_on, full=True) 

406 

407 if v.fill_result: 

408 if not any( 

409 col for col in stmt.selected_columns if hasattr(col, 'class_') and col.class_ == v.model 

410 ): 

411 stmt = stmt.add_columns(v.model) 

412 

413 elif isinstance(join_conditions, dict): 

414 for column, join_type in join_conditions.items(): 

415 allowed_join_types = ['inner', 'left', 'full'] # SQLAlchemy doesn't support right join 

416 if join_type not in allowed_join_types: 

417 raise JoinConditionError(f'Invalid join type: {join_type}, only supports {allowed_join_types}') 

418 try: 

419 attr = getattr(model, column) 

420 if join_type == 'inner': 

421 stmt = stmt.join(attr) 

422 elif join_type == 'left': 

423 stmt = stmt.join(attr, isouter=True) 

424 elif join_type == 'full': 

425 stmt = stmt.join(attr, full=True) 

426 else: 

427 stmt = stmt.join(attr) 

428 except AttributeError: 

429 raise ModelColumnError(f'Invalid model column: {column}') 

430 

431 return stmt