Coverage for src/pullapprove/config.py: 79%
314 statements
« prev ^ index » next coverage.py v7.8.2, created at 2026-03-16 10:17 -0500
« prev ^ index » next coverage.py v7.8.2, created at 2026-03-16 10:17 -0500
1from __future__ import annotations
3import re
4import tomllib
5import warnings
7with warnings.catch_warnings():
8 warnings.simplefilter("ignore", DeprecationWarning)
9 import sre_parse
10from collections.abc import Generator
11from enum import Enum
12from pathlib import Path
13from typing import Any
15from pydantic import (
16 BaseModel,
17 ConfigDict,
18 Field,
19 RootModel,
20 field_validator,
21 model_validator,
22)
23from wcmatch import glob
25from .checklists import Checklist
27CONFIG_FILENAME_PREFIX = "CODEREVIEW"
29_REPEAT_OPS = {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}
32def _has_nested_quantifiers(data: Any) -> bool:
33 """Detect patterns like (a+)+ that cause catastrophic backtracking."""
34 for op, av in data:
35 if op in _REPEAT_OPS:
36 if _contains_quantifier(av[2]):
37 return True
38 elif op == sre_parse.SUBPATTERN:
39 if _has_nested_quantifiers(av[-1]):
40 return True
41 elif op == sre_parse.BRANCH:
42 if any(_has_nested_quantifiers(branch) for branch in av[1]):
43 return True
44 return False
47def _contains_quantifier(data: Any) -> bool:
48 for op, av in data:
49 if op in _REPEAT_OPS:
50 return True
51 elif op == sre_parse.SUBPATTERN:
52 if _contains_quantifier(av[-1]):
53 return True
54 elif op == sre_parse.BRANCH:
55 if any(_contains_quantifier(branch) for branch in av[1]):
56 return True
57 return False
60CONFIG_FILENAME = "CODEREVIEW.toml"
63def _expand_aliases(
64 values: list[str],
65 aliases: dict[str, list[str]],
66 _seen: set[str] | None = None,
67 _path: list[str] | None = None,
68) -> list[str]:
69 """Replace alias references in a list with their mapped values recursively."""
70 if _seen is None:
71 _seen = set()
72 if _path is None:
73 _path = []
75 expanded: list[str] = []
76 for value in values:
77 # Support negated aliases like "!$team" -> ["!alice", "!bob"]
78 if value.startswith("!$"):
79 prefix = "!"
80 alias_ref = value[2:]
81 elif value.startswith("$"):
82 prefix = ""
83 alias_ref = value[1:]
84 else:
85 expanded.append(value)
86 continue
88 if alias_ref in _seen:
89 # Cycle detected, raise an error with the cycle path
90 cycle_path = _path[_path.index(alias_ref) :] + [alias_ref]
91 raise ValueError(
92 f"Circular reference detected in aliases: {' -> '.join(cycle_path)}"
93 )
94 if alias_ref in aliases:
95 _seen.add(alias_ref)
96 _path.append(alias_ref)
97 # Recursively expand the alias values
98 nested_expanded = _expand_aliases(aliases[alias_ref], aliases, _seen, _path)
99 if prefix:
100 expanded.extend(prefix + v for v in nested_expanded)
101 else:
102 expanded.extend(nested_expanded)
103 _path.pop()
104 _seen.remove(alias_ref)
106 # Remove duplicates while preserving order
107 return list(dict.fromkeys(expanded))
110class ReviewedForChoices(str, Enum):
111 EMPTY = ""
112 REQUIRED = "required"
113 IGNORED = "ignored"
116class OwnershipChoices(str, Enum):
117 EMPTY = ""
118 APPEND = "append"
119 GLOBAL = "global"
122class ScopeModel(BaseModel):
123 model_config = ConfigDict(extra="forbid")
125 # Required fields
126 name: str = Field(min_length=1)
127 paths: list[str] = Field(min_length=1)
129 # Optional fields
131 # Expanded version of lines could be dict
132 # with fnmatch, regex, exclude patterns, etc?
133 code: list[str] = []
135 # This only filtering field that can't be used with raw diff/files...
136 # If we get into that, the others are:
137 # - labels
138 # - ref (have branches at the root level...)
139 # - statuses
140 # - dates
141 # - body
142 # - title
143 # - other scopes
144 # (this is how I ended up with expressions...
145 # I'm not trying to build a general purpose workflow tool,
146 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy)
147 authors: list[str] = []
149 # (defaults should be the "empty" values)
150 description: str = ""
151 reviewers: list[str] = []
152 alternates: list[str] = []
153 cc: list[str] = []
155 # Review scoring
156 require: int = 0
157 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY
158 author_value: int = 0
160 # How scopes are combined
161 ownership: OwnershipChoices = OwnershipChoices.EMPTY
163 # Actionable items
164 request: int = 0
165 labels: list[str] = []
166 instructions: str = ""
168 # Approval checklist
169 checklist: Checklist | None = None
171 @field_validator("name", mode="after")
172 @classmethod
173 def validate_name(cls, name: str) -> str:
174 if "," in name:
175 raise ValueError("Scope name cannot contain commas")
176 return name
178 @field_validator("code", mode="after")
179 @classmethod
180 def validate_code_patterns(cls, code: list[str]) -> list[str]:
181 for pattern in code:
182 try:
183 parsed = sre_parse.parse(pattern)
184 except re.error as e:
185 raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from None
186 if _has_nested_quantifiers(parsed):
187 raise ValueError(
188 f"Regex pattern '{pattern}' contains nested quantifiers, "
189 "which can cause catastrophic backtracking."
190 )
191 return code
193 @model_validator(mode="after")
194 def validate_reviewers_for_require(self) -> ScopeModel:
195 all_reviewers = self.reviewers + self.alternates
197 # Skip if wildcard - anyone can review
198 if "*" in all_reviewers:
199 return self
201 # Skip if aliases not yet expanded (will validate again after compilation)
202 if any(r.startswith("$") for r in all_reviewers):
203 return self
205 if len(all_reviewers) < self.require:
206 raise ValueError(
207 f"has require={self.require} but only {len(all_reviewers)} reviewers/alternates specified"
208 )
209 return self
211 @model_validator(mode="after")
212 def validate_checklist_reviewed_for(self) -> ScopeModel:
213 if self.checklist and self.reviewed_for == ReviewedForChoices.REQUIRED:
214 raise ValueError(
215 "checklist and reviewed_for='required' cannot be used together. "
216 "The checklist already requires explicit scope acknowledgment."
217 )
218 return self
220 def printed_name(self) -> str:
221 match self.ownership:
222 case OwnershipChoices.APPEND:
223 return "+" + self.name
224 case OwnershipChoices.GLOBAL:
225 return "*" + self.name
227 return self.name
229 def __eq__(self, other: Any) -> bool:
230 return self.name == other.name
232 def matches_path(self, path: Path) -> bool:
233 # TODO paths shouldn't start with /
234 return glob.globmatch(
235 path,
236 self.paths,
237 flags=glob.GLOBSTAR
238 | glob.BRACE
239 | glob.NEGATE
240 | glob.IGNORECASE
241 | glob.DOTGLOB,
242 )
244 def matches_code(self, code: str) -> Generator[dict[str, int]]:
245 patterns = getattr(self, "_code_regex_patterns", [])
246 if not patterns:
247 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code]
248 self._code_regex_patterns = patterns
250 for pattern in patterns:
251 for match in pattern.finditer(code):
252 start_index = match.start()
253 end_index = match.end()
255 start_line = code.count("\n", 0, start_index) + 1
256 start_col = start_index - code.rfind("\n", 0, start_index)
258 end_line = code.count("\n", 0, end_index) + 1
259 end_col = end_index - code.rfind("\n", 0, end_index)
261 yield {
262 "start_line": start_line,
263 "start_col": start_col,
264 "end_line": end_line,
265 "end_col": end_col,
266 }
268 def matches_author(self, author_username: str) -> bool:
269 if not self.authors:
270 # No authors specified, so assume it matches
271 return True
273 author_username_lower = author_username.lower()
275 negated_authors = [a[1:].lower() for a in self.authors if a.startswith("!")]
276 authors = [a.lower() for a in self.authors if not a.startswith("!")]
278 if author_username_lower in negated_authors:
279 # If the author is in the negated list, return False
280 return False
282 if not authors:
283 # Negation-only: everyone not negated matches
284 return True
286 if author_username_lower in authors:
287 # If the author is in the authors list, return True
288 return True
290 return False
292 def enabled_for_pullrequest(self, author_username: str) -> bool:
293 # Paths/code are matched during diff parsing,
294 # but we also consider authors in the context of a pull request so do that here.
295 return self.matches_author(author_username)
298class LargeScaleChangeModel(BaseModel):
299 model_config = ConfigDict(extra="forbid")
301 # Note, an LSC only applies to diffs, not raw files,
302 # because we have to know what *changed*.
304 # Pretty similar to a scope, but more manual.
305 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one.
306 require: int = 1
307 reviewers: list[str] = [] # Field(min_length=1)
308 # min_paths: int = 300
309 # min_lines: int = 3000
310 labels: list[str] = []
311 # really need author value too...?
314class ConfigModel(BaseModel):
315 model_config = ConfigDict(extra="forbid")
317 # Nothing is technically required
318 extends: list[str] = []
319 template: bool = False
320 branches: list[str] = []
321 aliases: dict[str, list[str]] = {}
322 large_scale_change: LargeScaleChangeModel | None = None
323 scopes: list[ScopeModel] = []
325 @field_validator("scopes", mode="after")
326 @classmethod
327 def validate_unique_scope_names(cls, scopes: list[ScopeModel]) -> list[ScopeModel]:
328 seen: set[str] = set()
329 for scope in scopes:
330 if scope.name.lower() in seen:
331 raise ValueError(f"Duplicate scope name: {scope.name}")
332 seen.add(scope.name.lower())
334 return scopes
336 @field_validator("extends", mode="before")
337 @classmethod
338 def validate_extends(cls, extends: list[str]) -> list[str]:
339 for i, path in enumerate(extends):
340 basename = Path(path).name
341 if not basename.startswith(CONFIG_FILENAME_PREFIX):
342 raise ValueError(
343 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'."
344 )
345 return extends
347 def compiled_config(
348 self, config_path: Path, other_configs: ConfigModels
349 ) -> ConfigModel:
350 """
351 Merge extends, replace aliases.
352 """
354 if getattr(self, "_compiled_config", None) is not None:
355 return self._compiled_config
357 # Create a copy of the data from what we have currently
358 compiled_data = self.model_dump()
360 for extend_path in self.extends:
361 if extend_path not in other_configs:
362 raise ValueError(f"Config {extend_path} not found")
364 extended_config = other_configs[extend_path]
365 extended_config.compiled_config(Path(extend_path), other_configs)
367 extended_config_dumped = extended_config.model_dump(
368 include={"branches", "aliases", "scopes", "large_scale_change"}
369 )
371 compiled_data["scopes"] = (
372 extended_config_dumped["scopes"] + compiled_data["scopes"]
373 )
374 compiled_data["large_scale_change"] = (
375 compiled_data["large_scale_change"]
376 or extended_config_dumped["large_scale_change"]
377 )
378 compiled_data["aliases"] = (
379 extended_config_dumped["aliases"] | compiled_data["aliases"]
380 )
381 compiled_data["branches"] = (
382 extended_config_dumped["branches"] + compiled_data["branches"]
383 )
385 # Root aliases
386 for field in ["extends", "branches"]:
387 if field in compiled_data:
388 compiled_data[field] = _expand_aliases(
389 compiled_data[field], compiled_data["aliases"]
390 )
392 # Expand aliases for any aliasable list fields
393 for scope in compiled_data["scopes"]:
394 for field in [
395 "paths",
396 "code",
397 "authors",
398 "reviewers",
399 "alternates",
400 "cc",
401 "labels",
402 ]:
403 if field in scope:
404 scope[field] = _expand_aliases(
405 scope[field], compiled_data["aliases"]
406 )
408 if large_scale_change := compiled_data.get("large_scale_change"):
409 for field in ["reviewers", "labels"]:
410 large_scale_change[field] = _expand_aliases(
411 large_scale_change[field],
412 compiled_data["aliases"],
413 )
415 # Create a new config from the merged data
416 self._compiled_config = ConfigModel.from_data(
417 data=compiled_data,
418 path=config_path,
419 )
421 return self._compiled_config
423 @classmethod
424 def from_filesystem(cls, path: Path | str) -> ConfigModel:
425 with open(path, "rb") as f:
426 return cls.from_data(tomllib.load(f), path)
428 @classmethod
429 def from_content(cls, content: str, path: Path | str) -> ConfigModel:
430 return cls.from_data(tomllib.loads(content), path)
432 @classmethod
433 def from_data(cls, data: dict[str, Any], path: Path | str) -> ConfigModel:
434 # config = cls(path)
436 # config.data = data
437 # config = ConfigModel(**config.data)
439 return cls(**data)
441 def matches_branches(self, base_branch: str, head_branch: str) -> bool:
442 if not self.branches:
443 # No branches specified, so assume it matches
444 return True
446 for pattern in self.branches:
447 splitter = "..." if "..." in pattern else ".."
448 parts = pattern.split(splitter)
449 base_pattern = parts[0]
450 head_pattern = parts[1] if len(parts) > 1 else None
452 base_match = (
453 glob.globmatch(base_branch, base_pattern) if base_pattern else True
454 )
455 head_match = (
456 glob.globmatch(head_branch, head_pattern) if head_pattern else True
457 )
459 if base_match and head_match:
460 return True
462 return False
464 def enabled_for_pullrequest(self, base_branch: str, head_branch: str) -> bool:
465 return self.matches_branches(base_branch, head_branch)
467 # Kinda want the original toml if you dump? with comments etc
468 # def as_toml(self) -> str:
469 # """
470 # Convert the config to a TOML string.
471 # """
472 # return tomllib.dumps(self.model_dump())
474 # def matches_branch(self, base_branch: str, head_branch: str) -> bool:
475 # for pattern in self.branches:
476 # splitter = "..." if "..." in pattern else ".."
477 # parts = pattern.split(splitter)
478 # base_pattern = parts[0]
479 # head_pattern = parts[1] if len(parts) > 1 else None
481 # base_match = fnmatch.fnmatch(base_branch, base_pattern) if base_pattern else True
482 # head_match = fnmatch.fnmatch(head_branch, head_pattern) if head_pattern else True
484 # if base_match and head_match:
485 # return True
487 # return False
490class ConfigModels(RootModel):
491 root: dict[str, ConfigModel]
493 # def __init__(self, configs: dict[str, CodeReviewConfig] = None):
494 # if configs is None:
495 # configs = {}
496 # self = configs
498 # def __repr__(self):
499 # return f"CodeReviewConfigs({self})"
501 @classmethod
502 def from_configs_data(cls, data: dict[str, Any]) -> ConfigModels:
503 """Load configs from a dict of data"""
504 configs = cls(root={})
506 for path, config_data in data.items():
507 config = ConfigModel.from_data(config_data, Path(path))
508 configs.add_config(config, Path(path))
510 return configs
512 @classmethod
513 def from_config_models(cls, models: dict[str, ConfigModel]) -> ConfigModels:
514 """Load configs from a dict of models"""
515 configs = cls(root={})
517 for path, config_model in models.items():
518 # config = ConfigModel.from_model(config_model, Path(path))
519 configs.add_config(config_model, Path(path))
521 return configs
523 def get_config_models(self) -> dict[str, ConfigModel]:
524 return dict(self.root.items())
526 def add_config(self, config: ConfigModel, path: Path) -> None:
527 self.root[str(path)] = config
529 def get_default_large_scale_change(self) -> LargeScaleChangeModel:
530 """Get the root config, which is the first one found in the list"""
531 primary_config = CONFIG_FILENAME
533 if primary_config in self.root:
534 if lsc := self.root[primary_config].large_scale_change:
535 return lsc
537 return LargeScaleChangeModel()
539 def __bool__(self) -> bool:
540 return bool(self.root)
542 def __getitem__(self, key: str) -> ConfigModel:
543 return self.root[key]
545 def __contains__(self, key: str) -> bool:
546 return key in self.root
548 def __len__(self) -> int:
549 return len(self.root)
551 def compile_closest_config(self, file_path: Path) -> ConfigModel:
552 """Find the closest config file to this file"""
553 for parent in file_path.parents:
554 parent_config_path = str(parent / CONFIG_FILENAME)
556 if parent_config_path in self.root:
557 config = self.root[parent_config_path]
559 if config.template:
560 # Skip templates
561 continue
563 compiled = config.compiled_config(Path(parent_config_path), self)
565 return compiled
567 raise ValueError(f"No config found for {file_path}")
569 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel]]:
570 for config_path, config in self.root.items():
571 if config.template:
572 # Skip templates
573 continue
575 yield config_path, config.compiled_config(Path(config_path), self)
577 def num_scopes(self) -> int:
578 """
579 Count the total number of scopes across all configs.
580 """
581 return sum(len(config.scopes) for config in self.root.values())
583 def num_reviewers(self) -> int:
584 """
585 Count the total number of reviewers across all configs.
586 """
587 return sum(
588 len(scope.reviewers)
589 for config in self.root.values()
590 for scope in config.scopes
591 )
593 def filter_for_pullrequest(
594 self,
595 base_branch: str,
596 head_branch: str,
597 author_username: str,
598 ) -> ConfigModels:
599 """
600 Look at all configs (including templates) and filter out
601 configs and scopes based on branches, authors, etc.
602 """
603 filtered_configs = {}
605 for config_path, config in self.root.items():
606 compiled_config = config.compiled_config(Path(config_path), self)
608 if not compiled_config.enabled_for_pullrequest(base_branch, head_branch):
609 # Remove the config from the list
610 continue
612 # Collect the names of scopes to remove
613 scopes_to_remove = set()
614 for scope in compiled_config.scopes:
615 if not scope.enabled_for_pullrequest(author_username):
616 scopes_to_remove.add(scope.name)
618 # Create a filtered copy of the config data
619 filtered_config_data = config.model_dump()
621 # Filter scopes in a simple, declarative way
622 if "scopes" in filtered_config_data:
623 filtered_config_data["scopes"] = [
624 scope
625 for scope in filtered_config_data["scopes"]
626 if scope["name"] not in scopes_to_remove
627 ]
629 # Rebuild using the origial/modified raw data
630 filtered_configs[config_path] = filtered_config_data
632 return ConfigModels.from_configs_data(filtered_configs)
634 # for config in configs.iter_compiled_configs():
635 # if config.enabled_for_pullrequest(self):
636 # filtered_configs.add_config(config)
638 # # Paths/code are similar, but we need to iterate the diff to process them,
639 # # so this is almost a pre-process for metadata of the PR
640 # scopes_to_remove = []
641 # for scope in config.scopes:
642 # if not scope.enabled_for_pullrequest(self):
643 # scopes_to_remove.append(scope)
644 # for scope in scopes_to_remove:
645 # config.scopes.remove(scope)