batchata
Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.
Why AI-batching?
AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. This is ideal for workloads like document processing, data analysis, and content generation where immediate responses aren't required.
Quick Start
Installation
pip install batchata
Basic Usage
from batchata import Batch
# Simple batch processing
batch = Batch(results_dir="./output")
.set_default_params(model="claude-sonnet-4-20250514")
.add_cost_limit(usd=5.0)
# Add jobs
for file in files:
batch.add_job(file=file, prompt="Summarize this document")
# Execute
run = batch.run()
results = run.results()
Structured Output with Pydantic
from batchata import Batch
from pydantic import BaseModel
class DocumentAnalysis(BaseModel):
title: str
summary: str
key_points: list[str]
batch = Batch(results_dir="./results")
.set_default_params(model="claude-sonnet-4-20250514")
batch.add_job(
file="document.pdf",
prompt="Analyze this document",
response_model=DocumentAnalysis,
enable_citations=True # Anthropic only
)
run = batch.run()
for result in run.results()["completed"]:
analysis = result.parsed_response # DocumentAnalysis object
citations = result.citation_mappings # Field -> Citation mapping
Key Features
- 50% Cost Savings: Native batch processing via provider APIs
- Cost Limits: Set
max_cost_usdlimits for batch requests - Time Limits: Control execution time with
.add_time_limit() - State Persistence: Resume interrupted batches automatically
- Structured Output: Pydantic models with automatic validation
- Citations: Extract and map citations to response fields (Anthropic)
- Multiple Providers: Anthropic Claude and OpenAI GPT models
Supported Providers
| Feature | Anthropic | OpenAI |
|---|---|---|
| Models | All Claude models | All GPT models |
| Citations | ✅ | ❌ |
| Structured Output | ✅ | ✅ |
| File Types | PDF, TXT, DOCX, Images | PDF, Images |
Configuration
Set API keys as environment variables:
export ANTHROPIC_API_KEY="your-key"
export OPENAI_API_KEY="your-key"
Or use a .env file with python-dotenv.
1"""Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution. 2 3**Why AI-batching?** 4 5AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. 6This is ideal for workloads like document processing, data analysis, and content generation where immediate 7responses aren't required. 8 9## Quick Start 10 11### Installation 12 13```bash 14pip install batchata 15``` 16 17### Basic Usage 18 19```python 20from batchata import Batch 21 22# Simple batch processing 23batch = Batch(results_dir="./output") 24 .set_default_params(model="claude-sonnet-4-20250514") 25 .add_cost_limit(usd=5.0) 26 27# Add jobs 28for file in files: 29 batch.add_job(file=file, prompt="Summarize this document") 30 31# Execute 32run = batch.run() 33results = run.results() 34``` 35 36### Structured Output with Pydantic 37 38```python 39from batchata import Batch 40from pydantic import BaseModel 41 42class DocumentAnalysis(BaseModel): 43 title: str 44 summary: str 45 key_points: list[str] 46 47batch = Batch(results_dir="./results") 48 .set_default_params(model="claude-sonnet-4-20250514") 49 50batch.add_job( 51 file="document.pdf", 52 prompt="Analyze this document", 53 response_model=DocumentAnalysis, 54 enable_citations=True # Anthropic only 55) 56 57run = batch.run() 58for result in run.results()["completed"]: 59 analysis = result.parsed_response # DocumentAnalysis object 60 citations = result.citation_mappings # Field -> Citation mapping 61``` 62 63## Key Features 64 65- **50% Cost Savings**: Native batch processing via provider APIs 66- **Cost Limits**: Set `max_cost_usd` limits for batch requests 67- **Time Limits**: Control execution time with `.add_time_limit()` 68- **State Persistence**: Resume interrupted batches automatically 69- **Structured Output**: Pydantic models with automatic validation 70- **Citations**: Extract and map citations to response fields (Anthropic) 71- **Multiple Providers**: Anthropic Claude and OpenAI GPT models 72 73## Supported Providers 74 75| Feature | Anthropic | OpenAI | 76|---------|-----------|--------| 77| Models | [All Claude models](https://github.com/agamm/batchata/blob/main/batchata/providers/anthropic/models.py) | [All GPT models](https://github.com/agamm/batchata/blob/main/batchata/providers/openai/models.py) | 78| Citations | ✅ | ❌ | 79| Structured Output | ✅ | ✅ | 80| File Types | PDF, TXT, DOCX, Images | PDF, Images | 81 82## Configuration 83 84Set API keys as environment variables: 85 86```bash 87export ANTHROPIC_API_KEY="your-key" 88export OPENAI_API_KEY="your-key" 89``` 90 91Or use a `.env` file with python-dotenv. 92""" 93 94from .core import Batch, BatchRun, Job, JobResult 95from .exceptions import ( 96 BatchataError, 97 CostLimitExceededError, 98 ProviderError, 99 ProviderNotFoundError, 100 ValidationError, 101) 102from .types import Citation 103 104__version__ = "0.3.0" 105 106__all__ = [ 107 "Batch", 108 "BatchRun", 109 "Job", 110 "JobResult", 111 "Citation", 112 "BatchataError", 113 "CostLimitExceededError", 114 "ProviderError", 115 "ProviderNotFoundError", 116 "ValidationError", 117]
19class Batch: 20 """Builder for batch job configuration. 21 22 Provides a fluent interface for configuring batch jobs with sensible defaults 23 and validation. The batch can be configured with cost limits, default parameters, 24 and progress callbacks. 25 26 Example: 27 ```python 28 batch = Batch("./results", max_parallel_batches=10, items_per_batch=10) 29 .set_state(file="./state.json", reuse_state=True) 30 .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7) 31 .add_cost_limit(usd=15.0) 32 .add_job(messages=[{"role": "user", "content": "Hello"}]) 33 .add_job(file="./path/to/file.pdf", prompt="Generate summary of file") 34 35 run = batch.run() 36 ``` 37 """ 38 39 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 40 """Initialize batch configuration. 41 42 Args: 43 results_dir: Directory to store results 44 max_parallel_batches: Maximum parallel batch requests 45 items_per_batch: Number of jobs per provider batch 46 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 47 """ 48 # Auto-determine raw_files based on results_dir if not explicitly set 49 if raw_files is None: 50 raw_files = bool(results_dir and results_dir.strip()) 51 52 self.config = BatchParams( 53 state_file=None, 54 results_dir=results_dir, 55 max_parallel_batches=max_parallel_batches, 56 items_per_batch=items_per_batch, 57 reuse_state=True, 58 raw_files=raw_files 59 ) 60 self.jobs: List[Job] = [] 61 62 def set_default_params(self, **kwargs) -> 'Batch': 63 """Set default parameters for all jobs. 64 65 These defaults will be applied to all jobs unless overridden 66 by job-specific parameters. 67 68 Args: 69 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 70 71 Returns: 72 Self for chaining 73 74 Example: 75 ```python 76 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 77 ``` 78 """ 79 # Validate if model is provided 80 if "model" in kwargs: 81 self.config.validate_default_params(kwargs["model"]) 82 83 self.config.default_params.update(kwargs) 84 return self 85 86 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 87 """Set state file configuration. 88 89 Args: 90 file: Path to state file for persistence (default: None) 91 reuse_state: Whether to resume from existing state file (default: True) 92 93 Returns: 94 Self for chaining 95 96 Example: 97 ```python 98 batch.set_state(file="./state.json", reuse_state=True) 99 ``` 100 """ 101 self.config.state_file = file 102 self.config.reuse_state = reuse_state 103 return self 104 105 def add_cost_limit(self, usd: float) -> 'Batch': 106 """Add cost limit for the batch. 107 108 The batch will stop accepting new jobs once the cost limit is reached. 109 Active jobs will be allowed to complete. 110 111 Args: 112 usd: Cost limit in USD 113 114 Returns: 115 Self for chaining 116 117 Example: 118 ```python 119 batch.add_cost_limit(usd=50.0) 120 ``` 121 """ 122 if usd <= 0: 123 raise ValueError("Cost limit must be positive") 124 self.config.cost_limit_usd = usd 125 return self 126 127 def raw_files(self, enabled: bool = True) -> 'Batch': 128 """Enable or disable saving debug files from providers. 129 130 When enabled, debug files (raw API responses, JSONL files) will be saved 131 in a 'raw_files' subdirectory within the results directory. 132 This is useful for debugging, auditing, or accessing provider-specific metadata. 133 134 Args: 135 enabled: Whether to save debug files (default: True) 136 137 Returns: 138 Self for chaining 139 140 Example: 141 ```python 142 batch.raw_files(True) 143 ``` 144 """ 145 self.config.raw_files = enabled 146 return self 147 148 def set_verbosity(self, level: str) -> 'Batch': 149 """Set logging verbosity level. 150 151 Args: 152 level: Verbosity level ("debug", "info", "warn", "error") 153 154 Returns: 155 Self for chaining 156 157 Example: 158 ```python 159 batch.set_verbosity("error") # For production 160 batch.set_verbosity("debug") # For debugging 161 ``` 162 """ 163 valid_levels = {"debug", "info", "warn", "error"} 164 if level.lower() not in valid_levels: 165 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 166 self.config.verbosity = level.lower() 167 return self 168 169 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 170 """Add time limit for the entire batch execution. 171 172 When time limit is reached, all active provider batches are cancelled and 173 remaining unprocessed jobs are marked as failed. The batch execution 174 completes normally without throwing exceptions. 175 176 Args: 177 seconds: Time limit in seconds (optional) 178 minutes: Time limit in minutes (optional) 179 hours: Time limit in hours (optional) 180 181 Returns: 182 Self for chaining 183 184 Raises: 185 ValueError: If no time units specified, or if total time is outside 186 valid range (min: 10 seconds, max: 24 hours) 187 188 Note: 189 - Can combine multiple time units 190 - Time limit is checked every second by a background watchdog thread 191 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 192 - No exceptions are thrown when time limit is reached 193 194 Example: 195 ```python 196 batch.add_time_limit(seconds=30) # 30 seconds 197 batch.add_time_limit(minutes=5) # 5 minutes 198 batch.add_time_limit(hours=2) # 2 hours 199 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 200 ``` 201 """ 202 time_limit_seconds = 0.0 203 204 if seconds is not None: 205 time_limit_seconds += seconds 206 if minutes is not None: 207 time_limit_seconds += minutes * 60 208 if hours is not None: 209 time_limit_seconds += hours * 3600 210 211 if time_limit_seconds == 0: 212 raise ValueError("Must specify at least one of seconds, minutes, or hours") 213 214 self.config.time_limit_seconds = time_limit_seconds 215 return self 216 217 def add_job( 218 self, 219 messages: Optional[List[Message]] = None, 220 file: Optional[Union[str, Path]] = None, 221 prompt: Optional[str] = None, 222 model: Optional[str] = None, 223 temperature: Optional[float] = None, 224 max_tokens: Optional[int] = None, 225 response_model: Optional[Type[BaseModel]] = None, 226 enable_citations: bool = False, 227 **kwargs 228 ) -> 'Batch': 229 """Add a job to the batch. 230 231 Either provide messages OR file+prompt, not both. Parameters not provided 232 will use the defaults set via the defaults() method. 233 234 Args: 235 messages: Chat messages for direct input 236 file: File path for file-based input 237 prompt: Prompt to use with file input 238 model: Model to use (overrides default) 239 temperature: Sampling temperature (overrides default) 240 max_tokens: Max tokens to generate (overrides default) 241 response_model: Pydantic model for structured output 242 enable_citations: Whether to extract citations 243 **kwargs: Additional parameters 244 245 Returns: 246 Self for chaining 247 248 Example: 249 ```python 250 batch.add_job( 251 messages=[{"role": "user", "content": "Hello"}], 252 model="gpt-4" 253 ) 254 ``` 255 """ 256 # Generate unique job ID 257 job_id = f"job-{uuid.uuid4().hex[:8]}" 258 259 # Merge with defaults 260 params = self.config.default_params.copy() 261 262 # Update with provided parameters 263 if model is not None: 264 params["model"] = model 265 if temperature is not None: 266 params["temperature"] = temperature 267 if max_tokens is not None: 268 params["max_tokens"] = max_tokens 269 270 # Add other kwargs 271 params.update(kwargs) 272 273 # Ensure model is provided 274 if "model" not in params: 275 raise ValueError("Model must be provided either in defaults or job parameters") 276 277 # Validate parameters 278 provider = get_provider(params["model"]) 279 # Extract params without model to avoid duplicate 280 param_subset = {k: v for k, v in params.items() if k != "model"} 281 provider.validate_params(params["model"], **param_subset) 282 283 # Convert file path if string 284 if isinstance(file, str): 285 file = Path(file) 286 287 # Warn about temporary file paths that may not persist 288 if file: 289 file_str = str(file) 290 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 291 logger = logging.getLogger("batchata") 292 logger.debug(f"File path appears to be in a temporary directory: {file}") 293 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 294 295 # Create job 296 job = Job( 297 id=job_id, 298 messages=messages, 299 file=file, 300 prompt=prompt, 301 response_model=response_model, 302 enable_citations=enable_citations, 303 **params 304 ) 305 306 # Validate citation compatibility 307 if response_model and enable_citations: 308 from ..utils.validation import validate_flat_model 309 validate_flat_model(response_model) 310 311 # Validate job with provider (includes PDF validation for Anthropic) 312 provider.validate_job(job) 313 314 315 self.jobs.append(job) 316 return self 317 318 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun': 319 """Execute the batch. 320 321 Creates a BatchRun instance and executes the jobs synchronously. 322 323 Args: 324 on_progress: Optional progress callback function that receives 325 (stats_dict, elapsed_time_seconds, batch_data) 326 progress_interval: Interval in seconds between progress updates (default: 1.0) 327 print_status: Whether to show rich progress display (default: False) 328 dry_run: If True, only show cost estimation without executing (default: False) 329 330 Returns: 331 BatchRun instance with completed results 332 333 Raises: 334 ValueError: If no jobs have been added 335 """ 336 if not self.jobs: 337 raise ValueError("No jobs added to batch") 338 339 # Import here to avoid circular dependency 340 from .batch_run import BatchRun 341 342 # Create and start the run 343 run = BatchRun(self.config, self.jobs) 344 345 # Handle dry run mode 346 if dry_run: 347 return run.dry_run() 348 349 # Set progress callback - either rich display or custom callback 350 if print_status: 351 return self._run_with_rich_display(run, progress_interval) 352 else: 353 return self._run_with_custom_callback(run, on_progress, progress_interval) 354 355 def _run_with_rich_display(self, run: 'BatchRun', progress_interval: float) -> 'BatchRun': 356 """Execute batch run with rich progress display. 357 358 Args: 359 run: BatchRun instance to execute 360 progress_interval: Interval between progress updates 361 362 Returns: 363 Completed BatchRun instance 364 """ 365 from ..utils.rich_progress import RichBatchProgressDisplay 366 display = RichBatchProgressDisplay() 367 368 def rich_progress_callback(stats, elapsed_time, batch_data): 369 # Start display on first call 370 if not hasattr(rich_progress_callback, '_started'): 371 config_dict = { 372 'results_dir': self.config.results_dir, 373 'state_file': self.config.state_file, 374 'items_per_batch': self.config.items_per_batch, 375 'max_parallel_batches': self.config.max_parallel_batches 376 } 377 display.start(stats, config_dict) 378 rich_progress_callback._started = True 379 380 # Update display 381 display.update(stats, batch_data, elapsed_time) 382 383 run.set_on_progress(rich_progress_callback, interval=progress_interval) 384 385 # Execute with proper cleanup 386 try: 387 run.execute() 388 389 # Show final status with all batches completed 390 stats = run.status() 391 display.update(stats, run.batch_tracking, (datetime.now() - run._start_time).total_seconds()) 392 393 # Small delay to ensure display updates 394 import time 395 time.sleep(0.2) 396 397 except KeyboardInterrupt: 398 # Update batch tracking to show cancelled status for pending/running batches 399 with run._state_lock: 400 for batch_id, batch_info in run.batch_tracking.items(): 401 if batch_info['status'] == 'running': 402 batch_info['status'] = 'cancelled' 403 elif batch_info['status'] == 'pending': 404 batch_info['status'] = 'cancelled' 405 406 # Show final status with cancelled batches 407 stats = run.status() 408 display.update(stats, run.batch_tracking, 0.0) 409 410 # Add a small delay to ensure the display updates 411 import time 412 time.sleep(0.1) 413 414 display.stop() 415 raise 416 finally: 417 if display.live: # Only stop if not already stopped 418 display.stop() 419 420 return run 421 422 def _run_with_custom_callback(self, run: 'BatchRun', on_progress: Optional[Callable[[Dict, float, Dict], None]], progress_interval: float) -> 'BatchRun': 423 """Execute batch run with custom progress callback. 424 425 Args: 426 run: BatchRun instance to execute 427 on_progress: Optional custom progress callback 428 progress_interval: Interval between progress updates 429 430 Returns: 431 Completed BatchRun instance 432 """ 433 # Use custom progress callback if provided 434 if on_progress: 435 run.set_on_progress(on_progress, interval=progress_interval) 436 437 run.execute() 438 return run 439 440 def __len__(self) -> int: 441 """Get the number of jobs in the batch.""" 442 return len(self.jobs) 443 444 def __repr__(self) -> str: 445 """String representation of the batch.""" 446 return ( 447 f"Batch(jobs={len(self.jobs)}, " 448 f"max_parallel_batches={self.config.max_parallel_batches}, " 449 f"cost_limit=${self.config.cost_limit_usd or 'None'})" 450 )
Builder for batch job configuration.
Provides a fluent interface for configuring batch jobs with sensible defaults and validation. The batch can be configured with cost limits, default parameters, and progress callbacks.
Example:
batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
.set_state(file="./state.json", reuse_state=True)
.set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
.add_cost_limit(usd=15.0)
.add_job(messages=[{"role": "user", "content": "Hello"}])
.add_job(file="./path/to/file.pdf", prompt="Generate summary of file")
run = batch.run()
39 def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None): 40 """Initialize batch configuration. 41 42 Args: 43 results_dir: Directory to store results 44 max_parallel_batches: Maximum parallel batch requests 45 items_per_batch: Number of jobs per provider batch 46 raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise) 47 """ 48 # Auto-determine raw_files based on results_dir if not explicitly set 49 if raw_files is None: 50 raw_files = bool(results_dir and results_dir.strip()) 51 52 self.config = BatchParams( 53 state_file=None, 54 results_dir=results_dir, 55 max_parallel_batches=max_parallel_batches, 56 items_per_batch=items_per_batch, 57 reuse_state=True, 58 raw_files=raw_files 59 ) 60 self.jobs: List[Job] = []
Initialize batch configuration.
Args: results_dir: Directory to store results max_parallel_batches: Maximum parallel batch requests items_per_batch: Number of jobs per provider batch raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
62 def set_default_params(self, **kwargs) -> 'Batch': 63 """Set default parameters for all jobs. 64 65 These defaults will be applied to all jobs unless overridden 66 by job-specific parameters. 67 68 Args: 69 **kwargs: Default parameters (model, temperature, max_tokens, etc.) 70 71 Returns: 72 Self for chaining 73 74 Example: 75 ```python 76 batch.set_default_params(model="claude-3-sonnet", temperature=0.7) 77 ``` 78 """ 79 # Validate if model is provided 80 if "model" in kwargs: 81 self.config.validate_default_params(kwargs["model"]) 82 83 self.config.default_params.update(kwargs) 84 return self
Set default parameters for all jobs.
These defaults will be applied to all jobs unless overridden by job-specific parameters.
Args: **kwargs: Default parameters (model, temperature, max_tokens, etc.)
Returns: Self for chaining
Example:
batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
86 def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch': 87 """Set state file configuration. 88 89 Args: 90 file: Path to state file for persistence (default: None) 91 reuse_state: Whether to resume from existing state file (default: True) 92 93 Returns: 94 Self for chaining 95 96 Example: 97 ```python 98 batch.set_state(file="./state.json", reuse_state=True) 99 ``` 100 """ 101 self.config.state_file = file 102 self.config.reuse_state = reuse_state 103 return self
Set state file configuration.
Args: file: Path to state file for persistence (default: None) reuse_state: Whether to resume from existing state file (default: True)
Returns: Self for chaining
Example:
batch.set_state(file="./state.json", reuse_state=True)
105 def add_cost_limit(self, usd: float) -> 'Batch': 106 """Add cost limit for the batch. 107 108 The batch will stop accepting new jobs once the cost limit is reached. 109 Active jobs will be allowed to complete. 110 111 Args: 112 usd: Cost limit in USD 113 114 Returns: 115 Self for chaining 116 117 Example: 118 ```python 119 batch.add_cost_limit(usd=50.0) 120 ``` 121 """ 122 if usd <= 0: 123 raise ValueError("Cost limit must be positive") 124 self.config.cost_limit_usd = usd 125 return self
Add cost limit for the batch.
The batch will stop accepting new jobs once the cost limit is reached. Active jobs will be allowed to complete.
Args: usd: Cost limit in USD
Returns: Self for chaining
Example:
batch.add_cost_limit(usd=50.0)
127 def raw_files(self, enabled: bool = True) -> 'Batch': 128 """Enable or disable saving debug files from providers. 129 130 When enabled, debug files (raw API responses, JSONL files) will be saved 131 in a 'raw_files' subdirectory within the results directory. 132 This is useful for debugging, auditing, or accessing provider-specific metadata. 133 134 Args: 135 enabled: Whether to save debug files (default: True) 136 137 Returns: 138 Self for chaining 139 140 Example: 141 ```python 142 batch.raw_files(True) 143 ``` 144 """ 145 self.config.raw_files = enabled 146 return self
Enable or disable saving debug files from providers.
When enabled, debug files (raw API responses, JSONL files) will be saved in a 'raw_files' subdirectory within the results directory. This is useful for debugging, auditing, or accessing provider-specific metadata.
Args: enabled: Whether to save debug files (default: True)
Returns: Self for chaining
Example:
batch.raw_files(True)
148 def set_verbosity(self, level: str) -> 'Batch': 149 """Set logging verbosity level. 150 151 Args: 152 level: Verbosity level ("debug", "info", "warn", "error") 153 154 Returns: 155 Self for chaining 156 157 Example: 158 ```python 159 batch.set_verbosity("error") # For production 160 batch.set_verbosity("debug") # For debugging 161 ``` 162 """ 163 valid_levels = {"debug", "info", "warn", "error"} 164 if level.lower() not in valid_levels: 165 raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}") 166 self.config.verbosity = level.lower() 167 return self
Set logging verbosity level.
Args: level: Verbosity level ("debug", "info", "warn", "error")
Returns: Self for chaining
Example:
batch.set_verbosity("error") # For production
batch.set_verbosity("debug") # For debugging
169 def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch': 170 """Add time limit for the entire batch execution. 171 172 When time limit is reached, all active provider batches are cancelled and 173 remaining unprocessed jobs are marked as failed. The batch execution 174 completes normally without throwing exceptions. 175 176 Args: 177 seconds: Time limit in seconds (optional) 178 minutes: Time limit in minutes (optional) 179 hours: Time limit in hours (optional) 180 181 Returns: 182 Self for chaining 183 184 Raises: 185 ValueError: If no time units specified, or if total time is outside 186 valid range (min: 10 seconds, max: 24 hours) 187 188 Note: 189 - Can combine multiple time units 190 - Time limit is checked every second by a background watchdog thread 191 - Jobs that exceed time limit appear in results()["failed"] with time limit error message 192 - No exceptions are thrown when time limit is reached 193 194 Example: 195 ```python 196 batch.add_time_limit(seconds=30) # 30 seconds 197 batch.add_time_limit(minutes=5) # 5 minutes 198 batch.add_time_limit(hours=2) # 2 hours 199 batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total 200 ``` 201 """ 202 time_limit_seconds = 0.0 203 204 if seconds is not None: 205 time_limit_seconds += seconds 206 if minutes is not None: 207 time_limit_seconds += minutes * 60 208 if hours is not None: 209 time_limit_seconds += hours * 3600 210 211 if time_limit_seconds == 0: 212 raise ValueError("Must specify at least one of seconds, minutes, or hours") 213 214 self.config.time_limit_seconds = time_limit_seconds 215 return self
Add time limit for the entire batch execution.
When time limit is reached, all active provider batches are cancelled and remaining unprocessed jobs are marked as failed. The batch execution completes normally without throwing exceptions.
Args: seconds: Time limit in seconds (optional) minutes: Time limit in minutes (optional) hours: Time limit in hours (optional)
Returns: Self for chaining
Raises: ValueError: If no time units specified, or if total time is outside valid range (min: 10 seconds, max: 24 hours)
Note: - Can combine multiple time units - Time limit is checked every second by a background watchdog thread - Jobs that exceed time limit appear in results()["failed"] with time limit error message - No exceptions are thrown when time limit is reached
Example:
batch.add_time_limit(seconds=30) # 30 seconds
batch.add_time_limit(minutes=5) # 5 minutes
batch.add_time_limit(hours=2) # 2 hours
batch.add_time_limit(hours=1, minutes=30, seconds=15) # 5415 seconds total
217 def add_job( 218 self, 219 messages: Optional[List[Message]] = None, 220 file: Optional[Union[str, Path]] = None, 221 prompt: Optional[str] = None, 222 model: Optional[str] = None, 223 temperature: Optional[float] = None, 224 max_tokens: Optional[int] = None, 225 response_model: Optional[Type[BaseModel]] = None, 226 enable_citations: bool = False, 227 **kwargs 228 ) -> 'Batch': 229 """Add a job to the batch. 230 231 Either provide messages OR file+prompt, not both. Parameters not provided 232 will use the defaults set via the defaults() method. 233 234 Args: 235 messages: Chat messages for direct input 236 file: File path for file-based input 237 prompt: Prompt to use with file input 238 model: Model to use (overrides default) 239 temperature: Sampling temperature (overrides default) 240 max_tokens: Max tokens to generate (overrides default) 241 response_model: Pydantic model for structured output 242 enable_citations: Whether to extract citations 243 **kwargs: Additional parameters 244 245 Returns: 246 Self for chaining 247 248 Example: 249 ```python 250 batch.add_job( 251 messages=[{"role": "user", "content": "Hello"}], 252 model="gpt-4" 253 ) 254 ``` 255 """ 256 # Generate unique job ID 257 job_id = f"job-{uuid.uuid4().hex[:8]}" 258 259 # Merge with defaults 260 params = self.config.default_params.copy() 261 262 # Update with provided parameters 263 if model is not None: 264 params["model"] = model 265 if temperature is not None: 266 params["temperature"] = temperature 267 if max_tokens is not None: 268 params["max_tokens"] = max_tokens 269 270 # Add other kwargs 271 params.update(kwargs) 272 273 # Ensure model is provided 274 if "model" not in params: 275 raise ValueError("Model must be provided either in defaults or job parameters") 276 277 # Validate parameters 278 provider = get_provider(params["model"]) 279 # Extract params without model to avoid duplicate 280 param_subset = {k: v for k, v in params.items() if k != "model"} 281 provider.validate_params(params["model"], **param_subset) 282 283 # Convert file path if string 284 if isinstance(file, str): 285 file = Path(file) 286 287 # Warn about temporary file paths that may not persist 288 if file: 289 file_str = str(file) 290 if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower(): 291 logger = logging.getLogger("batchata") 292 logger.debug(f"File path appears to be in a temporary directory: {file}") 293 logger.debug("This may cause issues when resuming from state if temp files are cleaned up") 294 295 # Create job 296 job = Job( 297 id=job_id, 298 messages=messages, 299 file=file, 300 prompt=prompt, 301 response_model=response_model, 302 enable_citations=enable_citations, 303 **params 304 ) 305 306 # Validate citation compatibility 307 if response_model and enable_citations: 308 from ..utils.validation import validate_flat_model 309 validate_flat_model(response_model) 310 311 # Validate job with provider (includes PDF validation for Anthropic) 312 provider.validate_job(job) 313 314 315 self.jobs.append(job) 316 return self
Add a job to the batch.
Either provide messages OR file+prompt, not both. Parameters not provided will use the defaults set via the defaults() method.
Args: messages: Chat messages for direct input file: File path for file-based input prompt: Prompt to use with file input model: Model to use (overrides default) temperature: Sampling temperature (overrides default) max_tokens: Max tokens to generate (overrides default) response_model: Pydantic model for structured output enable_citations: Whether to extract citations **kwargs: Additional parameters
Returns: Self for chaining
Example:
batch.add_job(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4"
)
318 def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun': 319 """Execute the batch. 320 321 Creates a BatchRun instance and executes the jobs synchronously. 322 323 Args: 324 on_progress: Optional progress callback function that receives 325 (stats_dict, elapsed_time_seconds, batch_data) 326 progress_interval: Interval in seconds between progress updates (default: 1.0) 327 print_status: Whether to show rich progress display (default: False) 328 dry_run: If True, only show cost estimation without executing (default: False) 329 330 Returns: 331 BatchRun instance with completed results 332 333 Raises: 334 ValueError: If no jobs have been added 335 """ 336 if not self.jobs: 337 raise ValueError("No jobs added to batch") 338 339 # Import here to avoid circular dependency 340 from .batch_run import BatchRun 341 342 # Create and start the run 343 run = BatchRun(self.config, self.jobs) 344 345 # Handle dry run mode 346 if dry_run: 347 return run.dry_run() 348 349 # Set progress callback - either rich display or custom callback 350 if print_status: 351 return self._run_with_rich_display(run, progress_interval) 352 else: 353 return self._run_with_custom_callback(run, on_progress, progress_interval)
Execute the batch.
Creates a BatchRun instance and executes the jobs synchronously.
Args: on_progress: Optional progress callback function that receives (stats_dict, elapsed_time_seconds, batch_data) progress_interval: Interval in seconds between progress updates (default: 1.0) print_status: Whether to show rich progress display (default: False) dry_run: If True, only show cost estimation without executing (default: False)
Returns: BatchRun instance with completed results
Raises: ValueError: If no jobs have been added
25class BatchRun: 26 """Manages the execution of a batch job synchronously. 27 28 Processes jobs in batches based on items_per_batch configuration. 29 Simpler synchronous execution for clear logging and debugging. 30 31 Example: 32 ```python 33 config = BatchParams(...) 34 run = BatchRun(config, jobs) 35 run.execute() 36 results = run.results() 37 ``` 38 """ 39 40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 self._progress_lock = threading.Lock() 88 self._last_progress_update = 0.0 89 90 # Batch tracking for progress display 91 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 92 93 # Active batch tracking for cancellation 94 self._active_batches: Dict[str, object] = {} # batch_id -> provider 95 self._active_batches_lock = threading.Lock() 96 97 # Results directory 98 self.results_dir = Path(config.results_dir) 99 100 # If not reusing state, clear the results directory 101 if not config.reuse_state and self.results_dir.exists(): 102 import shutil 103 shutil.rmtree(self.results_dir) 104 105 self.results_dir.mkdir(parents=True, exist_ok=True) 106 107 # Raw files directory (if enabled) 108 self.raw_files_dir = None 109 if config.raw_files: 110 self.raw_files_dir = self.results_dir / "raw_files" 111 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 112 113 # Try to resume from saved state 114 self._resume_from_state() 115 116 117 def _resume_from_state(self): 118 """Resume from saved state if available.""" 119 # Check if we should reuse state 120 if not self.config.reuse_state: 121 # Clear any existing state and start fresh 122 self.state_manager.clear() 123 self.pending_jobs = list(self.jobs.values()) 124 return 125 126 state = self.state_manager.load() 127 if state is None: 128 # No saved state, use jobs passed to constructor 129 self.pending_jobs = list(self.jobs.values()) 130 return 131 132 logger.info("Resuming batch run from saved state") 133 134 # Restore pending jobs 135 self.pending_jobs = [] 136 for job_data in state.pending_jobs: 137 job = Job.from_dict(job_data) 138 # Check if file exists (if job has a file) 139 if job.file and not job.file.exists(): 140 logger.error(f"File not found for job {job.id}: {job.file}") 141 logger.error("This may happen if files were in temporary directories that were cleaned up") 142 self.failed_jobs[job.id] = f"File not found: {job.file}" 143 else: 144 self.pending_jobs.append(job) 145 146 # Restore completed results from file references 147 for result_ref in state.completed_results: 148 job_id = result_ref["job_id"] 149 file_path = result_ref["file_path"] 150 try: 151 with open(file_path, 'r') as f: 152 result_data = json.load(f) 153 result = JobResult.from_dict(result_data) 154 self.completed_results[job_id] = result 155 except Exception as e: 156 logger.error(f"Failed to load result for {job_id} from {file_path}: {e}") 157 # Move to failed jobs if we can't load the result 158 self.failed_jobs[job_id] = f"Failed to load result file: {e}" 159 160 # Restore failed jobs 161 for job_data in state.failed_jobs: 162 self.failed_jobs[job_data["id"]] = job_data.get("error", "Unknown error") 163 164 # Restore cancelled jobs (if they exist in state) 165 for job_data in getattr(state, 'cancelled_jobs', []): 166 self.cancelled_jobs[job_data["id"]] = job_data.get("reason", "Cancelled") 167 168 # Restore cost tracker 169 self.cost_tracker.used_usd = state.total_cost_usd 170 171 logger.info( 172 f"Resumed with {len(self.pending_jobs)} pending, " 173 f"{len(self.completed_results)} completed, " 174 f"{len(self.failed_jobs)} failed, " 175 f"{len(self.cancelled_jobs)} cancelled" 176 ) 177 178 def to_json(self) -> Dict: 179 """Convert current state to JSON-serializable dict.""" 180 return { 181 "created_at": datetime.now().isoformat(), 182 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 183 "completed_results": [ 184 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 185 for job_id in self.completed_results.keys() 186 ], 187 "failed_jobs": [ 188 { 189 "id": job_id, 190 "error": error, 191 "timestamp": datetime.now().isoformat() 192 } for job_id, error in self.failed_jobs.items() 193 ], 194 "cancelled_jobs": [ 195 { 196 "id": job_id, 197 "reason": reason, 198 "timestamp": datetime.now().isoformat() 199 } for job_id, reason in self.cancelled_jobs.items() 200 ], 201 "total_cost_usd": self.cost_tracker.used_usd, 202 "config": { 203 "state_file": self.config.state_file, 204 "results_dir": self.config.results_dir, 205 "max_parallel_batches": self.config.max_parallel_batches, 206 "items_per_batch": self.config.items_per_batch, 207 "cost_limit_usd": self.config.cost_limit_usd, 208 "default_params": self.config.default_params, 209 "raw_files": self.config.raw_files 210 } 211 } 212 213 def execute(self): 214 """Execute synchronous batch run and wait for completion.""" 215 if self._started: 216 raise RuntimeError("Batch run already started") 217 218 self._started = True 219 self._start_time = datetime.now() 220 221 # Register signal handler for graceful shutdown 222 def signal_handler(signum, frame): 223 logger.warning("Received interrupt signal, shutting down gracefully...") 224 self._shutdown_event.set() 225 226 # Store original handler to restore later 227 original_handler = signal.signal(signal.SIGINT, signal_handler) 228 229 try: 230 logger.info("Starting batch run") 231 232 # Start time limit watchdog if configured 233 self._start_time_limit_watchdog() 234 235 # Call initial progress 236 if self._progress_callback: 237 with self._progress_lock: 238 with self._state_lock: 239 stats = self.status() 240 batch_data = dict(self.batch_tracking) 241 self._progress_callback(stats, 0.0, batch_data) 242 self._last_progress_update = time.time() 243 244 # Process all jobs synchronously 245 self._process_all_jobs() 246 247 logger.info("Batch run completed") 248 finally: 249 # Restore original signal handler 250 signal.signal(signal.SIGINT, original_handler) 251 252 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 253 """Set progress callback for execution monitoring. 254 255 The callback will be called periodically with progress statistics 256 including completed jobs, total jobs, current cost, etc. 257 258 Args: 259 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 260 - stats_dict: Progress statistics dictionary 261 - elapsed_time_seconds: Time elapsed since batch started (float) 262 - batch_data: Dictionary mapping batch_id to batch information 263 interval: Interval in seconds between progress updates (default: 1.0) 264 265 Returns: 266 Self for chaining 267 268 Example: 269 ```python 270 run.set_on_progress( 271 lambda stats, time, batch_data: print( 272 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 273 ) 274 ) 275 ``` 276 """ 277 self._progress_callback = callback 278 self._progress_interval = interval 279 return self 280 281 def _start_time_limit_watchdog(self): 282 """Start a background thread to check for time limit every second.""" 283 if not self.config.time_limit_seconds: 284 return 285 286 def time_limit_watchdog(): 287 """Check for time limit every second and trigger shutdown if exceeded.""" 288 while not self._shutdown_event.is_set(): 289 if self._check_time_limit(): 290 logger.warning("Batch execution time limit exceeded") 291 with self._state_lock: 292 self._time_limit_exceeded = True 293 self._shutdown_event.set() 294 break 295 time.sleep(1.0) 296 297 # Start watchdog as daemon thread 298 watchdog_thread = threading.Thread(target=time_limit_watchdog, daemon=True) 299 watchdog_thread.start() 300 logger.debug(f"Started time limit watchdog thread (time limit: {self.config.time_limit_seconds}s)") 301 302 def _check_time_limit(self) -> bool: 303 """Check if batch execution has exceeded time limit.""" 304 if not self.config.time_limit_seconds or not self._start_time: 305 return False 306 307 elapsed = (datetime.now() - self._start_time).total_seconds() 308 return elapsed >= self.config.time_limit_seconds 309 310 def _process_all_jobs(self): 311 """Process all jobs with parallel execution.""" 312 # Prepare all batches 313 batches = self._prepare_batches() 314 self.total_batches = len(batches) 315 316 # Process batches in parallel 317 with ThreadPoolExecutor(max_workers=self.config.max_parallel_batches) as executor: 318 futures = [executor.submit(self._execute_batch_wrapped, provider, batch_jobs) 319 for _, provider, batch_jobs in batches] 320 321 try: 322 for future in as_completed(futures): 323 # Stop if shutdown event detected (includes time limit) 324 if self._shutdown_event.is_set(): 325 break 326 future.result() # Re-raise any exceptions 327 except KeyboardInterrupt: 328 self._shutdown_event.set() 329 # Cancel remaining futures 330 for future in futures: 331 future.cancel() 332 raise 333 finally: 334 # Handle time limit or cancellation - mark remaining jobs appropriately 335 with self._state_lock: 336 if self._shutdown_event.is_set(): 337 # If time limit exceeded, cancel all active batches 338 if self._time_limit_exceeded: 339 self._cancel_all_active_batches() 340 341 # Mark any unprocessed jobs based on reason for shutdown 342 for _, _, batch_jobs in batches: 343 for job in batch_jobs: 344 # Skip jobs already processed 345 if (job.id in self.completed_results or 346 job.id in self.failed_jobs or 347 job.id in self.cancelled_jobs): 348 continue 349 350 # Mark based on shutdown reason 351 if self._time_limit_exceeded: 352 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 353 else: 354 self.cancelled_jobs[job.id] = "Cancelled by user" 355 356 if job in self.pending_jobs: 357 self.pending_jobs.remove(job) 358 359 # Save state 360 self.state_manager.save(self) 361 362 def _cancel_all_active_batches(self): 363 """Cancel all active batches at the provider level.""" 364 with self._active_batches_lock: 365 active_batch_items = list(self._active_batches.items()) 366 367 logger.info(f"Cancelling {len(active_batch_items)} active batches due to time limit exceeded") 368 369 # Cancel outside the lock to avoid blocking 370 for batch_id, provider in active_batch_items: 371 try: 372 provider.cancel_batch(batch_id) 373 logger.info(f"Cancelled batch {batch_id} due to time limit exceeded") 374 except Exception as e: 375 logger.warning(f"Failed to cancel batch {batch_id}: {e}") 376 377 # Clear the tracking after cancellation attempts 378 with self._active_batches_lock: 379 self._active_batches.clear() 380 381 def _execute_batch_wrapped(self, provider, batch_jobs): 382 """Thread-safe wrapper for _execute_batch.""" 383 try: 384 result = self._execute_batch(provider, batch_jobs) 385 with self._state_lock: 386 self._update_batch_results(result) 387 # Remove jobs from pending_jobs if specified 388 jobs_to_remove = result.get("jobs_to_remove", []) 389 for job in jobs_to_remove: 390 if job in self.pending_jobs: 391 self.pending_jobs.remove(job) 392 except TimeoutError: 393 # Handle time limit exceeded - mark jobs as failed 394 with self._state_lock: 395 for job in batch_jobs: 396 self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded" 397 if job in self.pending_jobs: 398 self.pending_jobs.remove(job) 399 self.state_manager.save(self) 400 # Don't re-raise, just return the result 401 return 402 except KeyboardInterrupt: 403 self._shutdown_event.set() 404 # Handle user cancellation 405 with self._state_lock: 406 for job in batch_jobs: 407 self.cancelled_jobs[job.id] = "Cancelled by user" 408 if job in self.pending_jobs: 409 self.pending_jobs.remove(job) 410 self.state_manager.save(self) 411 raise 412 413 def _group_jobs_by_provider(self) -> Dict[str, List[Job]]: 414 """Group jobs by provider.""" 415 jobs_by_provider = {} 416 417 for job in self.pending_jobs[:]: # Copy to avoid modification during iteration 418 try: 419 provider = get_provider(job.model) 420 provider_name = provider.__class__.__name__ 421 422 if provider_name not in jobs_by_provider: 423 jobs_by_provider[provider_name] = [] 424 425 jobs_by_provider[provider_name].append(job) 426 427 except Exception as e: 428 logger.error(f"Failed to get provider for job {job.id}: {e}") 429 with self._state_lock: 430 self.failed_jobs[job.id] = str(e) 431 self.pending_jobs.remove(job) 432 433 return jobs_by_provider 434 435 def _split_into_batches(self, jobs: List[Job]) -> List[List[Job]]: 436 """Split jobs into batches based on items_per_batch.""" 437 batches = [] 438 batch_size = self.config.items_per_batch 439 440 for i in range(0, len(jobs), batch_size): 441 batch = jobs[i:i + batch_size] 442 batches.append(batch) 443 444 return batches 445 446 def _prepare_batches(self) -> List[Tuple[str, object, List[Job]]]: 447 """Prepare all batches as simple list of (provider_name, provider, jobs).""" 448 batches = [] 449 jobs_by_provider = self._group_jobs_by_provider() 450 451 for provider_name, provider_jobs in jobs_by_provider.items(): 452 provider = get_provider(provider_jobs[0].model) 453 job_batches = self._split_into_batches(provider_jobs) 454 455 for batch_jobs in job_batches: 456 batches.append((provider_name, provider, batch_jobs)) 457 458 # Pre-populate batch tracking for pending batches 459 batch_id = f"pending_{len(self.batch_tracking)}" 460 estimated_cost = provider.estimate_cost(batch_jobs) 461 self.batch_tracking[batch_id] = { 462 'start_time': None, 463 'status': 'pending', 464 'total': len(batch_jobs), 465 'completed': 0, 466 'cost': 0.0, 467 'estimated_cost': estimated_cost, 468 'provider': provider_name, 469 'jobs': batch_jobs 470 } 471 472 return batches 473 474 def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dict]]: 475 """Poll until batch completes.""" 476 status, error_details = provider.get_batch_status(batch_id) 477 logger.info(f"Initial batch status: {status}") 478 poll_count = 0 479 480 # Use provider-specific polling interval 481 provider_polling_interval = provider.get_polling_interval() 482 logger.debug(f"Using {provider_polling_interval}s polling interval for {provider.__class__.__name__}") 483 484 while status not in ["complete", "failed"]: 485 poll_count += 1 486 logger.debug(f"Polling attempt {poll_count}, current status: {status}") 487 488 # Interruptible wait - will wake up immediately if shutdown event is set (includes time limit) 489 if self._shutdown_event.wait(provider_polling_interval): 490 # Check if it's time limit exceeded or user cancellation 491 with self._state_lock: 492 is_time_limit_exceeded = self._time_limit_exceeded 493 494 if is_time_limit_exceeded: 495 logger.info(f"Batch {batch_id} polling interrupted by time limit exceeded") 496 raise TimeoutError("Batch cancelled due to time limit exceeded") 497 else: 498 logger.info(f"Batch {batch_id} polling interrupted by user") 499 raise KeyboardInterrupt("Batch cancelled by user") 500 501 status, error_details = provider.get_batch_status(batch_id) 502 503 if self._progress_callback: 504 # Rate limit progress updates and synchronize calls to prevent duplicate printing 505 current_time = time.time() 506 should_update = current_time - self._last_progress_update >= self._progress_interval 507 508 if should_update: 509 with self._progress_lock: 510 # Double-check timing inside the lock to avoid race condition 511 if current_time - self._last_progress_update >= self._progress_interval: 512 with self._state_lock: 513 stats = self.status() 514 elapsed_time = (datetime.now() - self._start_time).total_seconds() 515 batch_data = dict(self.batch_tracking) 516 self._progress_callback(stats, elapsed_time, batch_data) 517 self._last_progress_update = current_time 518 519 elapsed_seconds = poll_count * provider_polling_interval 520 521 return status, error_details 522 523 524 def _update_batch_results(self, batch_result: Dict): 525 """Update state from batch results.""" 526 results = batch_result.get("results", []) 527 failed = batch_result.get("failed", {}) 528 529 # Update completed results 530 for result in results: 531 if result.is_success: 532 self.completed_results[result.job_id] = result 533 self._save_result_to_file(result) 534 logger.info(f"✓ Job {result.job_id} completed successfully") 535 else: 536 error_message = result.error or "Unknown error" 537 self.failed_jobs[result.job_id] = error_message 538 self._save_result_to_file(result) 539 logger.error(f"✗ Job {result.job_id} failed: {result.error}") 540 541 # Remove completed/failed job from pending 542 self.pending_jobs = [job for job in self.pending_jobs if job.id != result.job_id] 543 544 # Update failed jobs 545 for job_id, error in failed.items(): 546 self.failed_jobs[job_id] = error 547 # Remove failed job from pending 548 self.pending_jobs = [job for job in self.pending_jobs if job.id != job_id] 549 logger.error(f"✗ Job {job_id} failed: {error}") 550 551 # Update batch tracking 552 self.completed_batches += 1 553 554 # Save state 555 self.state_manager.save(self) 556 557 def _execute_batch(self, provider, batch_jobs: List[Job]) -> Dict: 558 """Execute one batch, return results dict with jobs/costs/errors.""" 559 if not batch_jobs: 560 return {"results": [], "failed": {}, "cost": 0.0} 561 562 # Reserve cost limit 563 logger.info(f"Estimating cost for batch of {len(batch_jobs)} jobs...") 564 estimated_cost = provider.estimate_cost(batch_jobs) 565 remaining = self.cost_tracker.remaining() 566 remaining_str = f"${remaining:.4f}" if remaining is not None else "unlimited" 567 logger.info(f"Total estimated cost: ${estimated_cost:.4f}, remaining budget: {remaining_str}") 568 569 if not self.cost_tracker.reserve_cost(estimated_cost): 570 logger.warning(f"Cost limit would be exceeded, skipping batch of {len(batch_jobs)} jobs") 571 failed = {} 572 for job in batch_jobs: 573 failed[job.id] = "Cost limit exceeded" 574 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 575 576 batch_id = None 577 job_mapping = None 578 try: 579 # Create batch 580 logger.info(f"Creating batch with {len(batch_jobs)} jobs...") 581 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 582 batch_id, job_mapping = provider.create_batch(batch_jobs, raw_files_path) 583 584 # Track active batch for cancellation 585 with self._active_batches_lock: 586 self._active_batches[batch_id] = provider 587 588 # Track batch creation 589 with self._state_lock: 590 # Remove pending entry if it exists 591 pending_keys = [k for k in self.batch_tracking.keys() if k.startswith('pending_')] 592 for pending_key in pending_keys: 593 if self.batch_tracking[pending_key]['jobs'] == batch_jobs: 594 del self.batch_tracking[pending_key] 595 break 596 597 # Add actual batch tracking 598 self.batch_tracking[batch_id] = { 599 'start_time': datetime.now(), 600 'status': 'running', 601 'total': len(batch_jobs), 602 'completed': 0, 603 'cost': 0.0, 604 'estimated_cost': estimated_cost, 605 'provider': provider.__class__.__name__, 606 'jobs': batch_jobs 607 } 608 609 # Poll for completion 610 logger.info(f"Polling for batch {batch_id} completion...") 611 status, error_details = self._poll_batch_status(provider, batch_id) 612 613 if status == "failed": 614 if error_details: 615 logger.error(f"Batch {batch_id} failed with details: {error_details}") 616 else: 617 logger.error(f"Batch {batch_id} failed") 618 619 # Save error details if configured 620 if self.raw_files_dir and error_details: 621 self._save_batch_error_details(batch_id, error_details) 622 623 # Continue to get individual results - some jobs might have succeeded 624 625 # Get results 626 logger.info(f"Getting results for batch {batch_id}") 627 raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None 628 results = provider.get_batch_results(batch_id, job_mapping, raw_files_path) 629 630 # Calculate actual cost and adjust reservation 631 actual_cost = sum(r.cost_usd for r in results) 632 self.cost_tracker.adjust_reserved_cost(estimated_cost, actual_cost) 633 634 # Update batch tracking for completion 635 success_count = len([r for r in results if r.is_success]) 636 failed_count = len([r for r in results if not r.is_success]) 637 batch_status = 'complete' if failed_count == 0 else 'failed' 638 639 with self._state_lock: 640 if batch_id in self.batch_tracking: 641 self.batch_tracking[batch_id]['status'] = batch_status 642 self.batch_tracking[batch_id]['completed'] = success_count 643 self.batch_tracking[batch_id]['failed'] = failed_count 644 self.batch_tracking[batch_id]['cost'] = actual_cost 645 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 646 if batch_status == 'failed' and failed_count > 0: 647 # Use the first job's error as the batch error summary 648 first_error = next((r.error for r in results if not r.is_success), 'Some jobs failed') 649 self.batch_tracking[batch_id]['error'] = first_error 650 651 # Remove from active batches tracking 652 with self._active_batches_lock: 653 self._active_batches.pop(batch_id, None) 654 655 status_symbol = "✓" if batch_status == 'complete' else "⚠" 656 logger.info( 657 f"{status_symbol} Batch {batch_id} completed: " 658 f"{success_count} success, " 659 f"{failed_count} failed, " 660 f"cost: ${actual_cost:.6f}" 661 ) 662 663 return {"results": results, "failed": {}, "cost": actual_cost, "jobs_to_remove": list(batch_jobs)} 664 665 except TimeoutError: 666 logger.info(f"Time limit exceeded for batch{f' {batch_id}' if batch_id else ''}") 667 if batch_id: 668 # Update batch tracking for time limit exceeded 669 with self._state_lock: 670 if batch_id in self.batch_tracking: 671 self.batch_tracking[batch_id]['status'] = 'failed' 672 self.batch_tracking[batch_id]['error'] = 'Time limit exceeded: batch execution time limit exceeded' 673 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 674 # NOTE: Don't remove from _active_batches - let centralized cancellation handle it 675 # Release the reservation since batch exceeded time limit 676 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 677 # Re-raise to be handled by wrapper 678 raise 679 680 except KeyboardInterrupt: 681 logger.warning(f"\nCancelling batch{f' {batch_id}' if batch_id else ''}...") 682 if batch_id: 683 # Update batch tracking for cancellation 684 with self._state_lock: 685 if batch_id in self.batch_tracking: 686 self.batch_tracking[batch_id]['status'] = 'cancelled' 687 self.batch_tracking[batch_id]['error'] = 'Cancelled by user' 688 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 689 # Remove from active batches tracking 690 with self._active_batches_lock: 691 self._active_batches.pop(batch_id, None) 692 # Release the reservation since batch was cancelled 693 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 694 # Handle cancellation in the wrapper with proper locking 695 raise 696 697 except Exception as e: 698 logger.error(f"✗ Batch execution failed: {e}") 699 # Update batch tracking for exception 700 if batch_id: 701 with self._state_lock: 702 if batch_id in self.batch_tracking: 703 self.batch_tracking[batch_id]['status'] = 'failed' 704 self.batch_tracking[batch_id]['error'] = str(e) 705 self.batch_tracking[batch_id]['completion_time'] = datetime.now() 706 # Remove from active batches tracking 707 with self._active_batches_lock: 708 self._active_batches.pop(batch_id, None) 709 # Release the reservation since batch failed 710 self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0) 711 failed = {} 712 for job in batch_jobs: 713 failed[job.id] = str(e) 714 return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)} 715 716 717 def _save_result_to_file(self, result: JobResult): 718 """Save individual result to file.""" 719 result_file = self.results_dir / f"{result.job_id}.json" 720 721 try: 722 with open(result_file, 'w') as f: 723 json.dump(result.to_dict(), f, indent=2) 724 except Exception as e: 725 logger.error(f"Failed to save result for {result.job_id}: {e}") 726 727 def _save_batch_error_details(self, batch_id: str, error_details: Dict): 728 """Save batch error details to debug files directory.""" 729 try: 730 error_file = self.raw_files_dir / f"batch_{batch_id}_error.json" 731 with open(error_file, 'w') as f: 732 json.dump({ 733 "batch_id": batch_id, 734 "timestamp": datetime.now().isoformat(), 735 "error_details": error_details 736 }, f, indent=2) 737 logger.info(f"Saved batch error details to {error_file}") 738 except Exception as e: 739 logger.error(f"Failed to save batch error details: {e}") 740 741 @property 742 def is_complete(self) -> bool: 743 """Whether all jobs are complete.""" 744 total_jobs = len(self.jobs) 745 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 746 return len(self.pending_jobs) == 0 and completed_count == total_jobs 747 748 749 def status(self, print_status: bool = False) -> Dict: 750 """Get current execution statistics.""" 751 total_jobs = len(self.jobs) 752 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 753 remaining_count = total_jobs - completed_count 754 755 stats = { 756 "total": total_jobs, 757 "pending": remaining_count, 758 "active": 0, # Always 0 for synchronous execution 759 "completed": len(self.completed_results), 760 "failed": len(self.failed_jobs), 761 "cancelled": len(self.cancelled_jobs), 762 "cost_usd": self.cost_tracker.used_usd, 763 "cost_limit_usd": self.cost_tracker.limit_usd, 764 "is_complete": self.is_complete, 765 "batches_total": self.total_batches, 766 "batches_completed": self.completed_batches, 767 "batches_pending": self.total_batches - self.completed_batches, 768 "current_batch_index": self.current_batch_index, 769 "current_batch_size": self.current_batch_size, 770 "items_per_batch": self.config.items_per_batch 771 } 772 773 if print_status: 774 logger.info("\nBatch Run Status:") 775 logger.info(f" Total jobs: {stats['total']}") 776 logger.info(f" Pending: {stats['pending']}") 777 logger.info(f" Active: {stats['active']}") 778 logger.info(f" Completed: {stats['completed']}") 779 logger.info(f" Failed: {stats['failed']}") 780 logger.info(f" Cancelled: {stats['cancelled']}") 781 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 782 if stats['cost_limit_usd']: 783 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 784 logger.info(f" Complete: {stats['is_complete']}") 785 786 return stats 787 788 def results(self) -> Dict[str, List[JobResult]]: 789 """Get all results organized by status. 790 791 Returns: 792 { 793 "completed": [JobResult], 794 "failed": [JobResult], 795 "cancelled": [JobResult] 796 } 797 """ 798 return { 799 "completed": list(self.completed_results.values()), 800 "failed": self._create_failed_results(), 801 "cancelled": self._create_cancelled_results() 802 } 803 804 def get_failed_jobs(self) -> Dict[str, str]: 805 """Get failed jobs with error messages. 806 807 Note: This method is deprecated. Use results()['failed'] instead. 808 """ 809 return dict(self.failed_jobs) 810 811 def _create_failed_results(self) -> List[JobResult]: 812 """Convert failed jobs to JobResult objects.""" 813 failed_results = [] 814 for job_id, error_msg in self.failed_jobs.items(): 815 failed_results.append(JobResult( 816 job_id=job_id, 817 raw_response=None, 818 parsed_response=None, 819 error=error_msg, 820 cost_usd=0.0, 821 input_tokens=0, 822 output_tokens=0 823 )) 824 return failed_results 825 826 def _create_cancelled_results(self) -> List[JobResult]: 827 """Convert cancelled jobs to JobResult objects.""" 828 cancelled_results = [] 829 for job_id, reason in self.cancelled_jobs.items(): 830 cancelled_results.append(JobResult( 831 job_id=job_id, 832 raw_response=None, 833 parsed_response=None, 834 error=reason, 835 cost_usd=0.0, 836 input_tokens=0, 837 output_tokens=0 838 )) 839 return cancelled_results 840 841 def shutdown(self): 842 """Shutdown (no-op for synchronous execution).""" 843 pass 844 845 def dry_run(self) -> 'BatchRun': 846 """Perform a dry run - show cost estimation and job details without executing. 847 848 Returns: 849 Self for chaining (doesn't actually execute jobs) 850 """ 851 logger.info("=== DRY RUN MODE ===") 852 logger.info("This will show cost estimates without executing jobs") 853 854 # Load existing state if reuse_state=True 855 if self.config.reuse_state: 856 self.state_manager.load_state(self) 857 858 # Filter out completed jobs from previous runs 859 self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results] 860 861 if not self.pending_jobs: 862 logger.info("No pending jobs to analyze (all jobs already completed)") 863 return self 864 865 logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...") 866 867 # Group jobs by provider and analyze costs 868 provider_groups = self._group_jobs_by_provider() 869 total_estimated_cost = 0.0 870 871 logger.info(f"\nJob breakdown:") 872 for provider_name, jobs in provider_groups.items(): 873 provider = get_provider(jobs[0].model) 874 logger.info(f"\n{provider_name} ({len(jobs)} jobs):") 875 876 job_batches = [jobs[i:i + self.config.items_per_batch] 877 for i in range(0, len(jobs), self.config.items_per_batch)] 878 879 for batch_idx, batch_jobs in enumerate(job_batches, 1): 880 estimated_cost = provider.estimate_cost(batch_jobs) 881 total_estimated_cost += estimated_cost 882 883 logger.info(f" Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}") 884 for job in batch_jobs: 885 if job.file: 886 logger.info(f" - {job.id}: {job.file.name} (citations: {job.enable_citations})") 887 else: 888 logger.info(f" - {job.id}: direct messages (citations: {job.enable_citations})") 889 890 # Show cost summary 891 logger.info(f"\n=== COST SUMMARY ===") 892 logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}") 893 894 if self.config.cost_limit_usd: 895 logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}") 896 if total_estimated_cost > self.config.cost_limit_usd: 897 excess = total_estimated_cost - self.config.cost_limit_usd 898 logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}") 899 else: 900 remaining = self.config.cost_limit_usd - total_estimated_cost 901 logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)") 902 else: 903 logger.info("No cost limit set") 904 905 # Show execution plan 906 logger.info(f"\n=== EXECUTION PLAN ===") 907 total_batches = sum( 908 len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0) 909 for jobs in provider_groups.values() 910 ) 911 logger.info(f"Total batches to process: {total_batches}") 912 logger.info(f"Max parallel batches: {self.config.max_parallel_batches}") 913 logger.info(f"Items per batch: {self.config.items_per_batch}") 914 logger.info(f"Results directory: {self.config.results_dir}") 915 916 logger.info("\n=== DRY RUN COMPLETE ===") 917 logger.info("To execute for real, call run() without dry_run=True") 918 919 return self
Manages the execution of a batch job synchronously.
Processes jobs in batches based on items_per_batch configuration. Simpler synchronous execution for clear logging and debugging.
Example:
config = BatchParams(...)
run = BatchRun(config, jobs)
run.execute()
results = run.results()
40 def __init__(self, config: BatchParams, jobs: List[Job]): 41 """Initialize batch run. 42 43 Args: 44 config: Batch configuration 45 jobs: List of jobs to execute 46 """ 47 self.config = config 48 self.jobs = {job.id: job for job in jobs} 49 50 # Set logging level based on config 51 set_log_level(level=config.verbosity.upper()) 52 53 # Initialize components 54 self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd) 55 56 # Use temp file for state if not provided 57 state_file = config.state_file 58 if not state_file: 59 state_file = create_temp_state_file(config) 60 config.reuse_state = False 61 logger.info(f"Created temporary state file: {state_file}") 62 63 self.state_manager = StateManager(state_file) 64 65 # State tracking 66 self.pending_jobs: List[Job] = [] 67 self.completed_results: Dict[str, JobResult] = {} # job_id -> result 68 self.failed_jobs: Dict[str, str] = {} # job_id -> error 69 self.cancelled_jobs: Dict[str, str] = {} # job_id -> reason 70 71 # Batch tracking 72 self.total_batches = 0 73 self.completed_batches = 0 74 self.current_batch_index = 0 75 self.current_batch_size = 0 76 77 # Execution control 78 self._started = False 79 self._start_time: Optional[datetime] = None 80 self._time_limit_exceeded = False 81 self._progress_callback: Optional[Callable[[Dict, float], None]] = None 82 self._progress_interval: float = 1.0 # Default to 1 second 83 84 # Threading primitives 85 self._state_lock = threading.Lock() 86 self._shutdown_event = threading.Event() 87 self._progress_lock = threading.Lock() 88 self._last_progress_update = 0.0 89 90 # Batch tracking for progress display 91 self.batch_tracking: Dict[str, Dict] = {} # batch_id -> batch_info 92 93 # Active batch tracking for cancellation 94 self._active_batches: Dict[str, object] = {} # batch_id -> provider 95 self._active_batches_lock = threading.Lock() 96 97 # Results directory 98 self.results_dir = Path(config.results_dir) 99 100 # If not reusing state, clear the results directory 101 if not config.reuse_state and self.results_dir.exists(): 102 import shutil 103 shutil.rmtree(self.results_dir) 104 105 self.results_dir.mkdir(parents=True, exist_ok=True) 106 107 # Raw files directory (if enabled) 108 self.raw_files_dir = None 109 if config.raw_files: 110 self.raw_files_dir = self.results_dir / "raw_files" 111 self.raw_files_dir.mkdir(parents=True, exist_ok=True) 112 113 # Try to resume from saved state 114 self._resume_from_state()
Initialize batch run.
Args: config: Batch configuration jobs: List of jobs to execute
178 def to_json(self) -> Dict: 179 """Convert current state to JSON-serializable dict.""" 180 return { 181 "created_at": datetime.now().isoformat(), 182 "pending_jobs": [job.to_dict() for job in self.pending_jobs], 183 "completed_results": [ 184 {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")} 185 for job_id in self.completed_results.keys() 186 ], 187 "failed_jobs": [ 188 { 189 "id": job_id, 190 "error": error, 191 "timestamp": datetime.now().isoformat() 192 } for job_id, error in self.failed_jobs.items() 193 ], 194 "cancelled_jobs": [ 195 { 196 "id": job_id, 197 "reason": reason, 198 "timestamp": datetime.now().isoformat() 199 } for job_id, reason in self.cancelled_jobs.items() 200 ], 201 "total_cost_usd": self.cost_tracker.used_usd, 202 "config": { 203 "state_file": self.config.state_file, 204 "results_dir": self.config.results_dir, 205 "max_parallel_batches": self.config.max_parallel_batches, 206 "items_per_batch": self.config.items_per_batch, 207 "cost_limit_usd": self.config.cost_limit_usd, 208 "default_params": self.config.default_params, 209 "raw_files": self.config.raw_files 210 } 211 }
Convert current state to JSON-serializable dict.
213 def execute(self): 214 """Execute synchronous batch run and wait for completion.""" 215 if self._started: 216 raise RuntimeError("Batch run already started") 217 218 self._started = True 219 self._start_time = datetime.now() 220 221 # Register signal handler for graceful shutdown 222 def signal_handler(signum, frame): 223 logger.warning("Received interrupt signal, shutting down gracefully...") 224 self._shutdown_event.set() 225 226 # Store original handler to restore later 227 original_handler = signal.signal(signal.SIGINT, signal_handler) 228 229 try: 230 logger.info("Starting batch run") 231 232 # Start time limit watchdog if configured 233 self._start_time_limit_watchdog() 234 235 # Call initial progress 236 if self._progress_callback: 237 with self._progress_lock: 238 with self._state_lock: 239 stats = self.status() 240 batch_data = dict(self.batch_tracking) 241 self._progress_callback(stats, 0.0, batch_data) 242 self._last_progress_update = time.time() 243 244 # Process all jobs synchronously 245 self._process_all_jobs() 246 247 logger.info("Batch run completed") 248 finally: 249 # Restore original signal handler 250 signal.signal(signal.SIGINT, original_handler)
Execute synchronous batch run and wait for completion.
252 def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun': 253 """Set progress callback for execution monitoring. 254 255 The callback will be called periodically with progress statistics 256 including completed jobs, total jobs, current cost, etc. 257 258 Args: 259 callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) 260 - stats_dict: Progress statistics dictionary 261 - elapsed_time_seconds: Time elapsed since batch started (float) 262 - batch_data: Dictionary mapping batch_id to batch information 263 interval: Interval in seconds between progress updates (default: 1.0) 264 265 Returns: 266 Self for chaining 267 268 Example: 269 ```python 270 run.set_on_progress( 271 lambda stats, time, batch_data: print( 272 f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s" 273 ) 274 ) 275 ``` 276 """ 277 self._progress_callback = callback 278 self._progress_interval = interval 279 return self
Set progress callback for execution monitoring.
The callback will be called periodically with progress statistics including completed jobs, total jobs, current cost, etc.
Args: callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) - stats_dict: Progress statistics dictionary - elapsed_time_seconds: Time elapsed since batch started (float) - batch_data: Dictionary mapping batch_id to batch information interval: Interval in seconds between progress updates (default: 1.0)
Returns: Self for chaining
Example:
run.set_on_progress(
lambda stats, time, batch_data: print(
f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
)
)
741 @property 742 def is_complete(self) -> bool: 743 """Whether all jobs are complete.""" 744 total_jobs = len(self.jobs) 745 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 746 return len(self.pending_jobs) == 0 and completed_count == total_jobs
Whether all jobs are complete.
749 def status(self, print_status: bool = False) -> Dict: 750 """Get current execution statistics.""" 751 total_jobs = len(self.jobs) 752 completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs) 753 remaining_count = total_jobs - completed_count 754 755 stats = { 756 "total": total_jobs, 757 "pending": remaining_count, 758 "active": 0, # Always 0 for synchronous execution 759 "completed": len(self.completed_results), 760 "failed": len(self.failed_jobs), 761 "cancelled": len(self.cancelled_jobs), 762 "cost_usd": self.cost_tracker.used_usd, 763 "cost_limit_usd": self.cost_tracker.limit_usd, 764 "is_complete": self.is_complete, 765 "batches_total": self.total_batches, 766 "batches_completed": self.completed_batches, 767 "batches_pending": self.total_batches - self.completed_batches, 768 "current_batch_index": self.current_batch_index, 769 "current_batch_size": self.current_batch_size, 770 "items_per_batch": self.config.items_per_batch 771 } 772 773 if print_status: 774 logger.info("\nBatch Run Status:") 775 logger.info(f" Total jobs: {stats['total']}") 776 logger.info(f" Pending: {stats['pending']}") 777 logger.info(f" Active: {stats['active']}") 778 logger.info(f" Completed: {stats['completed']}") 779 logger.info(f" Failed: {stats['failed']}") 780 logger.info(f" Cancelled: {stats['cancelled']}") 781 logger.info(f" Cost: ${stats['cost_usd']:.6f}") 782 if stats['cost_limit_usd']: 783 logger.info(f" Cost limit: ${stats['cost_limit_usd']:.2f}") 784 logger.info(f" Complete: {stats['is_complete']}") 785 786 return stats
Get current execution statistics.
788 def results(self) -> Dict[str, List[JobResult]]: 789 """Get all results organized by status. 790 791 Returns: 792 { 793 "completed": [JobResult], 794 "failed": [JobResult], 795 "cancelled": [JobResult] 796 } 797 """ 798 return { 799 "completed": list(self.completed_results.values()), 800 "failed": self._create_failed_results(), 801 "cancelled": self._create_cancelled_results() 802 }
Get all results organized by status.
Returns: { "completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult] }
804 def get_failed_jobs(self) -> Dict[str, str]: 805 """Get failed jobs with error messages. 806 807 Note: This method is deprecated. Use results()['failed'] instead. 808 """ 809 return dict(self.failed_jobs)
Get failed jobs with error messages.
Note: This method is deprecated. Use results()['failed'] instead.
845 def dry_run(self) -> 'BatchRun': 846 """Perform a dry run - show cost estimation and job details without executing. 847 848 Returns: 849 Self for chaining (doesn't actually execute jobs) 850 """ 851 logger.info("=== DRY RUN MODE ===") 852 logger.info("This will show cost estimates without executing jobs") 853 854 # Load existing state if reuse_state=True 855 if self.config.reuse_state: 856 self.state_manager.load_state(self) 857 858 # Filter out completed jobs from previous runs 859 self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results] 860 861 if not self.pending_jobs: 862 logger.info("No pending jobs to analyze (all jobs already completed)") 863 return self 864 865 logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...") 866 867 # Group jobs by provider and analyze costs 868 provider_groups = self._group_jobs_by_provider() 869 total_estimated_cost = 0.0 870 871 logger.info(f"\nJob breakdown:") 872 for provider_name, jobs in provider_groups.items(): 873 provider = get_provider(jobs[0].model) 874 logger.info(f"\n{provider_name} ({len(jobs)} jobs):") 875 876 job_batches = [jobs[i:i + self.config.items_per_batch] 877 for i in range(0, len(jobs), self.config.items_per_batch)] 878 879 for batch_idx, batch_jobs in enumerate(job_batches, 1): 880 estimated_cost = provider.estimate_cost(batch_jobs) 881 total_estimated_cost += estimated_cost 882 883 logger.info(f" Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}") 884 for job in batch_jobs: 885 if job.file: 886 logger.info(f" - {job.id}: {job.file.name} (citations: {job.enable_citations})") 887 else: 888 logger.info(f" - {job.id}: direct messages (citations: {job.enable_citations})") 889 890 # Show cost summary 891 logger.info(f"\n=== COST SUMMARY ===") 892 logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}") 893 894 if self.config.cost_limit_usd: 895 logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}") 896 if total_estimated_cost > self.config.cost_limit_usd: 897 excess = total_estimated_cost - self.config.cost_limit_usd 898 logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}") 899 else: 900 remaining = self.config.cost_limit_usd - total_estimated_cost 901 logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)") 902 else: 903 logger.info("No cost limit set") 904 905 # Show execution plan 906 logger.info(f"\n=== EXECUTION PLAN ===") 907 total_batches = sum( 908 len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0) 909 for jobs in provider_groups.values() 910 ) 911 logger.info(f"Total batches to process: {total_batches}") 912 logger.info(f"Max parallel batches: {self.config.max_parallel_batches}") 913 logger.info(f"Items per batch: {self.config.items_per_batch}") 914 logger.info(f"Results directory: {self.config.results_dir}") 915 916 logger.info("\n=== DRY RUN COMPLETE ===") 917 logger.info("To execute for real, call run() without dry_run=True") 918 919 return self
Perform a dry run - show cost estimation and job details without executing.
Returns: Self for chaining (doesn't actually execute jobs)
12@dataclass 13class Job: 14 """Configuration for a single AI job. 15 16 Either provide messages OR prompt (with optional file), not both. 17 18 Attributes: 19 id: Unique identifier for the job 20 messages: Chat messages for direct message input 21 file: Optional file path for file-based input 22 prompt: Prompt text (can be used alone or with file) 23 model: Model name (e.g., "claude-3-sonnet") 24 temperature: Sampling temperature (0.0-1.0) 25 max_tokens: Maximum tokens to generate 26 response_model: Pydantic model for structured output 27 enable_citations: Whether to extract citations from response 28 """ 29 30 id: str # Unique identifier 31 model: str # Model name (e.g., "claude-3-sonnet") 32 messages: Optional[List[Message]] = None # Chat messages 33 file: Optional[Path] = None # File input 34 prompt: Optional[str] = None # Prompt for file 35 temperature: float = 0.7 36 max_tokens: int = 1000 37 response_model: Optional[Type[BaseModel]] = None # For structured output 38 enable_citations: bool = False 39 40 def __post_init__(self): 41 """Validate job configuration.""" 42 if self.messages and (self.file or self.prompt): 43 raise ValueError("Provide either messages OR file+prompt, not both") 44 45 if self.file and not self.prompt: 46 raise ValueError("File input requires a prompt") 47 48 if not self.messages and not self.prompt: 49 raise ValueError("Must provide either messages or prompt") 50 51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 } 64 65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Configuration for a single AI job.
Either provide messages OR prompt (with optional file), not both.
Attributes: id: Unique identifier for the job messages: Chat messages for direct message input file: Optional file path for file-based input prompt: Prompt text (can be used alone or with file) model: Model name (e.g., "claude-3-sonnet") temperature: Sampling temperature (0.0-1.0) max_tokens: Maximum tokens to generate response_model: Pydantic model for structured output enable_citations: Whether to extract citations from response
51 def to_dict(self) -> Dict[str, Any]: 52 """Serialize for state persistence.""" 53 return { 54 "id": self.id, 55 "model": self.model, 56 "messages": self.messages, 57 "file": str(self.file) if self.file else None, 58 "prompt": self.prompt, 59 "temperature": self.temperature, 60 "max_tokens": self.max_tokens, 61 "response_model": self.response_model.__name__ if self.response_model else None, 62 "enable_citations": self.enable_citations 63 }
Serialize for state persistence.
65 @classmethod 66 def from_dict(cls, data: Dict[str, Any]) -> 'Job': 67 """Deserialize from state.""" 68 # Convert file string back to Path if present 69 file_path = None 70 if data.get("file"): 71 file_path = Path(data["file"]) 72 73 # Note: response_model reconstruction would need additional logic 74 # For now, we'll set it to None during deserialization 75 return cls( 76 id=data["id"], 77 model=data["model"], 78 messages=data.get("messages"), 79 file=file_path, 80 prompt=data.get("prompt"), 81 temperature=data.get("temperature", 0.7), 82 max_tokens=data.get("max_tokens", 1000), 83 response_model=None, # Cannot reconstruct from string 84 enable_citations=data.get("enable_citations", False) 85 )
Deserialize from state.
11@dataclass 12class JobResult: 13 """Result from a completed AI job. 14 15 Attributes: 16 job_id: ID of the job this result is for 17 raw_response: Raw text response from the model (None for failed jobs) 18 parsed_response: Structured output (if response_model was used) 19 citations: Extracted citations (if enable_citations was True) 20 citation_mappings: Maps field names to relevant citations (if response_model used) 21 input_tokens: Number of input tokens used 22 output_tokens: Number of output tokens generated 23 cost_usd: Total cost in USD 24 error: Error message if job failed 25 batch_id: ID of the batch this job was part of (for mapping to raw files) 26 """ 27 28 job_id: str 29 raw_response: Optional[str] = None # Raw text response (None for failed jobs) 30 parsed_response: Optional[Union[BaseModel, Dict]] = None # Structured output or error dict 31 citations: Optional[List[Citation]] = None # Extracted citations 32 citation_mappings: Optional[Dict[str, List[Citation]]] = None # Field -> citations mapping 33 input_tokens: int = 0 34 output_tokens: int = 0 35 cost_usd: float = 0.0 36 error: Optional[str] = None # Error message if failed 37 batch_id: Optional[str] = None # Batch ID for mapping to raw files 38 39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None 43 44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens 48 49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 } 81 82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Result from a completed AI job.
Attributes: job_id: ID of the job this result is for raw_response: Raw text response from the model (None for failed jobs) parsed_response: Structured output (if response_model was used) citations: Extracted citations (if enable_citations was True) citation_mappings: Maps field names to relevant citations (if response_model used) input_tokens: Number of input tokens used output_tokens: Number of output tokens generated cost_usd: Total cost in USD error: Error message if job failed batch_id: ID of the batch this job was part of (for mapping to raw files)
39 @property 40 def is_success(self) -> bool: 41 """Whether the job completed successfully.""" 42 return self.error is None
Whether the job completed successfully.
44 @property 45 def total_tokens(self) -> int: 46 """Total tokens used (input + output).""" 47 return self.input_tokens + self.output_tokens
Total tokens used (input + output).
49 def to_dict(self) -> Dict[str, Any]: 50 """Serialize for state persistence.""" 51 # Handle parsed_response serialization 52 parsed_response = None 53 if self.parsed_response is not None: 54 if isinstance(self.parsed_response, dict): 55 parsed_response = self.parsed_response 56 elif isinstance(self.parsed_response, BaseModel): 57 parsed_response = self.parsed_response.model_dump() 58 else: 59 parsed_response = str(self.parsed_response) 60 61 # Handle citation_mappings serialization 62 citation_mappings = None 63 if self.citation_mappings: 64 citation_mappings = { 65 field: [asdict(c) for c in citations] 66 for field, citations in self.citation_mappings.items() 67 } 68 69 return { 70 "job_id": self.job_id, 71 "raw_response": self.raw_response, 72 "parsed_response": parsed_response, 73 "citations": [asdict(c) for c in self.citations] if self.citations else None, 74 "citation_mappings": citation_mappings, 75 "input_tokens": self.input_tokens, 76 "output_tokens": self.output_tokens, 77 "cost_usd": self.cost_usd, 78 "error": self.error, 79 "batch_id": self.batch_id 80 }
Serialize for state persistence.
82 @classmethod 83 def from_dict(cls, data: Dict[str, Any]) -> 'JobResult': 84 """Deserialize from state.""" 85 # Reconstruct citations if present 86 citations = None 87 if data.get("citations"): 88 citations = [Citation(**c) for c in data["citations"]] 89 90 # Reconstruct citation_mappings if present 91 citation_mappings = None 92 if data.get("citation_mappings"): 93 citation_mappings = { 94 field: [Citation(**c) for c in citations] 95 for field, citations in data["citation_mappings"].items() 96 } 97 98 return cls( 99 job_id=data["job_id"], 100 raw_response=data["raw_response"], 101 parsed_response=data.get("parsed_response"), 102 citations=citations, 103 citation_mappings=citation_mappings, 104 input_tokens=data.get("input_tokens", 0), 105 output_tokens=data.get("output_tokens", 0), 106 cost_usd=data.get("cost_usd", 0.0), 107 error=data.get("error"), 108 batch_id=data.get("batch_id") 109 )
Deserialize from state.
8@dataclass 9class Citation: 10 """Represents a citation extracted from an AI response.""" 11 12 text: str # The cited text 13 source: str # Source identifier (e.g., page number, section) 14 page: Optional[int] = None # Page number if applicable 15 metadata: Optional[Dict[str, Any]] = None # Additional metadata
Represents a citation extracted from an AI response.
Base exception for all Batchata errors.
35class CostLimitExceededError(BatchataError): 36 """Raised when cost limit would be exceeded.""" 37 pass
Raised when cost limit would be exceeded.
Base exception for provider-related errors.
20class ProviderNotFoundError(ProviderError): 21 """Raised when no provider is found for a model.""" 22 pass
Raised when no provider is found for a model.
10class ValidationError(BatchataError): 11 """Raised when job or configuration validation fails.""" 12 pass
Raised when job or configuration validation fails.