Coverage for src/pullapprove/matches.py: 80%
172 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 23:14 -0500
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 23:14 -0500
1from __future__ import annotations
3import hashlib
4import json
5from collections.abc import Generator, Iterator
6from pathlib import Path
7from typing import Any
9from pydantic import BaseModel, ConfigDict, Field, model_validator
11from .config import (
12 ConfigModel,
13 ConfigModels,
14 LargeScaleChangeModel,
15 ScopeModel,
16)
17from .diff import DiffCode, DiffFile, iterate_diff_parts
18from .exceptions import LargeScaleChangeException
21def match_path(
22 *, path: Path, config: ConfigModel
23) -> tuple[ScopePathMatch, list[ScopeModel]]:
24 path_match = ScopePathMatch(path=str(path), scopes=[])
26 scopes_matching_paths = [
27 scope for scope in config.scopes if scope.matches_path(path)
28 ]
29 code_scopes = [scope for scope in scopes_matching_paths if scope.code]
30 path_scopes = [scope for scope in scopes_matching_paths if not scope.code]
32 # Set the scopes on the path itself
33 for scope in path_scopes:
34 path_match.add_scope(scope)
36 return path_match, code_scopes
39def match_code(
40 *, path: str, code: str, scopes: list[ScopeModel], line_offset: int = 0
41) -> Generator[ScopeCodeMatch]:
42 code_matches: dict[str, ScopeCodeMatch] = {}
44 for scope in scopes:
45 for match in scope.matches_code(code):
46 code_match = ScopeCodeMatch(
47 path=path,
48 start_line=line_offset + match["start_line"],
49 end_line=line_offset + match["end_line"],
50 start_column=match["start_col"],
51 end_column=match["end_col"],
52 scopes=[scope.name],
53 location_id="",
54 )
55 code_match._scopes = [scope]
57 if code_match.location_id in code_matches:
58 # Just add the scopes to it
59 code_matches[code_match.location_id].add_scope(scope)
60 else:
61 code_matches[code_match.location_id] = code_match
63 yield from code_matches.values()
66def match_files(configs: ConfigModels, files: Iterator[str]) -> ChangeMatches:
67 def _iterate() -> Generator[ScopePathMatch | ScopeCodeMatch]:
68 for f in files:
69 file_path = Path(f)
71 config = configs.compile_closest_config(file_path)
73 path_match, code_scopes = match_path(
74 path=file_path,
75 config=config,
76 )
78 # Yield the paths first
79 yield path_match
81 # Then go line by line to find scopes that match lines
82 if code_scopes:
83 try:
84 code = file_path.read_text()
85 yield from match_code(
86 path=str(file_path),
87 code=code,
88 scopes=code_scopes,
89 )
90 except UnicodeDecodeError:
91 # Skip binary files that can't be decoded as text
92 pass
94 return ChangeMatches.from_config_matches(configs, _iterate())
97def iterate_diff(
98 configs: ConfigModels, diff: Iterator[str] | str
99) -> Generator[tuple[DiffFile | DiffCode, list[ScopePathMatch | ScopeCodeMatch]]]:
100 # We can still iterate a diff without configs, just by yield the diff objs
101 if not configs:
102 for diff_obj in iterate_diff_parts(diff):
103 yield diff_obj, []
105 return
107 # Keep track of these as we go and jump between file header
108 # and raw code during iteration
109 check_code_scopes: list[ScopeModel] = []
110 current_code_path = None
112 current_code_diffs = []
114 # TODO get root config here, check diff size as we go and raise exception?
115 # or we need to keep track per LSC? should be a compiled value...
117 def yield_code_diffs() -> Generator[
118 tuple[DiffCode, list[ScopePathMatch | ScopeCodeMatch]]
119 ]:
120 # We're passing the entire diff chunk to see if there's a match inside,
121 # but if there is, it probably won't match EVERY line in the chunk
122 assert current_code_path is not None, "current_code_path must be set"
123 current_code_chunk = "\n".join([code.raw() for code in current_code_diffs])
124 current_code_line_number = current_code_diffs[0].line_number - 1
126 code_matches = match_code(
127 path=current_code_path,
128 code=current_code_chunk,
129 scopes=check_code_scopes,
130 line_offset=current_code_line_number,
131 )
132 code_matches = list(code_matches)
134 for diff_line_index, diff_code in enumerate(current_code_diffs):
135 subcode_matches: list[ScopePathMatch | ScopeCodeMatch] = [
136 code_match
137 for code_match in code_matches
138 if code_match.start_line
139 <= (current_code_line_number + diff_line_index + 1)
140 <= code_match.end_line
141 ]
142 yield diff_code, subcode_matches
144 for diff_obj in iterate_diff_parts(diff):
145 if isinstance(diff_obj, DiffFile):
146 # Yield a code chunk if we finished one
147 if current_code_diffs:
148 yield from yield_code_diffs()
150 current_code_path = None
151 current_code_diffs = []
153 diff_file = diff_obj
154 file_path = Path(diff_file.new_path)
155 config = configs.compile_closest_config(file_path)
157 path_match, code_scopes = match_path(
158 path=file_path,
159 config=config,
160 )
162 current_code_path = str(file_path)
163 check_code_scopes = code_scopes
165 yield diff_obj, [path_match]
166 elif isinstance(diff_obj, DiffCode):
167 if check_code_scopes:
168 # It will be yielded later
169 current_code_diffs.append(diff_obj)
170 else:
171 # Skip all code lines if we don't care about code
172 yield diff_obj, []
174 # Yield the last code chunk we saw
175 if current_code_diffs:
176 yield from yield_code_diffs()
179def match_diff(configs: ConfigModels, diff: Iterator[str] | str) -> DiffResults:
180 config_paths_modified: set[str] = set()
181 additions = 0
182 deletions = 0
184 def iterate() -> Generator[ScopePathMatch | ScopeCodeMatch]:
185 nonlocal additions, deletions
186 for diff_obj, matches in iterate_diff(configs, diff):
187 # Track additions/deletions during existing iteration
188 if isinstance(diff_obj, DiffCode):
189 if diff_obj.is_addition():
190 additions += 1
191 elif diff_obj.is_deletion():
192 deletions += 1
194 if isinstance(diff_obj, DiffFile) and diff_obj.new_path in configs:
195 config_paths_modified.add(diff_obj.new_path)
196 if isinstance(diff_obj, DiffFile) and diff_obj.old_path in configs:
197 config_paths_modified.add(diff_obj.old_path)
199 yield from matches
201 try:
202 return DiffResults(
203 matches=ChangeMatches.from_config_matches(configs, iterate()),
204 config_paths_modified=list(config_paths_modified),
205 additions=additions,
206 deletions=deletions,
207 )
208 except LargeScaleChangeException:
209 # Get the large scale change config from CODEREVIEW.toml
210 lsc = configs.get_default_large_scale_change()
212 return DiffResults(
213 matches=ChangeMatches.from_large_scale_change(
214 configs=configs,
215 large_scale_change=lsc,
216 ),
217 config_paths_modified=list(config_paths_modified),
218 additions=additions,
219 deletions=deletions,
220 )
223class DiffResults(BaseModel):
224 """Results from analyzing a diff against configs."""
226 model_config = ConfigDict(extra="forbid")
228 matches: ChangeMatches
229 config_paths_modified: list[str] = Field(default_factory=list)
230 additions: int = 0
231 deletions: int = 0
234class ChangeMatches(BaseModel):
235 """
236 The matches for a given diff or set of files.
238 This knows nothing about a pull request (branches, commits, etc.)
239 """
241 model_config = ConfigDict(extra="forbid")
243 # Instead we could do
244 # - scopes
245 # - config
246 # - paths
247 # - code
248 # could add points, reviewers, etc to this
249 # but then we're mixing concerns... looking at raw files will just have empty values?
251 # Three modes are:
252 # - raw files
253 # - raw diff
254 # - pull request (has reviews)
256 configs: dict[str, ConfigModel] = {}
258 # The matching LSC, if there is one.
259 large_scale_change: LargeScaleChangeModel | None = None
261 # All scopes found in the results
262 scopes: dict[str, ScopeModel] = {}
264 # All evaluated paths
265 paths: dict[str, ScopePathMatch] = {}
267 # All code matches
268 code: dict[str, ScopeCodeMatch] = {}
270 def as_dict(self) -> dict[str, Any]:
271 return self.model_dump()
273 def __bool__(self) -> bool:
274 return bool(self.scopes)
276 @classmethod
277 def from_config_matches(
278 cls, configs: ConfigModels, matches: Iterator[ScopePathMatch | ScopeCodeMatch]
279 ) -> ChangeMatches:
280 scopes: dict[str, ScopeModel] = {}
281 paths: dict[str, ScopePathMatch] = {}
282 code: dict[str, ScopeCodeMatch] = {}
284 for match in matches:
285 # Store seen scopes as we go from all matches
286 for scope in match._scopes:
287 scopes[scope.name] = scope
289 if isinstance(match, ScopePathMatch):
290 if not match._scopes:
291 # Right now we don't care about storing anything that doesn't have scopes.
292 # This prevents an unnecessarily huge dump on big repos or PRs.
293 continue
295 paths[match.path] = match
297 elif isinstance(match, ScopeCodeMatch):
298 code_location_id = match.location_id
300 # Store it in the code results
301 code[code_location_id] = match
303 # Associate it with any path results
304 # if code_location_id not in paths[match.path].code:
305 # paths[match.path].code.append(code_location_id)
307 else:
308 raise ValueError(f"Unknown match type: {match}")
310 return cls(
311 large_scale_change=None,
312 scopes=scopes,
313 paths=paths,
314 code=code,
315 # Should this be compiled configs? At this point they may be modified (branches, author, etc.)
316 configs=configs.get_config_models(),
317 )
319 @classmethod
320 def from_large_scale_change(
321 cls,
322 configs: ConfigModels,
323 large_scale_change: LargeScaleChangeModel,
324 ) -> ChangeMatches:
325 return cls(
326 configs=configs.get_config_models(),
327 large_scale_change=large_scale_change,
328 scopes={},
329 paths={},
330 code={},
331 )
334class ScopePathMatch(BaseModel):
335 model_config = ConfigDict(extra="forbid")
337 path: str = Field(min_length=1)
338 scopes: list[str] # Field(min_length=1)
339 # code: list[str] = []
341 # Store this internally during processing (full reference of scope models)
342 _scopes: list[ScopeModel] = []
344 def add_scope(self, scope: ScopeModel) -> None:
345 if not scope.ownership:
346 # Remove any other scopes that don't have special ownership rules
347 # (i.e. we only want one primary scope in the end)
348 self._scopes = [s for s in self._scopes if s.ownership]
350 self._scopes.append(scope)
352 self.scopes = [s.name for s in self._scopes]
355class ScopeCodeMatch(BaseModel):
356 model_config = ConfigDict(extra="forbid")
358 # In a diff match, we could see both sides of the diff, i.e. repeated lines if the before and after both match...
359 path: str = Field(min_length=1)
360 start_line: int
361 end_line: int
362 start_column: int
363 end_column: int
364 scopes: list[str] # Field(min_length=1)
365 location_id: str
367 # Store this internally during processing (full reference of scope models)
368 _scopes: list[ScopeModel] = []
370 def printed_location(self) -> str:
371 if self.start_line == self.end_line:
372 return f"Ln {self.start_line}, Col {self.start_column}-{self.end_column}"
373 else:
374 return f"Ln {self.start_line}-{self.end_line}"
376 def add_scope(self, scope: ScopeModel) -> None:
377 if not scope.ownership:
378 # Remove any other scopes that don't have special ownership rules
379 # (i.e. we only want one primary scope in the end)
380 self._scopes = [s for s in self._scopes if s.ownership]
382 self._scopes.append(scope)
384 self.scopes = [s.name for s in self._scopes]
386 @model_validator(mode="after")
387 def compute_location_id(self) -> ScopeCodeMatch:
388 # only compute if the caller didn't provide one
389 if not self.location_id:
390 loc = {
391 "path": self.path,
392 "start_line": self.start_line,
393 "end_line": self.end_line,
394 "start_column": self.start_column,
395 "end_column": self.end_column,
396 }
397 raw = json.dumps(loc, sort_keys=True, separators=(",", ":")).encode()
398 self.location_id = hashlib.md5(raw).hexdigest()
399 return self
402# how to store what was reviewed? ideally we could be fine-grained, at some point
403# so we need to know who, which scopes, which paths, which codes (location hash) then we can cross reference everything?