Coverage for sqlalchemy_crud_plus\crud.py: 97%

218 statements  

« prev     ^ index     » next       coverage.py v7.11.1, created at 2025-11-26 16:47 +0800

1#!/usr/bin/env python3 

2# -*- coding: utf-8 -*- 

3from datetime import datetime, timezone 

4from typing import Any, Generic, Sequence, cast 

5 

6from sqlalchemy import ( 

7 Column, 

8 ColumnExpressionArgument, 

9 CursorResult, 

10 Row, 

11 Select, 

12 delete, 

13 func, 

14 insert, 

15 inspect, 

16 select, 

17 update, 

18) 

19from sqlalchemy.ext.asyncio import AsyncSession 

20 

21from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError 

22from sqlalchemy_crud_plus.types import ( 

23 CreateSchema, 

24 JoinConditions, 

25 LoadOptions, 

26 LoadStrategies, 

27 Model, 

28 SortColumns, 

29 SortOrders, 

30 UpdateSchema, 

31) 

32from sqlalchemy_crud_plus.utils import ( 

33 apply_join_conditions, 

34 apply_sorting, 

35 build_load_strategies, 

36 has_join_fill_result, 

37 parse_filters, 

38) 

39 

40 

41class CRUDPlus(Generic[Model]): 

42 def __init__(self, model: type[Model]): 

43 self.model = model 

44 self.model_column_names = [column.key for column in model.__table__.columns] 

45 self.primary_key = self._get_primary_key() 

46 

47 def _get_primary_key(self) -> Column | list[Column]: 

48 """ 

49 Dynamically retrieve the primary key column(s) for the model. 

50 """ 

51 mapper = inspect(self.model) 

52 primary_key = mapper.primary_key 

53 if len(primary_key) == 1: 

54 return primary_key[0] 

55 else: 

56 return list(primary_key) 

57 

58 def _get_pk_filter(self, pk: Any | list[Any]) -> list[ColumnExpressionArgument[bool]]: 

59 """ 

60 Get the primary key filter(s). 

61 

62 :param pk: Single value for simple primary key, or tuple for composite primary key 

63 :return: 

64 """ 

65 if isinstance(self.primary_key, list): 

66 if len(pk) != len(self.primary_key): 

67 raise CompositePrimaryKeysError(f'Expected {len(self.primary_key)} values for composite primary key') 

68 return [column == value for column, value in zip(self.primary_key, pk)] 

69 else: 

70 return [self.primary_key == pk] 

71 

72 async def create_model( 

73 self, 

74 session: AsyncSession, 

75 obj: CreateSchema, 

76 flush: bool = False, 

77 commit: bool = False, 

78 **kwargs, 

79 ) -> Model: 

80 """ 

81 Create a new instance of a model. 

82 

83 :param session: The SQLAlchemy async session 

84 :param obj: The Pydantic schema containing data to be saved 

85 :param flush: If `True`, flush all object changes to the database 

86 :param commit: If `True`, commits the transaction immediately 

87 :param kwargs: Additional model data not included in the pydantic schema 

88 :return: 

89 """ 

90 obj_data = obj.model_dump() 

91 if kwargs: 

92 obj_data.update(kwargs) 

93 

94 ins = self.model(**obj_data) 

95 session.add(ins) 

96 

97 if flush: 

98 await session.flush() 

99 if commit: 

100 await session.commit() 

101 

102 return ins 

103 

104 async def create_models( 

105 self, 

106 session: AsyncSession, 

107 objs: list[CreateSchema], 

108 flush: bool = False, 

109 commit: bool = False, 

110 **kwargs, 

111 ) -> list[Model]: 

112 """ 

113 Create new instances of a model. 

114 

115 :param session: The SQLAlchemy async session 

116 :param objs: The Pydantic schema list containing data to be saved 

117 :param flush: If `True`, flush all object changes to the database 

118 :param commit: If `True`, commits the transaction immediately 

119 :param kwargs: Additional model data not included in the pydantic schema 

120 :return: 

121 """ 

122 ins_list = [] 

123 for obj in objs: 

124 obj_data = obj.model_dump() 

125 if kwargs: 

126 obj_data.update(kwargs) 

127 ins = self.model(**obj_data) 

128 ins_list.append(ins) 

129 

130 session.add_all(ins_list) 

131 

132 if flush: 

133 await session.flush() 

134 if commit: 

135 await session.commit() 

136 

137 return ins_list 

138 

139 async def bulk_create_models( 

140 self, 

141 session: AsyncSession, 

142 objs: list[dict[str, Any]], 

143 render_nulls: bool = False, 

144 flush: bool = False, 

145 commit: bool = False, 

146 **kwargs, 

147 ) -> Sequence[Model]: 

148 """ 

149 Create new instances of a model. 

150 

151 :param session: The SQLAlchemy async session 

152 :param objs: The dict list containing data to be saved,The dict data should be aligned with the model column 

153 :param render_nulls: render null values instead of ignoring them 

154 :param flush: If `True`, flush all object changes to the database 

155 :param commit: If `True`, commits the transaction immediately 

156 :param kwargs: Additional model data not included in the dict 

157 :return: 

158 """ 

159 stmt = insert(self.model).values(**kwargs).execution_options(render_nulls=render_nulls).returning(self.model) 

160 result = await session.execute(stmt, objs) 

161 

162 if flush: 

163 await session.flush() 

164 if commit: 

165 await session.commit() 

166 

167 return result.scalars().all() 

168 

169 async def count( 

170 self, 

171 session: AsyncSession, 

172 *whereclause: ColumnExpressionArgument[bool], 

173 join_conditions: JoinConditions | None = None, 

174 **kwargs, 

175 ) -> int: 

176 """ 

177 Count records that match specified filters. 

178 

179 :param session: SQLAlchemy async session 

180 :param whereclause: Additional WHERE clauses 

181 :param join_conditions: JOIN conditions for relationships 

182 :param kwargs: Filter expressions using field__operator=value syntax 

183 :return: 

184 """ 

185 filters = list(whereclause) 

186 

187 if kwargs: 

188 filters.extend(parse_filters(self.model, **kwargs)) 

189 

190 if isinstance(self.primary_key, list): 

191 stmt = select(func.count()).select_from(self.model) 

192 else: 

193 stmt = select(func.count(self.primary_key)).select_from(self.model) 

194 

195 if filters: 

196 stmt = stmt.where(*filters) 

197 

198 if join_conditions: 

199 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions) 

200 

201 query = await session.execute(stmt) 

202 total_count = query.scalar() 

203 return total_count if total_count is not None else 0 

204 

205 async def exists( 

206 self, 

207 session: AsyncSession, 

208 *whereclause: ColumnExpressionArgument[bool], 

209 join_conditions: JoinConditions | None = None, 

210 **kwargs, 

211 ) -> bool: 

212 """ 

213 Check whether records that match the specified filters exist. 

214 

215 :param session: SQLAlchemy async session 

216 :param whereclause: Additional WHERE clauses 

217 :param join_conditions: JOIN conditions for relationships 

218 :param kwargs: Filter expressions using field__operator=value syntax 

219 :return: 

220 """ 

221 filters = list(whereclause) 

222 

223 if kwargs: 

224 filters.extend(parse_filters(self.model, **kwargs)) 

225 

226 stmt = select(self.model).where(*filters).limit(1) 

227 

228 if join_conditions: 

229 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions) 

230 

231 query = await session.execute(stmt) 

232 return query.scalars().first() is not None 

233 

234 async def select_model( 

235 self, 

236 session: AsyncSession, 

237 pk: Any | Sequence[Any], 

238 *whereclause: ColumnExpressionArgument[bool], 

239 load_options: LoadOptions | None = None, 

240 load_strategies: LoadStrategies | None = None, 

241 join_conditions: JoinConditions | None = None, 

242 **kwargs: Any, 

243 ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None: 

244 """ 

245 Query by primary key(s) with optional relationship loading and joins. 

246 

247 :param session: SQLAlchemy async session 

248 :param pk: Primary key value(s) - single value or tuple for composite keys 

249 :param whereclause: Additional WHERE clauses 

250 :param load_options: SQLAlchemy loading options 

251 :param load_strategies: Relationship loading strategies 

252 :param join_conditions: JOIN conditions for relationships 

253 :param kwargs: Filter expressions using field__operator=value syntax 

254 :return: 

255 """ 

256 filters = list(whereclause) 

257 filters.extend(self._get_pk_filter(pk)) 

258 

259 if kwargs: 

260 filters.extend(parse_filters(self.model, **kwargs)) 

261 

262 stmt = select(self.model).where(*filters) 

263 

264 if load_options: 

265 stmt = stmt.options(*load_options) 

266 

267 if join_conditions: 

268 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions) 

269 

270 if load_strategies: 

271 rel_options = build_load_strategies(self.model, load_strategies) 

272 if rel_options: 

273 stmt = stmt.options(*rel_options) 

274 

275 query = await session.execute(stmt) 

276 

277 if join_conditions: 

278 if has_join_fill_result(join_conditions): 

279 return query.first() 

280 

281 return query.scalars().first() 

282 

283 async def select_model_by_column( 

284 self, 

285 session: AsyncSession, 

286 *whereclause: ColumnExpressionArgument[bool], 

287 load_options: LoadOptions | None = None, 

288 load_strategies: LoadStrategies | None = None, 

289 join_conditions: JoinConditions | None = None, 

290 **kwargs: Any, 

291 ) -> Sequence[Row[tuple[Model, ...]] | None] | Model | None: 

292 """ 

293 Query by column with optional relationship loading and joins. 

294 

295 :param session: SQLAlchemy async session 

296 :param whereclause: Additional WHERE clauses 

297 :param load_options: SQLAlchemy loading options 

298 :param load_strategies: Relationship loading strategies 

299 :param join_conditions: JOIN conditions for relationships 

300 :param kwargs: Filter expressions using field__operator=value syntax 

301 :return: 

302 """ 

303 stmt = await self.select( 

304 *whereclause, 

305 load_options=load_options, 

306 load_strategies=load_strategies, 

307 join_conditions=join_conditions, 

308 **kwargs, 

309 ) 

310 

311 query = await session.execute(stmt) 

312 

313 if join_conditions: 

314 if has_join_fill_result(join_conditions): 

315 return query.first() 

316 

317 return query.scalars().first() 

318 

319 async def select( 

320 self, 

321 *whereclause: ColumnExpressionArgument[bool], 

322 load_options: LoadOptions | None = None, 

323 load_strategies: LoadStrategies | None = None, 

324 join_conditions: JoinConditions | None = None, 

325 **kwargs, 

326 ) -> Select: 

327 """ 

328 Construct the SQLAlchemy selection. 

329 

330 :param whereclause: WHERE clauses to apply to the query 

331 :param load_options: SQLAlchemy loading options 

332 :param load_strategies: Relationship loading strategies 

333 :param join_conditions: JOIN conditions for relationships 

334 :param kwargs: Query expressions 

335 :return: 

336 """ 

337 filters = list(whereclause) 

338 filters.extend(parse_filters(self.model, **kwargs)) 

339 stmt = select(self.model).where(*filters) 

340 

341 if join_conditions: 

342 stmt = apply_join_conditions(self.model, stmt.select_from(self.model), join_conditions) 

343 

344 if load_options: 

345 stmt = stmt.options(*load_options) 

346 

347 if load_strategies: 

348 rel_options = build_load_strategies(self.model, load_strategies) 

349 if rel_options: 

350 stmt = stmt.options(*rel_options) 

351 

352 return stmt 

353 

354 async def select_order( 

355 self, 

356 sort_columns: SortColumns, 

357 sort_orders: SortOrders = None, 

358 *whereclause: ColumnExpressionArgument[bool], 

359 load_options: LoadOptions | None = None, 

360 load_strategies: LoadStrategies | None = None, 

361 join_conditions: JoinConditions | None = None, 

362 **kwargs: Any, 

363 ) -> Select: 

364 """ 

365 Construct SQLAlchemy selection with sorting. 

366 

367 :param sort_columns: Column names to sort by 

368 :param sort_orders: Sort orders ('asc' or 'desc') 

369 :param whereclause: WHERE clauses to apply to the query 

370 :param load_options: SQLAlchemy loading options 

371 :param load_strategies: Relationship loading strategies 

372 :param join_conditions: JOIN conditions for relationships 

373 :param kwargs: Query expressions 

374 :return: 

375 """ 

376 stmt = await self.select( 

377 *whereclause, 

378 load_options=load_options, 

379 load_strategies=load_strategies, 

380 join_conditions=join_conditions, 

381 **kwargs, 

382 ) 

383 sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders) 

384 return sorted_stmt 

385 

386 async def select_models( 

387 self, 

388 session: AsyncSession, 

389 *whereclause: ColumnExpressionArgument[bool], 

390 load_options: LoadOptions | None = None, 

391 load_strategies: LoadStrategies | None = None, 

392 join_conditions: JoinConditions | None = None, 

393 limit: int | None = None, 

394 offset: int | None = None, 

395 **kwargs: Any, 

396 ) -> Sequence[Row[tuple[Model, ...] | Any] | Model]: 

397 """ 

398 Query all rows that match the specified filters with optional relationship loading and joins. 

399 

400 :param session: SQLAlchemy async session 

401 :param whereclause: Additional WHERE clauses 

402 :param load_options: SQLAlchemy loading options 

403 :param load_strategies: Relationship loading strategies 

404 :param join_conditions: JOIN conditions for relationships 

405 :param limit: Maximum number of results to return 

406 :param offset: Number of results to skip 

407 :param kwargs: Filter expressions using field__operator=value syntax 

408 :return: 

409 """ 

410 stmt = await self.select( 

411 *whereclause, 

412 load_options=load_options, 

413 load_strategies=load_strategies, 

414 join_conditions=join_conditions, 

415 **kwargs, 

416 ) 

417 

418 if limit is not None: 

419 stmt = stmt.limit(limit) 

420 if offset is not None: 

421 stmt = stmt.offset(offset) 

422 

423 query = await session.execute(stmt) 

424 

425 if join_conditions: 

426 if has_join_fill_result(join_conditions): 

427 return query.all() 

428 

429 return query.scalars().all() 

430 

431 async def select_models_order( 

432 self, 

433 session: AsyncSession, 

434 sort_columns: SortColumns, 

435 sort_orders: SortOrders = None, 

436 *whereclause: ColumnExpressionArgument[bool], 

437 load_options: LoadOptions | None = None, 

438 load_strategies: LoadStrategies | None = None, 

439 join_conditions: JoinConditions | None = None, 

440 limit: int | None = None, 

441 offset: int | None = None, 

442 **kwargs: Any, 

443 ) -> Sequence[Row[tuple[Model, ...] | Any] | Model]: 

444 """ 

445 Query all rows that match the specified filters and sort by columns 

446 with optional relationship loading and joins. 

447 

448 :param session: SQLAlchemy async session 

449 :param sort_columns: Column names to sort by 

450 :param sort_orders: Sort orders ('asc' or 'desc') 

451 :param whereclause: Additional WHERE clauses 

452 :param load_options: SQLAlchemy loading options 

453 :param load_strategies: Relationship loading strategies 

454 :param join_conditions: JOIN conditions for relationships 

455 :param limit: Maximum number of results to return 

456 :param offset: Number of results to skip 

457 :param kwargs: Filter expressions using field__operator=value syntax 

458 :return: 

459 """ 

460 stmt = await self.select_order( 

461 sort_columns, 

462 sort_orders, 

463 *whereclause, 

464 load_options=load_options, 

465 load_strategies=load_strategies, 

466 join_conditions=join_conditions, 

467 **kwargs, 

468 ) 

469 

470 if limit is not None: 

471 stmt = stmt.limit(limit) 

472 if offset is not None: 

473 stmt = stmt.offset(offset) 

474 

475 query = await session.execute(stmt) 

476 

477 if join_conditions: 

478 if has_join_fill_result(join_conditions): 

479 return query.all() 

480 

481 return query.scalars().all() 

482 

483 async def update_model( 

484 self, 

485 session: AsyncSession, 

486 pk: Any | Sequence[Any], 

487 obj: UpdateSchema | dict[str, Any], 

488 flush: bool = False, 

489 commit: bool = False, 

490 **kwargs, 

491 ) -> int: 

492 """ 

493 Update an instance by model's primary key 

494 

495 :param session: The SQLAlchemy async session. 

496 :param pk: Single value for simple primary key, or tuple for composite primary key. 

497 :param obj: A pydantic schema or dictionary containing the update data 

498 :param flush: If `True`, flush all object changes to the database. Default is `False`. 

499 :param commit: If `True`, commits the transaction immediately. Default is `False`. 

500 :param kwargs: Additional model data not included in the pydantic schema. 

501 :return: 

502 """ 

503 filters = self._get_pk_filter(pk) 

504 data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) 

505 data.update(kwargs) 

506 stmt = update(self.model).where(*filters).values(**data) 

507 result = cast(CursorResult[Any], await session.execute(stmt)) 

508 

509 if flush: 

510 await session.flush() 

511 if commit: 

512 await session.commit() 

513 

514 return result.rowcount 

515 

516 async def update_model_by_column( 

517 self, 

518 session: AsyncSession, 

519 obj: UpdateSchema | dict[str, Any], 

520 allow_multiple: bool = False, 

521 flush: bool = False, 

522 commit: bool = False, 

523 **kwargs, 

524 ) -> int: 

525 """ 

526 Update records by model column filters. 

527 

528 :param session: The SQLAlchemy async session 

529 :param obj: A Pydantic schema or dictionary containing the update data 

530 :param allow_multiple: If `True`, allows updating multiple records that match the filters 

531 :param flush: If `True`, flush all object changes to the database 

532 :param commit: If `True`, commits the transaction immediately 

533 :param kwargs: Filter expressions using field__operator=value syntax 

534 :return: 

535 """ 

536 filters = parse_filters(self.model, **kwargs) 

537 

538 if not filters: 

539 raise ValueError('At least one filter condition must be provided for update operation') 

540 

541 if not allow_multiple: 

542 total_count = await self.count(session, *filters) 

543 if total_count > 1: 

544 raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.') 

545 

546 data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) 

547 stmt = update(self.model).where(*filters).values(**data) 

548 result = cast(CursorResult[Any], await session.execute(stmt)) 

549 

550 if flush: 

551 await session.flush() 

552 if commit: 

553 await session.commit() 

554 

555 return result.rowcount 

556 

557 async def bulk_update_models( 

558 self, 

559 session: AsyncSession, 

560 objs: list[UpdateSchema | dict[str, Any]], 

561 pk_mode: bool = True, 

562 flush: bool = False, 

563 commit: bool = False, 

564 **kwargs, 

565 ) -> int: 

566 """ 

567 Bulk update multiple instances with different data for each record. 

568 

569 :param session: The SQLAlchemy async session 

570 :param objs: To save a list of Pydantic schemas or dict for data 

571 :param pk_mode: Primary key mode, when enabled, the data must contain the primary key data 

572 :param flush: If `True`, flush all object changes to the database 

573 :param commit: If `True`, commits the transaction immediately 

574 :param kwargs: Filter expressions using field__operator=value syntax 

575 :return: 

576 """ 

577 if not pk_mode: 

578 filters = parse_filters(self.model, **kwargs) 

579 

580 if not filters: 

581 raise ValueError('At least one filter condition must be provided for update operation') 

582 

583 datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs] 

584 stmt = update(self.model).where(*filters) 

585 conn = await session.connection() 

586 await conn.execute(stmt, datas) 

587 else: 

588 datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs] 

589 await session.execute(update(self.model), datas) 

590 

591 if flush: 

592 await session.flush() 

593 if commit: 

594 await session.commit() 

595 

596 return len(datas) 

597 

598 async def delete_model( 

599 self, 

600 session: AsyncSession, 

601 pk: Any | Sequence[Any], 

602 flush: bool = False, 

603 commit: bool = False, 

604 ) -> int: 

605 """ 

606 Delete an instance by model's primary key 

607 

608 :param session: The SQLAlchemy async session. 

609 :param pk: Single value for simple primary key, or tuple for composite primary key. 

610 :param flush: If `True`, flush all object changes to the database. Default is `False`. 

611 :param commit: If `True`, commits the transaction immediately. Default is `False`. 

612 :return: 

613 """ 

614 filters = self._get_pk_filter(pk) 

615 

616 stmt = delete(self.model).where(*filters) 

617 result = cast(CursorResult[Any], await session.execute(stmt)) 

618 

619 if flush: 

620 await session.flush() 

621 if commit: 

622 await session.commit() 

623 

624 return result.rowcount 

625 

626 async def delete_model_by_column( 

627 self, 

628 session: AsyncSession, 

629 allow_multiple: bool = False, 

630 logical_deletion: bool = False, 

631 deleted_flag_column: str = 'is_deleted', 

632 deleted_at_column: str = 'deleted_at', 

633 deleted_at_factory: datetime = datetime.now(timezone.utc), 

634 flush: bool = False, 

635 commit: bool = False, 

636 **kwargs, 

637 ) -> int: 

638 """ 

639 Delete records by model column filters. 

640 

641 :param session: The SQLAlchemy async session 

642 :param allow_multiple: If `True`, allows deleting multiple records that match the filters 

643 :param logical_deletion: If `True`, enable logical deletion instead of physical deletion 

644 :param deleted_flag_column: Column name for logical deletion flag 

645 :param deleted_at_column: Column name for delete time,automatic judgment 

646 :param deleted_at_factory: The delete time column datetime factory function 

647 :param flush: If `True`, flush all object changes to the database 

648 :param commit: If `True`, commits the transaction immediately 

649 :param kwargs: Filter expressions using field__operator=value syntax 

650 :return: 

651 """ 

652 if logical_deletion: 

653 if deleted_flag_column not in self.model_column_names: 

654 raise ModelColumnError(f'Column {deleted_flag_column} is not found in {self.model}') 

655 

656 filters = parse_filters(self.model, **kwargs) 

657 

658 if not filters: 

659 raise ValueError('At least one filter condition must be provided for delete operation') 

660 

661 if not allow_multiple: 

662 total_count = await self.count(session, *filters) 

663 if total_count > 1: 

664 raise MultipleResultsError(f'Only one record is expected to be deleted, found {total_count} records.') 

665 

666 data = {deleted_flag_column: True} 

667 

668 if deleted_at_column in self.model_column_names: 

669 data[deleted_at_column] = deleted_at_factory 

670 

671 stmt = ( 

672 update(self.model).where(*filters).values(**data) 

673 if logical_deletion 

674 else delete(self.model).where(*filters) 

675 ) 

676 

677 result = cast(CursorResult[Any], await session.execute(stmt)) 

678 

679 if flush: 

680 await session.flush() 

681 if commit: 

682 await session.commit() 

683 

684 return result.rowcount