Coverage for src/pullapprove/config.py: 79%

314 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2026-03-16 10:17 -0500

1from __future__ import annotations 

2 

3import re 

4import tomllib 

5import warnings 

6 

7with warnings.catch_warnings(): 

8 warnings.simplefilter("ignore", DeprecationWarning) 

9 import sre_parse 

10from collections.abc import Generator 

11from enum import Enum 

12from pathlib import Path 

13from typing import Any 

14 

15from pydantic import ( 

16 BaseModel, 

17 ConfigDict, 

18 Field, 

19 RootModel, 

20 field_validator, 

21 model_validator, 

22) 

23from wcmatch import glob 

24 

25from .checklists import Checklist 

26 

27CONFIG_FILENAME_PREFIX = "CODEREVIEW" 

28 

29_REPEAT_OPS = {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT} 

30 

31 

32def _has_nested_quantifiers(data: Any) -> bool: 

33 """Detect patterns like (a+)+ that cause catastrophic backtracking.""" 

34 for op, av in data: 

35 if op in _REPEAT_OPS: 

36 if _contains_quantifier(av[2]): 

37 return True 

38 elif op == sre_parse.SUBPATTERN: 

39 if _has_nested_quantifiers(av[-1]): 

40 return True 

41 elif op == sre_parse.BRANCH: 

42 if any(_has_nested_quantifiers(branch) for branch in av[1]): 

43 return True 

44 return False 

45 

46 

47def _contains_quantifier(data: Any) -> bool: 

48 for op, av in data: 

49 if op in _REPEAT_OPS: 

50 return True 

51 elif op == sre_parse.SUBPATTERN: 

52 if _contains_quantifier(av[-1]): 

53 return True 

54 elif op == sre_parse.BRANCH: 

55 if any(_contains_quantifier(branch) for branch in av[1]): 

56 return True 

57 return False 

58 

59 

60CONFIG_FILENAME = "CODEREVIEW.toml" 

61 

62 

63def _expand_aliases( 

64 values: list[str], 

65 aliases: dict[str, list[str]], 

66 _seen: set[str] | None = None, 

67 _path: list[str] | None = None, 

68) -> list[str]: 

69 """Replace alias references in a list with their mapped values recursively.""" 

70 if _seen is None: 

71 _seen = set() 

72 if _path is None: 

73 _path = [] 

74 

75 expanded: list[str] = [] 

76 for value in values: 

77 # Support negated aliases like "!$team" -> ["!alice", "!bob"] 

78 if value.startswith("!$"): 

79 prefix = "!" 

80 alias_ref = value[2:] 

81 elif value.startswith("$"): 

82 prefix = "" 

83 alias_ref = value[1:] 

84 else: 

85 expanded.append(value) 

86 continue 

87 

88 if alias_ref in _seen: 

89 # Cycle detected, raise an error with the cycle path 

90 cycle_path = _path[_path.index(alias_ref) :] + [alias_ref] 

91 raise ValueError( 

92 f"Circular reference detected in aliases: {' -> '.join(cycle_path)}" 

93 ) 

94 if alias_ref in aliases: 

95 _seen.add(alias_ref) 

96 _path.append(alias_ref) 

97 # Recursively expand the alias values 

98 nested_expanded = _expand_aliases(aliases[alias_ref], aliases, _seen, _path) 

99 if prefix: 

100 expanded.extend(prefix + v for v in nested_expanded) 

101 else: 

102 expanded.extend(nested_expanded) 

103 _path.pop() 

104 _seen.remove(alias_ref) 

105 

106 # Remove duplicates while preserving order 

107 return list(dict.fromkeys(expanded)) 

108 

109 

110class ReviewedForChoices(str, Enum): 

111 EMPTY = "" 

112 REQUIRED = "required" 

113 IGNORED = "ignored" 

114 

115 

116class OwnershipChoices(str, Enum): 

117 EMPTY = "" 

118 APPEND = "append" 

119 GLOBAL = "global" 

120 

121 

122class ScopeModel(BaseModel): 

123 model_config = ConfigDict(extra="forbid") 

124 

125 # Required fields 

126 name: str = Field(min_length=1) 

127 paths: list[str] = Field(min_length=1) 

128 

129 # Optional fields 

130 

131 # Expanded version of lines could be dict 

132 # with fnmatch, regex, exclude patterns, etc? 

133 code: list[str] = [] 

134 

135 # This only filtering field that can't be used with raw diff/files... 

136 # If we get into that, the others are: 

137 # - labels 

138 # - ref (have branches at the root level...) 

139 # - statuses 

140 # - dates 

141 # - body 

142 # - title 

143 # - other scopes 

144 # (this is how I ended up with expressions... 

145 # I'm not trying to build a general purpose workflow tool, 

146 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy) 

147 authors: list[str] = [] 

148 

149 # (defaults should be the "empty" values) 

150 description: str = "" 

151 reviewers: list[str] = [] 

152 alternates: list[str] = [] 

153 cc: list[str] = [] 

154 

155 # Review scoring 

156 require: int = 0 

157 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY 

158 author_value: int = 0 

159 

160 # How scopes are combined 

161 ownership: OwnershipChoices = OwnershipChoices.EMPTY 

162 

163 # Actionable items 

164 request: int = 0 

165 labels: list[str] = [] 

166 instructions: str = "" 

167 

168 # Approval checklist 

169 checklist: Checklist | None = None 

170 

171 @field_validator("name", mode="after") 

172 @classmethod 

173 def validate_name(cls, name: str) -> str: 

174 if "," in name: 

175 raise ValueError("Scope name cannot contain commas") 

176 return name 

177 

178 @field_validator("code", mode="after") 

179 @classmethod 

180 def validate_code_patterns(cls, code: list[str]) -> list[str]: 

181 for pattern in code: 

182 try: 

183 parsed = sre_parse.parse(pattern) 

184 except re.error as e: 

185 raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from None 

186 if _has_nested_quantifiers(parsed): 

187 raise ValueError( 

188 f"Regex pattern '{pattern}' contains nested quantifiers, " 

189 "which can cause catastrophic backtracking." 

190 ) 

191 return code 

192 

193 @model_validator(mode="after") 

194 def validate_reviewers_for_require(self) -> ScopeModel: 

195 all_reviewers = self.reviewers + self.alternates 

196 

197 # Skip if wildcard - anyone can review 

198 if "*" in all_reviewers: 

199 return self 

200 

201 # Skip if aliases not yet expanded (will validate again after compilation) 

202 if any(r.startswith("$") for r in all_reviewers): 

203 return self 

204 

205 if len(all_reviewers) < self.require: 

206 raise ValueError( 

207 f"has require={self.require} but only {len(all_reviewers)} reviewers/alternates specified" 

208 ) 

209 return self 

210 

211 @model_validator(mode="after") 

212 def validate_checklist_reviewed_for(self) -> ScopeModel: 

213 if self.checklist and self.reviewed_for == ReviewedForChoices.REQUIRED: 

214 raise ValueError( 

215 "checklist and reviewed_for='required' cannot be used together. " 

216 "The checklist already requires explicit scope acknowledgment." 

217 ) 

218 return self 

219 

220 def printed_name(self) -> str: 

221 match self.ownership: 

222 case OwnershipChoices.APPEND: 

223 return "+" + self.name 

224 case OwnershipChoices.GLOBAL: 

225 return "*" + self.name 

226 

227 return self.name 

228 

229 def __eq__(self, other: Any) -> bool: 

230 return self.name == other.name 

231 

232 def matches_path(self, path: Path) -> bool: 

233 # TODO paths shouldn't start with / 

234 return glob.globmatch( 

235 path, 

236 self.paths, 

237 flags=glob.GLOBSTAR 

238 | glob.BRACE 

239 | glob.NEGATE 

240 | glob.IGNORECASE 

241 | glob.DOTGLOB, 

242 ) 

243 

244 def matches_code(self, code: str) -> Generator[dict[str, int]]: 

245 patterns = getattr(self, "_code_regex_patterns", []) 

246 if not patterns: 

247 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code] 

248 self._code_regex_patterns = patterns 

249 

250 for pattern in patterns: 

251 for match in pattern.finditer(code): 

252 start_index = match.start() 

253 end_index = match.end() 

254 

255 start_line = code.count("\n", 0, start_index) + 1 

256 start_col = start_index - code.rfind("\n", 0, start_index) 

257 

258 end_line = code.count("\n", 0, end_index) + 1 

259 end_col = end_index - code.rfind("\n", 0, end_index) 

260 

261 yield { 

262 "start_line": start_line, 

263 "start_col": start_col, 

264 "end_line": end_line, 

265 "end_col": end_col, 

266 } 

267 

268 def matches_author(self, author_username: str) -> bool: 

269 if not self.authors: 

270 # No authors specified, so assume it matches 

271 return True 

272 

273 author_username_lower = author_username.lower() 

274 

275 negated_authors = [a[1:].lower() for a in self.authors if a.startswith("!")] 

276 authors = [a.lower() for a in self.authors if not a.startswith("!")] 

277 

278 if author_username_lower in negated_authors: 

279 # If the author is in the negated list, return False 

280 return False 

281 

282 if not authors: 

283 # Negation-only: everyone not negated matches 

284 return True 

285 

286 if author_username_lower in authors: 

287 # If the author is in the authors list, return True 

288 return True 

289 

290 return False 

291 

292 def enabled_for_pullrequest(self, author_username: str) -> bool: 

293 # Paths/code are matched during diff parsing, 

294 # but we also consider authors in the context of a pull request so do that here. 

295 return self.matches_author(author_username) 

296 

297 

298class LargeScaleChangeModel(BaseModel): 

299 model_config = ConfigDict(extra="forbid") 

300 

301 # Note, an LSC only applies to diffs, not raw files, 

302 # because we have to know what *changed*. 

303 

304 # Pretty similar to a scope, but more manual. 

305 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one. 

306 require: int = 1 

307 reviewers: list[str] = [] # Field(min_length=1) 

308 # min_paths: int = 300 

309 # min_lines: int = 3000 

310 labels: list[str] = [] 

311 # really need author value too...? 

312 

313 

314class ConfigModel(BaseModel): 

315 model_config = ConfigDict(extra="forbid") 

316 

317 # Nothing is technically required 

318 extends: list[str] = [] 

319 template: bool = False 

320 branches: list[str] = [] 

321 aliases: dict[str, list[str]] = {} 

322 large_scale_change: LargeScaleChangeModel | None = None 

323 scopes: list[ScopeModel] = [] 

324 

325 @field_validator("scopes", mode="after") 

326 @classmethod 

327 def validate_unique_scope_names(cls, scopes: list[ScopeModel]) -> list[ScopeModel]: 

328 seen: set[str] = set() 

329 for scope in scopes: 

330 if scope.name.lower() in seen: 

331 raise ValueError(f"Duplicate scope name: {scope.name}") 

332 seen.add(scope.name.lower()) 

333 

334 return scopes 

335 

336 @field_validator("extends", mode="before") 

337 @classmethod 

338 def validate_extends(cls, extends: list[str]) -> list[str]: 

339 for i, path in enumerate(extends): 

340 basename = Path(path).name 

341 if not basename.startswith(CONFIG_FILENAME_PREFIX): 

342 raise ValueError( 

343 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'." 

344 ) 

345 return extends 

346 

347 def compiled_config( 

348 self, config_path: Path, other_configs: ConfigModels 

349 ) -> ConfigModel: 

350 """ 

351 Merge extends, replace aliases. 

352 """ 

353 

354 if getattr(self, "_compiled_config", None) is not None: 

355 return self._compiled_config 

356 

357 # Create a copy of the data from what we have currently 

358 compiled_data = self.model_dump() 

359 

360 for extend_path in self.extends: 

361 if extend_path not in other_configs: 

362 raise ValueError(f"Config {extend_path} not found") 

363 

364 extended_config = other_configs[extend_path] 

365 extended_config.compiled_config(Path(extend_path), other_configs) 

366 

367 extended_config_dumped = extended_config.model_dump( 

368 include={"branches", "aliases", "scopes", "large_scale_change"} 

369 ) 

370 

371 compiled_data["scopes"] = ( 

372 extended_config_dumped["scopes"] + compiled_data["scopes"] 

373 ) 

374 compiled_data["large_scale_change"] = ( 

375 compiled_data["large_scale_change"] 

376 or extended_config_dumped["large_scale_change"] 

377 ) 

378 compiled_data["aliases"] = ( 

379 extended_config_dumped["aliases"] | compiled_data["aliases"] 

380 ) 

381 compiled_data["branches"] = ( 

382 extended_config_dumped["branches"] + compiled_data["branches"] 

383 ) 

384 

385 # Root aliases 

386 for field in ["extends", "branches"]: 

387 if field in compiled_data: 

388 compiled_data[field] = _expand_aliases( 

389 compiled_data[field], compiled_data["aliases"] 

390 ) 

391 

392 # Expand aliases for any aliasable list fields 

393 for scope in compiled_data["scopes"]: 

394 for field in [ 

395 "paths", 

396 "code", 

397 "authors", 

398 "reviewers", 

399 "alternates", 

400 "cc", 

401 "labels", 

402 ]: 

403 if field in scope: 

404 scope[field] = _expand_aliases( 

405 scope[field], compiled_data["aliases"] 

406 ) 

407 

408 if large_scale_change := compiled_data.get("large_scale_change"): 

409 for field in ["reviewers", "labels"]: 

410 large_scale_change[field] = _expand_aliases( 

411 large_scale_change[field], 

412 compiled_data["aliases"], 

413 ) 

414 

415 # Create a new config from the merged data 

416 self._compiled_config = ConfigModel.from_data( 

417 data=compiled_data, 

418 path=config_path, 

419 ) 

420 

421 return self._compiled_config 

422 

423 @classmethod 

424 def from_filesystem(cls, path: Path | str) -> ConfigModel: 

425 with open(path, "rb") as f: 

426 return cls.from_data(tomllib.load(f), path) 

427 

428 @classmethod 

429 def from_content(cls, content: str, path: Path | str) -> ConfigModel: 

430 return cls.from_data(tomllib.loads(content), path) 

431 

432 @classmethod 

433 def from_data(cls, data: dict[str, Any], path: Path | str) -> ConfigModel: 

434 # config = cls(path) 

435 

436 # config.data = data 

437 # config = ConfigModel(**config.data) 

438 

439 return cls(**data) 

440 

441 def matches_branches(self, base_branch: str, head_branch: str) -> bool: 

442 if not self.branches: 

443 # No branches specified, so assume it matches 

444 return True 

445 

446 for pattern in self.branches: 

447 splitter = "..." if "..." in pattern else ".." 

448 parts = pattern.split(splitter) 

449 base_pattern = parts[0] 

450 head_pattern = parts[1] if len(parts) > 1 else None 

451 

452 base_match = ( 

453 glob.globmatch(base_branch, base_pattern) if base_pattern else True 

454 ) 

455 head_match = ( 

456 glob.globmatch(head_branch, head_pattern) if head_pattern else True 

457 ) 

458 

459 if base_match and head_match: 

460 return True 

461 

462 return False 

463 

464 def enabled_for_pullrequest(self, base_branch: str, head_branch: str) -> bool: 

465 return self.matches_branches(base_branch, head_branch) 

466 

467 # Kinda want the original toml if you dump? with comments etc 

468 # def as_toml(self) -> str: 

469 # """ 

470 # Convert the config to a TOML string. 

471 # """ 

472 # return tomllib.dumps(self.model_dump()) 

473 

474 # def matches_branch(self, base_branch: str, head_branch: str) -> bool: 

475 # for pattern in self.branches: 

476 # splitter = "..." if "..." in pattern else ".." 

477 # parts = pattern.split(splitter) 

478 # base_pattern = parts[0] 

479 # head_pattern = parts[1] if len(parts) > 1 else None 

480 

481 # base_match = fnmatch.fnmatch(base_branch, base_pattern) if base_pattern else True 

482 # head_match = fnmatch.fnmatch(head_branch, head_pattern) if head_pattern else True 

483 

484 # if base_match and head_match: 

485 # return True 

486 

487 # return False 

488 

489 

490class ConfigModels(RootModel): 

491 root: dict[str, ConfigModel] 

492 

493 # def __init__(self, configs: dict[str, CodeReviewConfig] = None): 

494 # if configs is None: 

495 # configs = {} 

496 # self = configs 

497 

498 # def __repr__(self): 

499 # return f"CodeReviewConfigs({self})" 

500 

501 @classmethod 

502 def from_configs_data(cls, data: dict[str, Any]) -> ConfigModels: 

503 """Load configs from a dict of data""" 

504 configs = cls(root={}) 

505 

506 for path, config_data in data.items(): 

507 config = ConfigModel.from_data(config_data, Path(path)) 

508 configs.add_config(config, Path(path)) 

509 

510 return configs 

511 

512 @classmethod 

513 def from_config_models(cls, models: dict[str, ConfigModel]) -> ConfigModels: 

514 """Load configs from a dict of models""" 

515 configs = cls(root={}) 

516 

517 for path, config_model in models.items(): 

518 # config = ConfigModel.from_model(config_model, Path(path)) 

519 configs.add_config(config_model, Path(path)) 

520 

521 return configs 

522 

523 def get_config_models(self) -> dict[str, ConfigModel]: 

524 return dict(self.root.items()) 

525 

526 def add_config(self, config: ConfigModel, path: Path) -> None: 

527 self.root[str(path)] = config 

528 

529 def get_default_large_scale_change(self) -> LargeScaleChangeModel: 

530 """Get the root config, which is the first one found in the list""" 

531 primary_config = CONFIG_FILENAME 

532 

533 if primary_config in self.root: 

534 if lsc := self.root[primary_config].large_scale_change: 

535 return lsc 

536 

537 return LargeScaleChangeModel() 

538 

539 def __bool__(self) -> bool: 

540 return bool(self.root) 

541 

542 def __getitem__(self, key: str) -> ConfigModel: 

543 return self.root[key] 

544 

545 def __contains__(self, key: str) -> bool: 

546 return key in self.root 

547 

548 def __len__(self) -> int: 

549 return len(self.root) 

550 

551 def compile_closest_config(self, file_path: Path) -> ConfigModel: 

552 """Find the closest config file to this file""" 

553 for parent in file_path.parents: 

554 parent_config_path = str(parent / CONFIG_FILENAME) 

555 

556 if parent_config_path in self.root: 

557 config = self.root[parent_config_path] 

558 

559 if config.template: 

560 # Skip templates 

561 continue 

562 

563 compiled = config.compiled_config(Path(parent_config_path), self) 

564 

565 return compiled 

566 

567 raise ValueError(f"No config found for {file_path}") 

568 

569 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel]]: 

570 for config_path, config in self.root.items(): 

571 if config.template: 

572 # Skip templates 

573 continue 

574 

575 yield config_path, config.compiled_config(Path(config_path), self) 

576 

577 def num_scopes(self) -> int: 

578 """ 

579 Count the total number of scopes across all configs. 

580 """ 

581 return sum(len(config.scopes) for config in self.root.values()) 

582 

583 def num_reviewers(self) -> int: 

584 """ 

585 Count the total number of reviewers across all configs. 

586 """ 

587 return sum( 

588 len(scope.reviewers) 

589 for config in self.root.values() 

590 for scope in config.scopes 

591 ) 

592 

593 def filter_for_pullrequest( 

594 self, 

595 base_branch: str, 

596 head_branch: str, 

597 author_username: str, 

598 ) -> ConfigModels: 

599 """ 

600 Look at all configs (including templates) and filter out 

601 configs and scopes based on branches, authors, etc. 

602 """ 

603 filtered_configs = {} 

604 

605 for config_path, config in self.root.items(): 

606 compiled_config = config.compiled_config(Path(config_path), self) 

607 

608 if not compiled_config.enabled_for_pullrequest(base_branch, head_branch): 

609 # Remove the config from the list 

610 continue 

611 

612 # Collect the names of scopes to remove 

613 scopes_to_remove = set() 

614 for scope in compiled_config.scopes: 

615 if not scope.enabled_for_pullrequest(author_username): 

616 scopes_to_remove.add(scope.name) 

617 

618 # Create a filtered copy of the config data 

619 filtered_config_data = config.model_dump() 

620 

621 # Filter scopes in a simple, declarative way 

622 if "scopes" in filtered_config_data: 

623 filtered_config_data["scopes"] = [ 

624 scope 

625 for scope in filtered_config_data["scopes"] 

626 if scope["name"] not in scopes_to_remove 

627 ] 

628 

629 # Rebuild using the origial/modified raw data 

630 filtered_configs[config_path] = filtered_config_data 

631 

632 return ConfigModels.from_configs_data(filtered_configs) 

633 

634 # for config in configs.iter_compiled_configs(): 

635 # if config.enabled_for_pullrequest(self): 

636 # filtered_configs.add_config(config) 

637 

638 # # Paths/code are similar, but we need to iterate the diff to process them, 

639 # # so this is almost a pre-process for metadata of the PR 

640 # scopes_to_remove = [] 

641 # for scope in config.scopes: 

642 # if not scope.enabled_for_pullrequest(self): 

643 # scopes_to_remove.append(scope) 

644 # for scope in scopes_to_remove: 

645 # config.scopes.remove(scope)