Coverage for sqlalchemy_crud_plus\crud.py: 97%
218 statements
« prev ^ index » next coverage.py v7.11.1, created at 2025-11-26 16:47 +0800
« prev ^ index » next coverage.py v7.11.1, created at 2025-11-26 16:47 +0800
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3from datetime import datetime, timezone
4from typing import Any, Generic, Sequence, cast
6from sqlalchemy import (
7 Column,
8 ColumnExpressionArgument,
9 CursorResult,
10 Row,
11 Select,
12 delete,
13 func,
14 insert,
15 inspect,
16 select,
17 update,
18)
19from sqlalchemy.ext.asyncio import AsyncSession
21from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError
22from sqlalchemy_crud_plus.types import (
23 CreateSchema,
24 JoinConditions,
25 LoadOptions,
26 LoadStrategies,
27 Model,
28 SortColumns,
29 SortOrders,
30 UpdateSchema,
31)
32from sqlalchemy_crud_plus.utils import (
33 apply_join_conditions,
34 apply_sorting,
35 build_load_strategies,
36 has_join_fill_result,
37 parse_filters,
38)
41class CRUDPlus(Generic[Model]):
42 def __init__(self, model: type[Model]):
43 self.model = model
44 self.model_column_names = [column.key for column in model.__table__.columns]
45 self.primary_key = self._get_primary_key()
47 def _get_primary_key(self) -> Column | list[Column]:
48 """
49 Dynamically retrieve the primary key column(s) for the model.
50 """
51 mapper = inspect(self.model)
52 primary_key = mapper.primary_key
53 if len(primary_key) == 1:
54 return primary_key[0]
55 else:
56 return list(primary_key)
58 def _get_pk_filter(self, pk: Any | list[Any]) -> list[ColumnExpressionArgument[bool]]:
59 """
60 Get the primary key filter(s).
62 :param pk: Single value for simple primary key, or tuple for composite primary key
63 :return:
64 """
65 if isinstance(self.primary_key, list):
66 if len(pk) != len(self.primary_key):
67 raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key')
68 return [column == value for column, value in zip(self.primary_key, pk)]
69 else:
70 return [self.primary_key == pk]
72 async def create_model(
73 self,
74 session: AsyncSession,
75 obj: CreateSchema,
76 flush: bool = False,
77 commit: bool = False,
78 **kwargs,
79 ) -> Model:
80 """
81 Create a new instance of a model.
83 :param session: The SQLAlchemy async session
84 :param obj: The Pydantic schema containing data to be saved
85 :param flush: If `True`, flush all object changes to the database
86 :param commit: If `True`, commits the transaction immediately
87 :param kwargs: Additional model data not included in the pydantic schema
88 :return:
89 """
90 obj_data = obj.model_dump()
91 if kwargs:
92 obj_data.update(kwargs)
94 ins = self.model(**obj_data)
95 session.add(ins)
97 if flush:
98 await session.flush()
99 if commit:
100 await session.commit()
102 return ins
104 async def create_models(
105 self,
106 session: AsyncSession,
107 objs: list[CreateSchema],
108 flush: bool = False,
109 commit: bool = False,
110 **kwargs,
111 ) -> list[Model]:
112 """
113 Create new instances of a model.
115 :param session: The SQLAlchemy async session
116 :param objs: The Pydantic schema list containing data to be saved
117 :param flush: If `True`, flush all object changes to the database
118 :param commit: If `True`, commits the transaction immediately
119 :param kwargs: Additional model data not included in the pydantic schema
120 :return:
121 """
122 ins_list = []
123 for obj in objs:
124 obj_data = obj.model_dump()
125 if kwargs:
126 obj_data.update(kwargs)
127 ins = self.model(**obj_data)
128 ins_list.append(ins)
130 session.add_all(ins_list)
132 if flush:
133 await session.flush()
134 if commit:
135 await session.commit()
137 return ins_list
139 async def bulk_create_models(
140 self,
141 session: AsyncSession,
142 objs: list[dict[str, Any]],
143 render_nulls: bool = False,
144 flush: bool = False,
145 commit: bool = False,
146 **kwargs,
147 ) -> Sequence[Model]:
148 """
149 Create new instances of a model.
151 :param session: The SQLAlchemy async session
152 :param objs: The dict list containing data to be saved,The dict data should be aligned with the model column
153 :param render_nulls: render null values instead of ignoring them
154 :param flush: If `True`, flush all object changes to the database
155 :param commit: If `True`, commits the transaction immediately
156 :param kwargs: Additional model data not included in the dict
157 :return:
158 """
159 stmt = insert(self.model).values(**kwargs).execution_options(render_nulls=render_nulls).returning(self.model)
160 result = await session.execute(stmt, objs)
162 if flush:
163 await session.flush()
164 if commit:
165 await session.commit()
167 return result.scalars().all()
169 async def count(
170 self,
171 session: AsyncSession,
172 *whereclause: ColumnExpressionArgument[bool],
173 join_conditions: JoinConditions | None = None,
174 **kwargs,
175 ) -> int:
176 """
177 Count records that match specified filters.
179 :param session: SQLAlchemy async session
180 :param whereclause: Additional WHERE clauses
181 :param join_conditions: JOIN conditions for relationships
182 :param kwargs: Filter expressions using field__operator=value syntax
183 :return:
184 """
185 filters = list(whereclause)
187 if kwargs:
188 filters.extend(parse_filters(self.model, **kwargs))
190 if isinstance(self.primary_key, list):
191 stmt = select(func.count()).select_from(self.model)
192 else:
193 stmt = select(func.count(self.primary_key)).select_from(self.model)
195 if filters:
196 stmt = stmt.where(*filters)
198 if join_conditions:
199 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions)
201 query = await session.execute(stmt)
202 total_count = query.scalar()
203 return total_count if total_count is not None else 0
205 async def exists(
206 self,
207 session: AsyncSession,
208 *whereclause: ColumnExpressionArgument[bool],
209 join_conditions: JoinConditions | None = None,
210 **kwargs,
211 ) -> bool:
212 """
213 Check whether records that match the specified filters exist.
215 :param session: SQLAlchemy async session
216 :param whereclause: Additional WHERE clauses
217 :param join_conditions: JOIN conditions for relationships
218 :param kwargs: Filter expressions using field__operator=value syntax
219 :return:
220 """
221 filters = list(whereclause)
223 if kwargs:
224 filters.extend(parse_filters(self.model, **kwargs))
226 stmt = select(self.model).where(*filters).limit(1)
228 if join_conditions:
229 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions)
231 query = await session.execute(stmt)
232 return query.scalars().first() is not None
234 async def select_model(
235 self,
236 session: AsyncSession,
237 pk: Any | Sequence[Any],
238 *whereclause: ColumnExpressionArgument[bool],
239 load_options: LoadOptions | None = None,
240 load_strategies: LoadStrategies | None = None,
241 join_conditions: JoinConditions | None = None,
242 **kwargs: Any,
243 ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None:
244 """
245 Query by primary key(s) with optional relationship loading and joins.
247 :param session: SQLAlchemy async session
248 :param pk: Primary key value(s) - single value or tuple for composite keys
249 :param whereclause: Additional WHERE clauses
250 :param load_options: SQLAlchemy loading options
251 :param load_strategies: Relationship loading strategies
252 :param join_conditions: JOIN conditions for relationships
253 :param kwargs: Filter expressions using field__operator=value syntax
254 :return:
255 """
256 filters = list(whereclause)
257 filters.extend(self._get_pk_filter(pk))
259 if kwargs:
260 filters.extend(parse_filters(self.model, **kwargs))
262 stmt = select(self.model).where(*filters)
264 if load_options:
265 stmt = stmt.options(*load_options)
267 if join_conditions:
268 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions)
270 if load_strategies:
271 rel_options = build_load_strategies(self.model, load_strategies)
272 if rel_options:
273 stmt = stmt.options(*rel_options)
275 query = await session.execute(stmt)
277 if join_conditions:
278 if has_join_fill_result(join_conditions):
279 return query.first()
281 return query.scalars().first()
283 async def select_model_by_column(
284 self,
285 session: AsyncSession,
286 *whereclause: ColumnExpressionArgument[bool],
287 load_options: LoadOptions | None = None,
288 load_strategies: LoadStrategies | None = None,
289 join_conditions: JoinConditions | None = None,
290 **kwargs: Any,
291 ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None:
292 """
293 Query by column with optional relationship loading and joins.
295 :param session: SQLAlchemy async session
296 :param whereclause: Additional WHERE clauses
297 :param load_options: SQLAlchemy loading options
298 :param load_strategies: Relationship loading strategies
299 :param join_conditions: JOIN conditions for relationships
300 :param kwargs: Filter expressions using field__operator=value syntax
301 :return:
302 """
303 stmt = await self.select(
304 *whereclause,
305 load_options=load_options,
306 load_strategies=load_strategies,
307 join_conditions=join_conditions,
308 **kwargs,
309 )
311 query = await session.execute(stmt)
313 if join_conditions:
314 if has_join_fill_result(join_conditions):
315 return query.first()
317 return query.scalars().first()
319 async def select(
320 self,
321 *whereclause: ColumnExpressionArgument[bool],
322 load_options: LoadOptions | None = None,
323 load_strategies: LoadStrategies | None = None,
324 join_conditions: JoinConditions | None = None,
325 **kwargs,
326 ) -> Select:
327 """
328 Construct the SQLAlchemy selection.
330 :param whereclause: WHERE clauses to apply to the query
331 :param load_options: SQLAlchemy loading options
332 :param load_strategies: Relationship loading strategies
333 :param join_conditions: JOIN conditions for relationships
334 :param kwargs: Query expressions
335 :return:
336 """
337 filters = list(whereclause)
338 filters.extend(parse_filters(self.model, **kwargs))
339 stmt = select(self.model).where(*filters)
341 if join_conditions:
342 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions)
344 if load_options:
345 stmt = stmt.options(*load_options)
347 if load_strategies:
348 rel_options = build_load_strategies(self.model, load_strategies)
349 if rel_options:
350 stmt = stmt.options(*rel_options)
352 return stmt
354 async def select_order(
355 self,
356 sort_columns: SortColumns,
357 sort_orders: SortOrders = None,
358 *whereclause: ColumnExpressionArgument[bool],
359 load_options: LoadOptions | None = None,
360 load_strategies: LoadStrategies | None = None,
361 join_conditions: JoinConditions | None = None,
362 **kwargs: Any,
363 ) -> Select:
364 """
365 Construct SQLAlchemy selection with sorting.
367 :param sort_columns: Column names to sort by
368 :param sort_orders: Sort orders ('asc' or 'desc')
369 :param whereclause: WHERE clauses to apply to the query
370 :param load_options: SQLAlchemy loading options
371 :param load_strategies: Relationship loading strategies
372 :param join_conditions: JOIN conditions for relationships
373 :param kwargs: Query expressions
374 :return:
375 """
376 stmt = await self.select(
377 *whereclause,
378 load_options=load_options,
379 load_strategies=load_strategies,
380 join_conditions=join_conditions,
381 **kwargs,
382 )
383 sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders)
384 return sorted_stmt
386 async def select_models(
387 self,
388 session: AsyncSession,
389 *whereclause: ColumnExpressionArgument[bool],
390 load_options: LoadOptions | None = None,
391 load_strategies: LoadStrategies | None = None,
392 join_conditions: JoinConditions | None = None,
393 limit: int | None = None,
394 offset: int | None = None,
395 **kwargs: Any,
396 ) -> Sequence[Row[tuple[Model, ...] | Any] | Model]:
397 """
398 Query all rows that match the specified filters with optional relationship loading and joins.
400 :param session: SQLAlchemy async session
401 :param whereclause: Additional WHERE clauses
402 :param load_options: SQLAlchemy loading options
403 :param load_strategies: Relationship loading strategies
404 :param join_conditions: JOIN conditions for relationships
405 :param limit: Maximum number of results to return
406 :param offset: Number of results to skip
407 :param kwargs: Filter expressions using field__operator=value syntax
408 :return:
409 """
410 stmt = await self.select(
411 *whereclause,
412 load_options=load_options,
413 load_strategies=load_strategies,
414 join_conditions=join_conditions,
415 **kwargs,
416 )
418 if limit is not None:
419 stmt = stmt.limit(limit)
420 if offset is not None:
421 stmt = stmt.offset(offset)
423 query = await session.execute(stmt)
425 if join_conditions:
426 if has_join_fill_result(join_conditions):
427 return query.all()
429 return query.scalars().all()
431 async def select_models_order(
432 self,
433 session: AsyncSession,
434 sort_columns: SortColumns,
435 sort_orders: SortOrders = None,
436 *whereclause: ColumnExpressionArgument[bool],
437 load_options: LoadOptions | None = None,
438 load_strategies: LoadStrategies | None = None,
439 join_conditions: JoinConditions | None = None,
440 limit: int | None = None,
441 offset: int | None = None,
442 **kwargs: Any,
443 ) -> Sequence[Row[tuple[Model, ...] | Any] | Model]:
444 """
445 Query all rows that match the specified filters and sort by columns
446 with optional relationship loading and joins.
448 :param session: SQLAlchemy async session
449 :param sort_columns: Column names to sort by
450 :param sort_orders: Sort orders ('asc' or 'desc')
451 :param whereclause: Additional WHERE clauses
452 :param load_options: SQLAlchemy loading options
453 :param load_strategies: Relationship loading strategies
454 :param join_conditions: JOIN conditions for relationships
455 :param limit: Maximum number of results to return
456 :param offset: Number of results to skip
457 :param kwargs: Filter expressions using field__operator=value syntax
458 :return:
459 """
460 stmt = await self.select_order(
461 sort_columns,
462 sort_orders,
463 *whereclause,
464 load_options=load_options,
465 load_strategies=load_strategies,
466 join_conditions=join_conditions,
467 **kwargs,
468 )
470 if limit is not None:
471 stmt = stmt.limit(limit)
472 if offset is not None:
473 stmt = stmt.offset(offset)
475 query = await session.execute(stmt)
477 if join_conditions:
478 if has_join_fill_result(join_conditions):
479 return query.all()
481 return query.scalars().all()
483 async def update_model(
484 self,
485 session: AsyncSession,
486 pk: Any | Sequence[Any],
487 obj: UpdateSchema | dict[str, Any],
488 flush: bool = False,
489 commit: bool = False,
490 **kwargs,
491 ) -> int:
492 """
493 Update an instance by model's primary key
495 :param session: The SQLAlchemy async session.
496 :param pk: Single value for simple primary key, or tuple for composite primary key.
497 :param obj: A pydantic schema or dictionary containing the update data
498 :param flush: If `True`, flush all object changes to the database. Default is `False`.
499 :param commit: If `True`, commits the transaction immediately. Default is `False`.
500 :param kwargs: Additional model data not included in the pydantic schema.
501 :return:
502 """
503 filters = self._get_pk_filter(pk)
504 data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
505 data.update(kwargs)
506 stmt = update(self.model).where(*filters).values(**data)
507 result = cast(CursorResult[Any], await session.execute(stmt))
509 if flush:
510 await session.flush()
511 if commit:
512 await session.commit()
514 return result.rowcount
516 async def update_model_by_column(
517 self,
518 session: AsyncSession,
519 obj: UpdateSchema | dict[str, Any],
520 allow_multiple: bool = False,
521 flush: bool = False,
522 commit: bool = False,
523 **kwargs,
524 ) -> int:
525 """
526 Update records by model column filters.
528 :param session: The SQLAlchemy async session
529 :param obj: A Pydantic schema or dictionary containing the update data
530 :param allow_multiple: If `True`, allows updating multiple records that match the filters
531 :param flush: If `True`, flush all object changes to the database
532 :param commit: If `True`, commits the transaction immediately
533 :param kwargs: Filter expressions using field__operator=value syntax
534 :return:
535 """
536 filters = parse_filters(self.model, **kwargs)
538 if not filters:
539 raise ValueError('At least one filter condition must be provided for update operation')
541 if not allow_multiple:
542 total_count = await self.count(session, *filters)
543 if total_count > 1:
544 raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.')
546 data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
547 stmt = update(self.model).where(*filters).values(**data)
548 result = cast(CursorResult[Any], await session.execute(stmt))
550 if flush:
551 await session.flush()
552 if commit:
553 await session.commit()
555 return result.rowcount
557 async def bulk_update_models(
558 self,
559 session: AsyncSession,
560 objs: list[UpdateSchema | dict[str, Any]],
561 pk_mode: bool = True,
562 flush: bool = False,
563 commit: bool = False,
564 **kwargs,
565 ) -> int:
566 """
567 Bulk update multiple instances with different data for each record.
569 :param session: The SQLAlchemy async session
570 :param objs: To save a list of Pydantic schemas or dict for data
571 :param pk_mode: Primary key mode, when enabled, the data must contain the primary key data
572 :param flush: If `True`, flush all object changes to the database
573 :param commit: If `True`, commits the transaction immediately
574 :param kwargs: Filter expressions using field__operator=value syntax
575 :return:
576 """
577 if not pk_mode:
578 filters = parse_filters(self.model, **kwargs)
580 if not filters:
581 raise ValueError('At least one filter condition must be provided for update operation')
583 datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
584 stmt = update(self.model).where(*filters)
585 conn = await session.connection()
586 await conn.execute(stmt, datas)
587 else:
588 datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
589 await session.execute(update(self.model), datas)
591 if flush:
592 await session.flush()
593 if commit:
594 await session.commit()
596 return len(datas)
598 async def delete_model(
599 self,
600 session: AsyncSession,
601 pk: Any | Sequence[Any],
602 flush: bool = False,
603 commit: bool = False,
604 ) -> int:
605 """
606 Delete an instance by model's primary key
608 :param session: The SQLAlchemy async session.
609 :param pk: Single value for simple primary key, or tuple for composite primary key.
610 :param flush: If `True`, flush all object changes to the database. Default is `False`.
611 :param commit: If `True`, commits the transaction immediately. Default is `False`.
612 :return:
613 """
614 filters = self._get_pk_filter(pk)
616 stmt = delete(self.model).where(*filters)
617 result = cast(CursorResult[Any], await session.execute(stmt))
619 if flush:
620 await session.flush()
621 if commit:
622 await session.commit()
624 return result.rowcount
626 async def delete_model_by_column(
627 self,
628 session: AsyncSession,
629 allow_multiple: bool = False,
630 logical_deletion: bool = False,
631 deleted_flag_column: str = 'is_deleted',
632 deleted_at_column: str = 'deleted_at',
633 deleted_at_factory: datetime = datetime.now(timezone.utc),
634 flush: bool = False,
635 commit: bool = False,
636 **kwargs,
637 ) -> int:
638 """
639 Delete records by model column filters.
641 :param session: The SQLAlchemy async session
642 :param allow_multiple: If `True`, allows deleting multiple records that match the filters
643 :param logical_deletion: If `True`, enable logical deletion instead of physical deletion
644 :param deleted_flag_column: Column name for logical deletion flag
645 :param deleted_at_column: Column name for delete time,automatic judgment
646 :param deleted_at_factory: The delete time column datetime factory function
647 :param flush: If `True`, flush all object changes to the database
648 :param commit: If `True`, commits the transaction immediately
649 :param kwargs: Filter expressions using field__operator=value syntax
650 :return:
651 """
652 if logical_deletion:
653 if deleted_flag_column not in self.model_column_names:
654 raise ModelColumnError(f'Column {deleted_flag_column} is not found in {self.model}')
656 filters = parse_filters(self.model, **kwargs)
658 if not filters:
659 raise ValueError('At least one filter condition must be provided for delete operation')
661 if not allow_multiple:
662 total_count = await self.count(session, *filters)
663 if total_count > 1:
664 raise MultipleResultsError(f'Only one record is expected to be deleted, found {total_count} records.')
666 data = {deleted_flag_column: True}
668 if deleted_at_column in self.model_column_names:
669 data[deleted_at_column] = deleted_at_factory
671 stmt = (
672 update(self.model).where(*filters).values(**data)
673 if logical_deletion
674 else delete(self.model).where(*filters)
675 )
677 result = cast(CursorResult[Any], await session.execute(stmt))
679 if flush:
680 await session.flush()
681 if commit:
682 await session.commit()
684 return result.rowcount