Coverage for src/pullapprove/config.py: 88%

318 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2026-03-16 16:39 -0500

1from __future__ import annotations 

2 

3import re 

4import tomllib 

5import warnings 

6 

7with warnings.catch_warnings(): 

8 warnings.simplefilter("ignore", DeprecationWarning) 

9 import sre_parse 

10from collections.abc import Generator 

11from enum import Enum 

12from pathlib import Path 

13from typing import Any 

14 

15from pydantic import ( 

16 BaseModel, 

17 ConfigDict, 

18 Field, 

19 RootModel, 

20 field_validator, 

21 model_validator, 

22) 

23from wcmatch import glob 

24 

25from .checklists import Checklist 

26 

27CONFIG_FILENAME_PREFIX = "CODEREVIEW" 

28 

29_REPEAT_OPS = {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT} 

30 

31 

32def _has_nested_quantifiers(data: Any) -> bool: 

33 """Detect patterns like (a+)+ that cause catastrophic backtracking.""" 

34 for op, av in data: 

35 if op in _REPEAT_OPS: 

36 if _contains_quantifier(av[2]): 

37 return True 

38 elif op == sre_parse.SUBPATTERN: 

39 if _has_nested_quantifiers(av[-1]): 

40 return True 

41 elif op == sre_parse.BRANCH: 

42 if any(_has_nested_quantifiers(branch) for branch in av[1]): 

43 return True 

44 return False 

45 

46 

47def _contains_quantifier(data: Any) -> bool: 

48 for op, av in data: 

49 if op in _REPEAT_OPS: 

50 return True 

51 elif op == sre_parse.SUBPATTERN: 

52 if _contains_quantifier(av[-1]): 

53 return True 

54 elif op == sre_parse.BRANCH: 

55 if any(_contains_quantifier(branch) for branch in av[1]): 

56 return True 

57 return False 

58 

59 

60CONFIG_FILENAME = "CODEREVIEW.toml" 

61 

62 

63def _matches_branches(branches: list[str], base_branch: str, head_branch: str) -> bool: 

64 """Check if the given base/head branches match any of the branch patterns.""" 

65 if not branches: 

66 return True 

67 

68 for pattern in branches: 

69 splitter = "..." if "..." in pattern else ".." 

70 parts = pattern.split(splitter, 1) 

71 

72 base_pattern = parts[0] 

73 head_pattern = parts[1] if len(parts) > 1 else None 

74 

75 base_match = glob.globmatch(base_branch, base_pattern) if base_pattern else True 

76 head_match = glob.globmatch(head_branch, head_pattern) if head_pattern else True 

77 

78 if base_match and head_match: 

79 return True 

80 

81 return False 

82 

83 

84def _expand_aliases( 

85 values: list[str], 

86 aliases: dict[str, list[str]], 

87 _seen: set[str] | None = None, 

88 _path: list[str] | None = None, 

89) -> list[str]: 

90 """Replace alias references in a list with their mapped values recursively.""" 

91 if _seen is None: 

92 _seen = set() 

93 if _path is None: 

94 _path = [] 

95 

96 expanded: list[str] = [] 

97 for value in values: 

98 # Support negated aliases like "!$team" -> ["!alice", "!bob"] 

99 if value.startswith("!$"): 

100 prefix = "!" 

101 alias_ref = value[2:] 

102 elif value.startswith("$"): 

103 prefix = "" 

104 alias_ref = value[1:] 

105 else: 

106 expanded.append(value) 

107 continue 

108 

109 if alias_ref in _seen: 

110 # Cycle detected, raise an error with the cycle path 

111 cycle_path = _path[_path.index(alias_ref) :] + [alias_ref] 

112 raise ValueError( 

113 f"Circular reference detected in aliases: {' -> '.join(cycle_path)}" 

114 ) 

115 if alias_ref in aliases: 

116 _seen.add(alias_ref) 

117 _path.append(alias_ref) 

118 # Recursively expand the alias values 

119 nested_expanded = _expand_aliases(aliases[alias_ref], aliases, _seen, _path) 

120 if prefix: 

121 expanded.extend(prefix + v for v in nested_expanded) 

122 else: 

123 expanded.extend(nested_expanded) 

124 _path.pop() 

125 _seen.remove(alias_ref) 

126 

127 # Remove duplicates while preserving order 

128 return list(dict.fromkeys(expanded)) 

129 

130 

131class ReviewedForChoices(str, Enum): 

132 EMPTY = "" 

133 REQUIRED = "required" 

134 IGNORED = "ignored" 

135 

136 

137class OwnershipChoices(str, Enum): 

138 EMPTY = "" 

139 APPEND = "append" 

140 GLOBAL = "global" 

141 

142 

143class ScopeModel(BaseModel): 

144 model_config = ConfigDict(extra="forbid") 

145 

146 # Required fields 

147 name: str = Field(min_length=1) 

148 paths: list[str] = Field(min_length=1) 

149 

150 # Optional fields 

151 

152 # Expanded version of lines could be dict 

153 # with fnmatch, regex, exclude patterns, etc? 

154 code: list[str] = [] 

155 

156 # This only filtering field that can't be used with raw diff/files... 

157 # If we get into that, the others are: 

158 # - labels 

159 # - ref (have branches at the root level...) 

160 # - statuses 

161 # - dates 

162 # - body 

163 # - title 

164 # - other scopes 

165 # (this is how I ended up with expressions... 

166 # I'm not trying to build a general purpose workflow tool, 

167 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy) 

168 authors: list[str] = [] 

169 branches: list[str] = [] 

170 

171 # (defaults should be the "empty" values) 

172 description: str = "" 

173 reviewers: list[str] = [] 

174 alternates: list[str] = [] 

175 cc: list[str] = [] 

176 

177 # Review scoring 

178 require: int = 0 

179 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY 

180 author_value: int = 0 

181 

182 # How scopes are combined 

183 ownership: OwnershipChoices = OwnershipChoices.EMPTY 

184 

185 # Actionable items 

186 request: int = 0 

187 labels: list[str] = [] 

188 instructions: str = "" 

189 

190 # Approval checklist 

191 checklist: Checklist | None = None 

192 

193 @field_validator("name", mode="after") 

194 @classmethod 

195 def validate_name(cls, name: str) -> str: 

196 if "," in name: 

197 raise ValueError("Scope name cannot contain commas") 

198 return name 

199 

200 @field_validator("code", mode="after") 

201 @classmethod 

202 def validate_code_patterns(cls, code: list[str]) -> list[str]: 

203 for pattern in code: 

204 try: 

205 parsed = sre_parse.parse(pattern) 

206 except re.error as e: 

207 raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from None 

208 if _has_nested_quantifiers(parsed): 

209 raise ValueError( 

210 f"Regex pattern '{pattern}' contains nested quantifiers, " 

211 "which can cause catastrophic backtracking." 

212 ) 

213 return code 

214 

215 @model_validator(mode="after") 

216 def validate_reviewers_for_require(self) -> ScopeModel: 

217 all_reviewers = self.reviewers + self.alternates 

218 

219 # Skip if wildcard - anyone can review 

220 if "*" in all_reviewers: 

221 return self 

222 

223 # Skip if aliases not yet expanded (will validate again after compilation) 

224 if any(r.startswith("$") for r in all_reviewers): 

225 return self 

226 

227 if len(all_reviewers) < self.require: 

228 raise ValueError( 

229 f"has require={self.require} but only {len(all_reviewers)} reviewers/alternates specified" 

230 ) 

231 return self 

232 

233 @model_validator(mode="after") 

234 def validate_checklist_reviewed_for(self) -> ScopeModel: 

235 if self.checklist and self.reviewed_for == ReviewedForChoices.REQUIRED: 

236 raise ValueError( 

237 "checklist and reviewed_for='required' cannot be used together. " 

238 "The checklist already requires explicit scope acknowledgment." 

239 ) 

240 return self 

241 

242 def printed_name(self) -> str: 

243 match self.ownership: 

244 case OwnershipChoices.APPEND: 

245 return "+" + self.name 

246 case OwnershipChoices.GLOBAL: 

247 return "*" + self.name 

248 

249 return self.name 

250 

251 def __eq__(self, other: Any) -> bool: 

252 return self.name == other.name 

253 

254 def matches_path(self, path: Path) -> bool: 

255 # TODO paths shouldn't start with / 

256 return glob.globmatch( 

257 path, 

258 self.paths, 

259 flags=glob.GLOBSTAR 

260 | glob.BRACE 

261 | glob.NEGATE 

262 | glob.IGNORECASE 

263 | glob.DOTGLOB, 

264 ) 

265 

266 def matches_code(self, code: str) -> Generator[dict[str, int]]: 

267 patterns = getattr(self, "_code_regex_patterns", []) 

268 if not patterns: 

269 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code] 

270 self._code_regex_patterns = patterns 

271 

272 for pattern in patterns: 

273 for match in pattern.finditer(code): 

274 start_index = match.start() 

275 end_index = match.end() 

276 

277 start_line = code.count("\n", 0, start_index) + 1 

278 start_col = start_index - code.rfind("\n", 0, start_index) 

279 

280 end_line = code.count("\n", 0, end_index) + 1 

281 end_col = end_index - code.rfind("\n", 0, end_index) 

282 

283 yield { 

284 "start_line": start_line, 

285 "start_col": start_col, 

286 "end_line": end_line, 

287 "end_col": end_col, 

288 } 

289 

290 def matches_author(self, author_username: str) -> bool: 

291 if not self.authors: 

292 # No authors specified, so assume it matches 

293 return True 

294 

295 author_username_lower = author_username.lower() 

296 

297 negated_authors = [a[1:].lower() for a in self.authors if a.startswith("!")] 

298 authors = [a.lower() for a in self.authors if not a.startswith("!")] 

299 

300 if author_username_lower in negated_authors: 

301 # If the author is in the negated list, return False 

302 return False 

303 

304 if not authors: 

305 # Negation-only: everyone not negated matches 

306 return True 

307 

308 if author_username_lower in authors: 

309 # If the author is in the authors list, return True 

310 return True 

311 

312 return False 

313 

314 def matches_branches(self, base_branch: str, head_branch: str) -> bool: 

315 return _matches_branches(self.branches, base_branch, head_branch) 

316 

317 def enabled_for_pullrequest( 

318 self, author_username: str, base_branch: str, head_branch: str 

319 ) -> bool: 

320 # Paths/code are matched during diff parsing, 

321 # but we also consider authors and branches in the context of a pull request. 

322 return self.matches_author(author_username) and self.matches_branches( 

323 base_branch, head_branch 

324 ) 

325 

326 

327class LargeScaleChangeModel(BaseModel): 

328 model_config = ConfigDict(extra="forbid") 

329 

330 # Note, an LSC only applies to diffs, not raw files, 

331 # because we have to know what *changed*. 

332 

333 # Pretty similar to a scope, but more manual. 

334 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one. 

335 require: int = 1 

336 reviewers: list[str] = [] # Field(min_length=1) 

337 # min_paths: int = 300 

338 # min_lines: int = 3000 

339 labels: list[str] = [] 

340 # really need author value too...? 

341 

342 

343class ConfigModel(BaseModel): 

344 model_config = ConfigDict(extra="forbid") 

345 

346 # Nothing is technically required 

347 extends: list[str] = [] 

348 template: bool = False 

349 branches: list[str] = [] 

350 aliases: dict[str, list[str]] = {} 

351 large_scale_change: LargeScaleChangeModel | None = None 

352 scopes: list[ScopeModel] = [] 

353 

354 @field_validator("scopes", mode="after") 

355 @classmethod 

356 def validate_unique_scope_names(cls, scopes: list[ScopeModel]) -> list[ScopeModel]: 

357 seen: set[str] = set() 

358 for scope in scopes: 

359 if scope.name.lower() in seen: 

360 raise ValueError(f"Duplicate scope name: {scope.name}") 

361 seen.add(scope.name.lower()) 

362 

363 return scopes 

364 

365 @field_validator("extends", mode="before") 

366 @classmethod 

367 def validate_extends(cls, extends: list[str]) -> list[str]: 

368 for i, path in enumerate(extends): 

369 basename = Path(path).name 

370 if not basename.startswith(CONFIG_FILENAME_PREFIX): 

371 raise ValueError( 

372 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'." 

373 ) 

374 return extends 

375 

376 def compiled_config( 

377 self, config_path: Path, other_configs: ConfigModels 

378 ) -> ConfigModel: 

379 """ 

380 Merge extends, replace aliases. 

381 """ 

382 

383 if getattr(self, "_compiled_config", None) is not None: 

384 return self._compiled_config 

385 

386 # Create a copy of the data from what we have currently 

387 compiled_data = self.model_dump() 

388 

389 for extend_path in self.extends: 

390 if extend_path not in other_configs: 

391 raise ValueError(f"Config {extend_path} not found") 

392 

393 extended_config = other_configs[extend_path] 

394 extended_config.compiled_config(Path(extend_path), other_configs) 

395 

396 extended_config_dumped = extended_config.model_dump( 

397 include={"branches", "aliases", "scopes", "large_scale_change"} 

398 ) 

399 

400 compiled_data["scopes"] = ( 

401 extended_config_dumped["scopes"] + compiled_data["scopes"] 

402 ) 

403 compiled_data["large_scale_change"] = ( 

404 compiled_data["large_scale_change"] 

405 or extended_config_dumped["large_scale_change"] 

406 ) 

407 compiled_data["aliases"] = ( 

408 extended_config_dumped["aliases"] | compiled_data["aliases"] 

409 ) 

410 compiled_data["branches"] = ( 

411 extended_config_dumped["branches"] + compiled_data["branches"] 

412 ) 

413 

414 # Root aliases 

415 for field in ["extends", "branches"]: 

416 if field in compiled_data: 

417 compiled_data[field] = _expand_aliases( 

418 compiled_data[field], compiled_data["aliases"] 

419 ) 

420 

421 # Expand aliases for any aliasable list fields 

422 for scope in compiled_data["scopes"]: 

423 for field in [ 

424 "paths", 

425 "code", 

426 "authors", 

427 "branches", 

428 "reviewers", 

429 "alternates", 

430 "cc", 

431 "labels", 

432 ]: 

433 if field in scope: 

434 scope[field] = _expand_aliases( 

435 scope[field], compiled_data["aliases"] 

436 ) 

437 

438 if large_scale_change := compiled_data.get("large_scale_change"): 

439 for field in ["reviewers", "labels"]: 

440 large_scale_change[field] = _expand_aliases( 

441 large_scale_change[field], 

442 compiled_data["aliases"], 

443 ) 

444 

445 # Create a new config from the merged data 

446 self._compiled_config = ConfigModel.from_data( 

447 data=compiled_data, 

448 path=config_path, 

449 ) 

450 

451 return self._compiled_config 

452 

453 @classmethod 

454 def from_filesystem(cls, path: Path | str) -> ConfigModel: 

455 with open(path, "rb") as f: 

456 return cls.from_data(tomllib.load(f), path) 

457 

458 @classmethod 

459 def from_content(cls, content: str, path: Path | str) -> ConfigModel: 

460 return cls.from_data(tomllib.loads(content), path) 

461 

462 @classmethod 

463 def from_data(cls, data: dict[str, Any], path: Path | str) -> ConfigModel: 

464 # config = cls(path) 

465 

466 # config.data = data 

467 # config = ConfigModel(**config.data) 

468 

469 return cls(**data) 

470 

471 def matches_branches(self, base_branch: str, head_branch: str) -> bool: 

472 return _matches_branches(self.branches, base_branch, head_branch) 

473 

474 def enabled_for_pullrequest(self, base_branch: str, head_branch: str) -> bool: 

475 return self.matches_branches(base_branch, head_branch) 

476 

477 

478class ConfigModels(RootModel): 

479 root: dict[str, ConfigModel] 

480 

481 # def __init__(self, configs: dict[str, CodeReviewConfig] = None): 

482 # if configs is None: 

483 # configs = {} 

484 # self = configs 

485 

486 # def __repr__(self): 

487 # return f"CodeReviewConfigs({self})" 

488 

489 @classmethod 

490 def from_configs_data(cls, data: dict[str, Any]) -> ConfigModels: 

491 """Load configs from a dict of data""" 

492 configs = cls(root={}) 

493 

494 for path, config_data in data.items(): 

495 config = ConfigModel.from_data(config_data, Path(path)) 

496 configs.add_config(config, Path(path)) 

497 

498 return configs 

499 

500 @classmethod 

501 def from_config_models(cls, models: dict[str, ConfigModel]) -> ConfigModels: 

502 """Load configs from a dict of models""" 

503 configs = cls(root={}) 

504 

505 for path, config_model in models.items(): 

506 # config = ConfigModel.from_model(config_model, Path(path)) 

507 configs.add_config(config_model, Path(path)) 

508 

509 return configs 

510 

511 def get_config_models(self) -> dict[str, ConfigModel]: 

512 return dict(self.root.items()) 

513 

514 def add_config(self, config: ConfigModel, path: Path) -> None: 

515 self.root[str(path)] = config 

516 

517 def get_default_large_scale_change(self) -> LargeScaleChangeModel: 

518 """Get the root config, which is the first one found in the list""" 

519 primary_config = CONFIG_FILENAME 

520 

521 if primary_config in self.root: 

522 if lsc := self.root[primary_config].large_scale_change: 

523 return lsc 

524 

525 return LargeScaleChangeModel() 

526 

527 def __bool__(self) -> bool: 

528 return bool(self.root) 

529 

530 def __getitem__(self, key: str) -> ConfigModel: 

531 return self.root[key] 

532 

533 def __contains__(self, key: str) -> bool: 

534 return key in self.root 

535 

536 def __len__(self) -> int: 

537 return len(self.root) 

538 

539 def compile_closest_config(self, file_path: Path) -> ConfigModel: 

540 """Find the closest config file to this file""" 

541 for parent in file_path.parents: 

542 parent_config_path = str(parent / CONFIG_FILENAME) 

543 

544 if parent_config_path in self.root: 

545 config = self.root[parent_config_path] 

546 

547 if config.template: 

548 # Skip templates 

549 continue 

550 

551 compiled = config.compiled_config(Path(parent_config_path), self) 

552 

553 return compiled 

554 

555 raise ValueError(f"No config found for {file_path}") 

556 

557 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel]]: 

558 for config_path, config in self.root.items(): 

559 if config.template: 

560 # Skip templates 

561 continue 

562 

563 yield config_path, config.compiled_config(Path(config_path), self) 

564 

565 def num_scopes(self) -> int: 

566 """ 

567 Count the total number of scopes across all configs. 

568 """ 

569 return sum(len(config.scopes) for config in self.root.values()) 

570 

571 def num_reviewers(self) -> int: 

572 """ 

573 Count the total number of reviewers across all configs. 

574 """ 

575 return sum( 

576 len(scope.reviewers) 

577 for config in self.root.values() 

578 for scope in config.scopes 

579 ) 

580 

581 def filter_for_pullrequest( 

582 self, 

583 base_branch: str, 

584 head_branch: str, 

585 author_username: str, 

586 ) -> ConfigModels: 

587 """ 

588 Look at all configs (including templates) and filter out 

589 configs and scopes based on branches, authors, etc. 

590 """ 

591 filtered_configs = {} 

592 

593 for config_path, config in self.root.items(): 

594 compiled_config = config.compiled_config(Path(config_path), self) 

595 

596 if not compiled_config.enabled_for_pullrequest(base_branch, head_branch): 

597 # Remove the config from the list 

598 continue 

599 

600 # Collect the names of scopes to remove 

601 scopes_to_remove = set() 

602 for scope in compiled_config.scopes: 

603 if not scope.enabled_for_pullrequest( 

604 author_username, base_branch, head_branch 

605 ): 

606 scopes_to_remove.add(scope.name) 

607 

608 # Create a filtered copy of the config data 

609 filtered_config_data = config.model_dump() 

610 filtered_config_data["scopes"] = [ 

611 scope 

612 for scope in filtered_config_data["scopes"] 

613 if scope["name"] not in scopes_to_remove 

614 ] 

615 

616 # Rebuild using the original/modified raw data 

617 filtered_configs[config_path] = filtered_config_data 

618 

619 return ConfigModels.from_configs_data(filtered_configs)