Coverage for src/pullapprove/config.py: 88%
318 statements
« prev ^ index » next coverage.py v7.8.2, created at 2026-03-16 16:39 -0500
« prev ^ index » next coverage.py v7.8.2, created at 2026-03-16 16:39 -0500
1from __future__ import annotations
3import re
4import tomllib
5import warnings
7with warnings.catch_warnings():
8 warnings.simplefilter("ignore", DeprecationWarning)
9 import sre_parse
10from collections.abc import Generator
11from enum import Enum
12from pathlib import Path
13from typing import Any
15from pydantic import (
16 BaseModel,
17 ConfigDict,
18 Field,
19 RootModel,
20 field_validator,
21 model_validator,
22)
23from wcmatch import glob
25from .checklists import Checklist
27CONFIG_FILENAME_PREFIX = "CODEREVIEW"
29_REPEAT_OPS = {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}
32def _has_nested_quantifiers(data: Any) -> bool:
33 """Detect patterns like (a+)+ that cause catastrophic backtracking."""
34 for op, av in data:
35 if op in _REPEAT_OPS:
36 if _contains_quantifier(av[2]):
37 return True
38 elif op == sre_parse.SUBPATTERN:
39 if _has_nested_quantifiers(av[-1]):
40 return True
41 elif op == sre_parse.BRANCH:
42 if any(_has_nested_quantifiers(branch) for branch in av[1]):
43 return True
44 return False
47def _contains_quantifier(data: Any) -> bool:
48 for op, av in data:
49 if op in _REPEAT_OPS:
50 return True
51 elif op == sre_parse.SUBPATTERN:
52 if _contains_quantifier(av[-1]):
53 return True
54 elif op == sre_parse.BRANCH:
55 if any(_contains_quantifier(branch) for branch in av[1]):
56 return True
57 return False
60CONFIG_FILENAME = "CODEREVIEW.toml"
63def _matches_branches(branches: list[str], base_branch: str, head_branch: str) -> bool:
64 """Check if the given base/head branches match any of the branch patterns."""
65 if not branches:
66 return True
68 for pattern in branches:
69 splitter = "..." if "..." in pattern else ".."
70 parts = pattern.split(splitter, 1)
72 base_pattern = parts[0]
73 head_pattern = parts[1] if len(parts) > 1 else None
75 base_match = glob.globmatch(base_branch, base_pattern) if base_pattern else True
76 head_match = glob.globmatch(head_branch, head_pattern) if head_pattern else True
78 if base_match and head_match:
79 return True
81 return False
84def _expand_aliases(
85 values: list[str],
86 aliases: dict[str, list[str]],
87 _seen: set[str] | None = None,
88 _path: list[str] | None = None,
89) -> list[str]:
90 """Replace alias references in a list with their mapped values recursively."""
91 if _seen is None:
92 _seen = set()
93 if _path is None:
94 _path = []
96 expanded: list[str] = []
97 for value in values:
98 # Support negated aliases like "!$team" -> ["!alice", "!bob"]
99 if value.startswith("!$"):
100 prefix = "!"
101 alias_ref = value[2:]
102 elif value.startswith("$"):
103 prefix = ""
104 alias_ref = value[1:]
105 else:
106 expanded.append(value)
107 continue
109 if alias_ref in _seen:
110 # Cycle detected, raise an error with the cycle path
111 cycle_path = _path[_path.index(alias_ref) :] + [alias_ref]
112 raise ValueError(
113 f"Circular reference detected in aliases: {' -> '.join(cycle_path)}"
114 )
115 if alias_ref in aliases:
116 _seen.add(alias_ref)
117 _path.append(alias_ref)
118 # Recursively expand the alias values
119 nested_expanded = _expand_aliases(aliases[alias_ref], aliases, _seen, _path)
120 if prefix:
121 expanded.extend(prefix + v for v in nested_expanded)
122 else:
123 expanded.extend(nested_expanded)
124 _path.pop()
125 _seen.remove(alias_ref)
127 # Remove duplicates while preserving order
128 return list(dict.fromkeys(expanded))
131class ReviewedForChoices(str, Enum):
132 EMPTY = ""
133 REQUIRED = "required"
134 IGNORED = "ignored"
137class OwnershipChoices(str, Enum):
138 EMPTY = ""
139 APPEND = "append"
140 GLOBAL = "global"
143class ScopeModel(BaseModel):
144 model_config = ConfigDict(extra="forbid")
146 # Required fields
147 name: str = Field(min_length=1)
148 paths: list[str] = Field(min_length=1)
150 # Optional fields
152 # Expanded version of lines could be dict
153 # with fnmatch, regex, exclude patterns, etc?
154 code: list[str] = []
156 # This only filtering field that can't be used with raw diff/files...
157 # If we get into that, the others are:
158 # - labels
159 # - ref (have branches at the root level...)
160 # - statuses
161 # - dates
162 # - body
163 # - title
164 # - other scopes
165 # (this is how I ended up with expressions...
166 # I'm not trying to build a general purpose workflow tool,
167 # but I do need to support the legit use cases and AI/bot review is one, so is team hierarchy)
168 authors: list[str] = []
169 branches: list[str] = []
171 # (defaults should be the "empty" values)
172 description: str = ""
173 reviewers: list[str] = []
174 alternates: list[str] = []
175 cc: list[str] = []
177 # Review scoring
178 require: int = 0
179 reviewed_for: ReviewedForChoices = ReviewedForChoices.EMPTY
180 author_value: int = 0
182 # How scopes are combined
183 ownership: OwnershipChoices = OwnershipChoices.EMPTY
185 # Actionable items
186 request: int = 0
187 labels: list[str] = []
188 instructions: str = ""
190 # Approval checklist
191 checklist: Checklist | None = None
193 @field_validator("name", mode="after")
194 @classmethod
195 def validate_name(cls, name: str) -> str:
196 if "," in name:
197 raise ValueError("Scope name cannot contain commas")
198 return name
200 @field_validator("code", mode="after")
201 @classmethod
202 def validate_code_patterns(cls, code: list[str]) -> list[str]:
203 for pattern in code:
204 try:
205 parsed = sre_parse.parse(pattern)
206 except re.error as e:
207 raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from None
208 if _has_nested_quantifiers(parsed):
209 raise ValueError(
210 f"Regex pattern '{pattern}' contains nested quantifiers, "
211 "which can cause catastrophic backtracking."
212 )
213 return code
215 @model_validator(mode="after")
216 def validate_reviewers_for_require(self) -> ScopeModel:
217 all_reviewers = self.reviewers + self.alternates
219 # Skip if wildcard - anyone can review
220 if "*" in all_reviewers:
221 return self
223 # Skip if aliases not yet expanded (will validate again after compilation)
224 if any(r.startswith("$") for r in all_reviewers):
225 return self
227 if len(all_reviewers) < self.require:
228 raise ValueError(
229 f"has require={self.require} but only {len(all_reviewers)} reviewers/alternates specified"
230 )
231 return self
233 @model_validator(mode="after")
234 def validate_checklist_reviewed_for(self) -> ScopeModel:
235 if self.checklist and self.reviewed_for == ReviewedForChoices.REQUIRED:
236 raise ValueError(
237 "checklist and reviewed_for='required' cannot be used together. "
238 "The checklist already requires explicit scope acknowledgment."
239 )
240 return self
242 def printed_name(self) -> str:
243 match self.ownership:
244 case OwnershipChoices.APPEND:
245 return "+" + self.name
246 case OwnershipChoices.GLOBAL:
247 return "*" + self.name
249 return self.name
251 def __eq__(self, other: Any) -> bool:
252 return self.name == other.name
254 def matches_path(self, path: Path) -> bool:
255 # TODO paths shouldn't start with /
256 return glob.globmatch(
257 path,
258 self.paths,
259 flags=glob.GLOBSTAR
260 | glob.BRACE
261 | glob.NEGATE
262 | glob.IGNORECASE
263 | glob.DOTGLOB,
264 )
266 def matches_code(self, code: str) -> Generator[dict[str, int]]:
267 patterns = getattr(self, "_code_regex_patterns", [])
268 if not patterns:
269 patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.code]
270 self._code_regex_patterns = patterns
272 for pattern in patterns:
273 for match in pattern.finditer(code):
274 start_index = match.start()
275 end_index = match.end()
277 start_line = code.count("\n", 0, start_index) + 1
278 start_col = start_index - code.rfind("\n", 0, start_index)
280 end_line = code.count("\n", 0, end_index) + 1
281 end_col = end_index - code.rfind("\n", 0, end_index)
283 yield {
284 "start_line": start_line,
285 "start_col": start_col,
286 "end_line": end_line,
287 "end_col": end_col,
288 }
290 def matches_author(self, author_username: str) -> bool:
291 if not self.authors:
292 # No authors specified, so assume it matches
293 return True
295 author_username_lower = author_username.lower()
297 negated_authors = [a[1:].lower() for a in self.authors if a.startswith("!")]
298 authors = [a.lower() for a in self.authors if not a.startswith("!")]
300 if author_username_lower in negated_authors:
301 # If the author is in the negated list, return False
302 return False
304 if not authors:
305 # Negation-only: everyone not negated matches
306 return True
308 if author_username_lower in authors:
309 # If the author is in the authors list, return True
310 return True
312 return False
314 def matches_branches(self, base_branch: str, head_branch: str) -> bool:
315 return _matches_branches(self.branches, base_branch, head_branch)
317 def enabled_for_pullrequest(
318 self, author_username: str, base_branch: str, head_branch: str
319 ) -> bool:
320 # Paths/code are matched during diff parsing,
321 # but we also consider authors and branches in the context of a pull request.
322 return self.matches_author(author_username) and self.matches_branches(
323 base_branch, head_branch
324 )
327class LargeScaleChangeModel(BaseModel):
328 model_config = ConfigDict(extra="forbid")
330 # Note, an LSC only applies to diffs, not raw files,
331 # because we have to know what *changed*.
333 # Pretty similar to a scope, but more manual.
334 # There has to be at least one reviewer. So if a LSC config is not defined, an LSC PR error until you add one.
335 require: int = 1
336 reviewers: list[str] = [] # Field(min_length=1)
337 # min_paths: int = 300
338 # min_lines: int = 3000
339 labels: list[str] = []
340 # really need author value too...?
343class ConfigModel(BaseModel):
344 model_config = ConfigDict(extra="forbid")
346 # Nothing is technically required
347 extends: list[str] = []
348 template: bool = False
349 branches: list[str] = []
350 aliases: dict[str, list[str]] = {}
351 large_scale_change: LargeScaleChangeModel | None = None
352 scopes: list[ScopeModel] = []
354 @field_validator("scopes", mode="after")
355 @classmethod
356 def validate_unique_scope_names(cls, scopes: list[ScopeModel]) -> list[ScopeModel]:
357 seen: set[str] = set()
358 for scope in scopes:
359 if scope.name.lower() in seen:
360 raise ValueError(f"Duplicate scope name: {scope.name}")
361 seen.add(scope.name.lower())
363 return scopes
365 @field_validator("extends", mode="before")
366 @classmethod
367 def validate_extends(cls, extends: list[str]) -> list[str]:
368 for i, path in enumerate(extends):
369 basename = Path(path).name
370 if not basename.startswith(CONFIG_FILENAME_PREFIX):
371 raise ValueError(
372 f"Invalid extends path: {path}. It should start with '{CONFIG_FILENAME_PREFIX}'."
373 )
374 return extends
376 def compiled_config(
377 self, config_path: Path, other_configs: ConfigModels
378 ) -> ConfigModel:
379 """
380 Merge extends, replace aliases.
381 """
383 if getattr(self, "_compiled_config", None) is not None:
384 return self._compiled_config
386 # Create a copy of the data from what we have currently
387 compiled_data = self.model_dump()
389 for extend_path in self.extends:
390 if extend_path not in other_configs:
391 raise ValueError(f"Config {extend_path} not found")
393 extended_config = other_configs[extend_path]
394 extended_config.compiled_config(Path(extend_path), other_configs)
396 extended_config_dumped = extended_config.model_dump(
397 include={"branches", "aliases", "scopes", "large_scale_change"}
398 )
400 compiled_data["scopes"] = (
401 extended_config_dumped["scopes"] + compiled_data["scopes"]
402 )
403 compiled_data["large_scale_change"] = (
404 compiled_data["large_scale_change"]
405 or extended_config_dumped["large_scale_change"]
406 )
407 compiled_data["aliases"] = (
408 extended_config_dumped["aliases"] | compiled_data["aliases"]
409 )
410 compiled_data["branches"] = (
411 extended_config_dumped["branches"] + compiled_data["branches"]
412 )
414 # Root aliases
415 for field in ["extends", "branches"]:
416 if field in compiled_data:
417 compiled_data[field] = _expand_aliases(
418 compiled_data[field], compiled_data["aliases"]
419 )
421 # Expand aliases for any aliasable list fields
422 for scope in compiled_data["scopes"]:
423 for field in [
424 "paths",
425 "code",
426 "authors",
427 "branches",
428 "reviewers",
429 "alternates",
430 "cc",
431 "labels",
432 ]:
433 if field in scope:
434 scope[field] = _expand_aliases(
435 scope[field], compiled_data["aliases"]
436 )
438 if large_scale_change := compiled_data.get("large_scale_change"):
439 for field in ["reviewers", "labels"]:
440 large_scale_change[field] = _expand_aliases(
441 large_scale_change[field],
442 compiled_data["aliases"],
443 )
445 # Create a new config from the merged data
446 self._compiled_config = ConfigModel.from_data(
447 data=compiled_data,
448 path=config_path,
449 )
451 return self._compiled_config
453 @classmethod
454 def from_filesystem(cls, path: Path | str) -> ConfigModel:
455 with open(path, "rb") as f:
456 return cls.from_data(tomllib.load(f), path)
458 @classmethod
459 def from_content(cls, content: str, path: Path | str) -> ConfigModel:
460 return cls.from_data(tomllib.loads(content), path)
462 @classmethod
463 def from_data(cls, data: dict[str, Any], path: Path | str) -> ConfigModel:
464 # config = cls(path)
466 # config.data = data
467 # config = ConfigModel(**config.data)
469 return cls(**data)
471 def matches_branches(self, base_branch: str, head_branch: str) -> bool:
472 return _matches_branches(self.branches, base_branch, head_branch)
474 def enabled_for_pullrequest(self, base_branch: str, head_branch: str) -> bool:
475 return self.matches_branches(base_branch, head_branch)
478class ConfigModels(RootModel):
479 root: dict[str, ConfigModel]
481 # def __init__(self, configs: dict[str, CodeReviewConfig] = None):
482 # if configs is None:
483 # configs = {}
484 # self = configs
486 # def __repr__(self):
487 # return f"CodeReviewConfigs({self})"
489 @classmethod
490 def from_configs_data(cls, data: dict[str, Any]) -> ConfigModels:
491 """Load configs from a dict of data"""
492 configs = cls(root={})
494 for path, config_data in data.items():
495 config = ConfigModel.from_data(config_data, Path(path))
496 configs.add_config(config, Path(path))
498 return configs
500 @classmethod
501 def from_config_models(cls, models: dict[str, ConfigModel]) -> ConfigModels:
502 """Load configs from a dict of models"""
503 configs = cls(root={})
505 for path, config_model in models.items():
506 # config = ConfigModel.from_model(config_model, Path(path))
507 configs.add_config(config_model, Path(path))
509 return configs
511 def get_config_models(self) -> dict[str, ConfigModel]:
512 return dict(self.root.items())
514 def add_config(self, config: ConfigModel, path: Path) -> None:
515 self.root[str(path)] = config
517 def get_default_large_scale_change(self) -> LargeScaleChangeModel:
518 """Get the root config, which is the first one found in the list"""
519 primary_config = CONFIG_FILENAME
521 if primary_config in self.root:
522 if lsc := self.root[primary_config].large_scale_change:
523 return lsc
525 return LargeScaleChangeModel()
527 def __bool__(self) -> bool:
528 return bool(self.root)
530 def __getitem__(self, key: str) -> ConfigModel:
531 return self.root[key]
533 def __contains__(self, key: str) -> bool:
534 return key in self.root
536 def __len__(self) -> int:
537 return len(self.root)
539 def compile_closest_config(self, file_path: Path) -> ConfigModel:
540 """Find the closest config file to this file"""
541 for parent in file_path.parents:
542 parent_config_path = str(parent / CONFIG_FILENAME)
544 if parent_config_path in self.root:
545 config = self.root[parent_config_path]
547 if config.template:
548 # Skip templates
549 continue
551 compiled = config.compiled_config(Path(parent_config_path), self)
553 return compiled
555 raise ValueError(f"No config found for {file_path}")
557 def iter_compiled_configs(self) -> Generator[tuple[str, ConfigModel]]:
558 for config_path, config in self.root.items():
559 if config.template:
560 # Skip templates
561 continue
563 yield config_path, config.compiled_config(Path(config_path), self)
565 def num_scopes(self) -> int:
566 """
567 Count the total number of scopes across all configs.
568 """
569 return sum(len(config.scopes) for config in self.root.values())
571 def num_reviewers(self) -> int:
572 """
573 Count the total number of reviewers across all configs.
574 """
575 return sum(
576 len(scope.reviewers)
577 for config in self.root.values()
578 for scope in config.scopes
579 )
581 def filter_for_pullrequest(
582 self,
583 base_branch: str,
584 head_branch: str,
585 author_username: str,
586 ) -> ConfigModels:
587 """
588 Look at all configs (including templates) and filter out
589 configs and scopes based on branches, authors, etc.
590 """
591 filtered_configs = {}
593 for config_path, config in self.root.items():
594 compiled_config = config.compiled_config(Path(config_path), self)
596 if not compiled_config.enabled_for_pullrequest(base_branch, head_branch):
597 # Remove the config from the list
598 continue
600 # Collect the names of scopes to remove
601 scopes_to_remove = set()
602 for scope in compiled_config.scopes:
603 if not scope.enabled_for_pullrequest(
604 author_username, base_branch, head_branch
605 ):
606 scopes_to_remove.add(scope.name)
608 # Create a filtered copy of the config data
609 filtered_config_data = config.model_dump()
610 filtered_config_data["scopes"] = [
611 scope
612 for scope in filtered_config_data["scopes"]
613 if scope["name"] not in scopes_to_remove
614 ]
616 # Rebuild using the original/modified raw data
617 filtered_configs[config_path] = filtered_config_data
619 return ConfigModels.from_configs_data(filtered_configs)