batchata

Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.

Why AI-batching?

AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. This is ideal for workloads like document processing, data analysis, and content generation where immediate responses aren't required.

Quick Start

Installation

pip install batchata

Basic Usage

from batchata import Batch

# Simple batch processing
batch = Batch(results_dir="./output")
    .set_default_params(model="claude-sonnet-4-20250514")
    .add_cost_limit(usd=5.0)

# Add jobs
for file in files:
    batch.add_job(file=file, prompt="Summarize this document")

# Execute
run = batch.run()
results = run.results()

Structured Output with Pydantic

from batchata import Batch
from pydantic import BaseModel

class DocumentAnalysis(BaseModel):
    title: str
    summary: str
    key_points: list[str]

batch = Batch(results_dir="./results")
    .set_default_params(model="claude-sonnet-4-20250514")

batch.add_job(
    file="document.pdf",
    prompt="Analyze this document",
    response_model=DocumentAnalysis,
    enable_citations=True  # Anthropic only
)

run = batch.run()
for result in run.results()["completed"]:
    analysis = result.parsed_response  # DocumentAnalysis object
    citations = result.citation_mappings  # Field -> Citation mapping

Key Features

  • 50% Cost Savings: Native batch processing via provider APIs
  • Cost Limits: Set max_cost_usd limits for batch requests
  • Time Limits: Control execution time with .add_time_limit()
  • State Persistence: Resume interrupted batches automatically
  • Structured Output: Pydantic models with automatic validation
  • Citations: Extract and map citations to response fields (Anthropic)
  • Multiple Providers: Anthropic Claude and OpenAI GPT models

Supported Providers

Feature Anthropic OpenAI
Models All Claude models All GPT models
Citations
Structured Output
File Types PDF, TXT, DOCX, Images PDF, Images

Configuration

Set API keys as environment variables:

export ANTHROPIC_API_KEY="your-key"
export OPENAI_API_KEY="your-key"

Or use a .env file with python-dotenv.

  1"""Batchata - Unified Python API for AI Batch requests with cost tracking, Pydantic responses, and parallel execution.
  2
  3**Why AI-batching?**
  4
  5AI providers offer batch APIs that process requests asynchronously at 50% reduced cost compared to real-time APIs. 
  6This is ideal for workloads like document processing, data analysis, and content generation where immediate 
  7responses aren't required.
  8
  9## Quick Start
 10
 11### Installation
 12
 13```bash
 14pip install batchata
 15```
 16
 17### Basic Usage
 18
 19```python
 20from batchata import Batch
 21
 22# Simple batch processing
 23batch = Batch(results_dir="./output")
 24    .set_default_params(model="claude-sonnet-4-20250514")
 25    .add_cost_limit(usd=5.0)
 26
 27# Add jobs
 28for file in files:
 29    batch.add_job(file=file, prompt="Summarize this document")
 30
 31# Execute
 32run = batch.run()
 33results = run.results()
 34```
 35
 36### Structured Output with Pydantic
 37
 38```python
 39from batchata import Batch
 40from pydantic import BaseModel
 41
 42class DocumentAnalysis(BaseModel):
 43    title: str
 44    summary: str
 45    key_points: list[str]
 46
 47batch = Batch(results_dir="./results")
 48    .set_default_params(model="claude-sonnet-4-20250514")
 49
 50batch.add_job(
 51    file="document.pdf",
 52    prompt="Analyze this document",
 53    response_model=DocumentAnalysis,
 54    enable_citations=True  # Anthropic only
 55)
 56
 57run = batch.run()
 58for result in run.results()["completed"]:
 59    analysis = result.parsed_response  # DocumentAnalysis object
 60    citations = result.citation_mappings  # Field -> Citation mapping
 61```
 62
 63## Key Features
 64
 65- **50% Cost Savings**: Native batch processing via provider APIs
 66- **Cost Limits**: Set `max_cost_usd` limits for batch requests  
 67- **Time Limits**: Control execution time with `.add_time_limit()`
 68- **State Persistence**: Resume interrupted batches automatically
 69- **Structured Output**: Pydantic models with automatic validation
 70- **Citations**: Extract and map citations to response fields (Anthropic)
 71- **Multiple Providers**: Anthropic Claude and OpenAI GPT models
 72
 73## Supported Providers
 74
 75| Feature | Anthropic | OpenAI |
 76|---------|-----------|--------|
 77| Models | [All Claude models](https://github.com/agamm/batchata/blob/main/batchata/providers/anthropic/models.py) | [All GPT models](https://github.com/agamm/batchata/blob/main/batchata/providers/openai/models.py) |
 78| Citations | ✅ | ❌ |
 79| Structured Output | ✅ | ✅ |
 80| File Types | PDF, TXT, DOCX, Images | PDF, Images |
 81
 82## Configuration
 83
 84Set API keys as environment variables:
 85
 86```bash
 87export ANTHROPIC_API_KEY="your-key"
 88export OPENAI_API_KEY="your-key"
 89```
 90
 91Or use a `.env` file with python-dotenv.
 92"""
 93
 94from .core import Batch, BatchRun, Job, JobResult
 95from .exceptions import (
 96    BatchataError,
 97    CostLimitExceededError,
 98    ProviderError,
 99    ProviderNotFoundError,
100    ValidationError,
101)
102from .types import Citation
103
104__version__ = "0.3.0"
105
106__all__ = [
107    "Batch",
108    "BatchRun", 
109    "Job",
110    "JobResult",
111    "Citation",
112    "BatchataError",
113    "CostLimitExceededError",
114    "ProviderError",
115    "ProviderNotFoundError",
116    "ValidationError",
117]
class Batch:
 19class Batch:
 20    """Builder for batch job configuration.
 21    
 22    Provides a fluent interface for configuring batch jobs with sensible defaults
 23    and validation. The batch can be configured with cost limits, default parameters,
 24    and progress callbacks.
 25    
 26    Example:
 27        ```python
 28        batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
 29            .set_state(file="./state.json", reuse_state=True)
 30            .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
 31            .add_cost_limit(usd=15.0)
 32            .add_job(messages=[{"role": "user", "content": "Hello"}])
 33            .add_job(file="./path/to/file.pdf", prompt="Generate summary of file")
 34        
 35        run = batch.run()
 36        ```
 37    """
 38    
 39    def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None):
 40        """Initialize batch configuration.
 41        
 42        Args:
 43            results_dir: Directory to store results
 44            max_parallel_batches: Maximum parallel batch requests
 45            items_per_batch: Number of jobs per provider batch
 46            raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
 47        """
 48        # Auto-determine raw_files based on results_dir if not explicitly set
 49        if raw_files is None:
 50            raw_files = bool(results_dir and results_dir.strip())
 51        
 52        self.config = BatchParams(
 53            state_file=None,
 54            results_dir=results_dir,
 55            max_parallel_batches=max_parallel_batches,
 56            items_per_batch=items_per_batch,
 57            reuse_state=True,
 58            raw_files=raw_files
 59        )
 60        self.jobs: List[Job] = []
 61    
 62    def set_default_params(self, **kwargs) -> 'Batch':
 63        """Set default parameters for all jobs.
 64        
 65        These defaults will be applied to all jobs unless overridden
 66        by job-specific parameters.
 67        
 68        Args:
 69            **kwargs: Default parameters (model, temperature, max_tokens, etc.)
 70            
 71        Returns:
 72            Self for chaining
 73            
 74        Example:
 75            ```python
 76            batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
 77            ```
 78        """
 79        # Validate if model is provided
 80        if "model" in kwargs:
 81            self.config.validate_default_params(kwargs["model"])
 82        
 83        self.config.default_params.update(kwargs)
 84        return self
 85    
 86    def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch':
 87        """Set state file configuration.
 88        
 89        Args:
 90            file: Path to state file for persistence (default: None)
 91            reuse_state: Whether to resume from existing state file (default: True)
 92            
 93        Returns:
 94            Self for chaining
 95            
 96        Example:
 97            ```python
 98            batch.set_state(file="./state.json", reuse_state=True)
 99            ```
100        """
101        self.config.state_file = file
102        self.config.reuse_state = reuse_state
103        return self
104    
105    def add_cost_limit(self, usd: float) -> 'Batch':
106        """Add cost limit for the batch.
107        
108        The batch will stop accepting new jobs once the cost limit is reached.
109        Active jobs will be allowed to complete.
110        
111        Args:
112            usd: Cost limit in USD
113            
114        Returns:
115            Self for chaining
116            
117        Example:
118            ```python
119            batch.add_cost_limit(usd=50.0)
120            ```
121        """
122        if usd <= 0:
123            raise ValueError("Cost limit must be positive")
124        self.config.cost_limit_usd = usd
125        return self
126    
127    def raw_files(self, enabled: bool = True) -> 'Batch':
128        """Enable or disable saving debug files from providers.
129        
130        When enabled, debug files (raw API responses, JSONL files) will be saved
131        in a 'raw_files' subdirectory within the results directory.
132        This is useful for debugging, auditing, or accessing provider-specific metadata.
133        
134        Args:
135            enabled: Whether to save debug files (default: True)
136            
137        Returns:
138            Self for chaining
139            
140        Example:
141            ```python
142            batch.raw_files(True)
143            ```
144        """
145        self.config.raw_files = enabled
146        return self
147    
148    def set_verbosity(self, level: str) -> 'Batch':
149        """Set logging verbosity level.
150        
151        Args:
152            level: Verbosity level ("debug", "info", "warn", "error")
153            
154        Returns:
155            Self for chaining
156            
157        Example:
158            ```python
159            batch.set_verbosity("error")  # For production
160            batch.set_verbosity("debug")  # For debugging
161            ```
162        """
163        valid_levels = {"debug", "info", "warn", "error"}
164        if level.lower() not in valid_levels:
165            raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}")
166        self.config.verbosity = level.lower()
167        return self
168    
169    def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch':
170        """Add time limit for the entire batch execution.
171        
172        When time limit is reached, all active provider batches are cancelled and 
173        remaining unprocessed jobs are marked as failed. The batch execution 
174        completes normally without throwing exceptions.
175        
176        Args:
177            seconds: Time limit in seconds (optional)
178            minutes: Time limit in minutes (optional)
179            hours: Time limit in hours (optional)
180            
181        Returns:
182            Self for chaining
183            
184        Raises:
185            ValueError: If no time units specified, or if total time is outside 
186                       valid range (min: 10 seconds, max: 24 hours)
187            
188        Note:
189            - Can combine multiple time units
190            - Time limit is checked every second by a background watchdog thread
191            - Jobs that exceed time limit appear in results()["failed"] with time limit error message
192            - No exceptions are thrown when time limit is reached
193            
194        Example:
195            ```python
196            batch.add_time_limit(seconds=30)  # 30 seconds
197            batch.add_time_limit(minutes=5)   # 5 minutes
198            batch.add_time_limit(hours=2)     # 2 hours
199            batch.add_time_limit(hours=1, minutes=30, seconds=15)  # 5415 seconds total
200            ```
201        """
202        time_limit_seconds = 0.0
203        
204        if seconds is not None:
205            time_limit_seconds += seconds
206        if minutes is not None:
207            time_limit_seconds += minutes * 60
208        if hours is not None:
209            time_limit_seconds += hours * 3600
210            
211        if time_limit_seconds == 0:
212            raise ValueError("Must specify at least one of seconds, minutes, or hours")
213            
214        self.config.time_limit_seconds = time_limit_seconds
215        return self
216    
217    def add_job(
218        self,
219        messages: Optional[List[Message]] = None,
220        file: Optional[Union[str, Path]] = None,
221        prompt: Optional[str] = None,
222        model: Optional[str] = None,
223        temperature: Optional[float] = None,
224        max_tokens: Optional[int] = None,
225        response_model: Optional[Type[BaseModel]] = None,
226        enable_citations: bool = False,
227        **kwargs
228    ) -> 'Batch':
229        """Add a job to the batch.
230        
231        Either provide messages OR file+prompt, not both. Parameters not provided
232        will use the defaults set via the defaults() method.
233        
234        Args:
235            messages: Chat messages for direct input
236            file: File path for file-based input
237            prompt: Prompt to use with file input
238            model: Model to use (overrides default)
239            temperature: Sampling temperature (overrides default)
240            max_tokens: Max tokens to generate (overrides default)
241            response_model: Pydantic model for structured output
242            enable_citations: Whether to extract citations
243            **kwargs: Additional parameters
244            
245        Returns:
246            Self for chaining
247            
248        Example:
249            ```python
250            batch.add_job(
251                messages=[{"role": "user", "content": "Hello"}],
252                model="gpt-4"
253            )
254            ```
255        """
256        # Generate unique job ID
257        job_id = f"job-{uuid.uuid4().hex[:8]}"
258        
259        # Merge with defaults
260        params = self.config.default_params.copy()
261        
262        # Update with provided parameters
263        if model is not None:
264            params["model"] = model
265        if temperature is not None:
266            params["temperature"] = temperature
267        if max_tokens is not None:
268            params["max_tokens"] = max_tokens
269        
270        # Add other kwargs
271        params.update(kwargs)
272        
273        # Ensure model is provided
274        if "model" not in params:
275            raise ValueError("Model must be provided either in defaults or job parameters")
276        
277        # Validate parameters
278        provider = get_provider(params["model"])
279        # Extract params without model to avoid duplicate
280        param_subset = {k: v for k, v in params.items() if k != "model"}
281        provider.validate_params(params["model"], **param_subset)
282        
283        # Convert file path if string
284        if isinstance(file, str):
285            file = Path(file)
286        
287        # Warn about temporary file paths that may not persist
288        if file:
289            file_str = str(file)
290            if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower():
291                logger = logging.getLogger("batchata")
292                logger.debug(f"File path appears to be in a temporary directory: {file}")
293                logger.debug("This may cause issues when resuming from state if temp files are cleaned up")
294        
295        # Create job
296        job = Job(
297            id=job_id,
298            messages=messages,
299            file=file,
300            prompt=prompt,
301            response_model=response_model,
302            enable_citations=enable_citations,
303            **params
304        )
305        
306        # Validate citation compatibility
307        if response_model and enable_citations:
308            from ..utils.validation import validate_flat_model
309            validate_flat_model(response_model)
310        
311        # Validate job with provider (includes PDF validation for Anthropic)
312        provider.validate_job(job)
313        
314        
315        self.jobs.append(job)
316        return self
317    
318    def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun':
319        """Execute the batch.
320        
321        Creates a BatchRun instance and executes the jobs synchronously.
322        
323        Args:
324            on_progress: Optional progress callback function that receives
325                        (stats_dict, elapsed_time_seconds, batch_data)
326            progress_interval: Interval in seconds between progress updates (default: 1.0)
327            print_status: Whether to show rich progress display (default: False)
328            dry_run: If True, only show cost estimation without executing (default: False)
329            
330        Returns:
331            BatchRun instance with completed results
332            
333        Raises:
334            ValueError: If no jobs have been added
335        """
336        if not self.jobs:
337            raise ValueError("No jobs added to batch")
338        
339        # Import here to avoid circular dependency
340        from .batch_run import BatchRun
341        
342        # Create and start the run
343        run = BatchRun(self.config, self.jobs)
344        
345        # Handle dry run mode
346        if dry_run:
347            return run.dry_run()
348        
349        # Set progress callback - either rich display or custom callback
350        if print_status:
351            return self._run_with_rich_display(run, progress_interval)
352        else:
353            return self._run_with_custom_callback(run, on_progress, progress_interval)
354    
355    def _run_with_rich_display(self, run: 'BatchRun', progress_interval: float) -> 'BatchRun':
356        """Execute batch run with rich progress display.
357        
358        Args:
359            run: BatchRun instance to execute
360            progress_interval: Interval between progress updates
361            
362        Returns:
363            Completed BatchRun instance
364        """
365        from ..utils.rich_progress import RichBatchProgressDisplay
366        display = RichBatchProgressDisplay()
367        
368        def rich_progress_callback(stats, elapsed_time, batch_data):
369            # Start display on first call
370            if not hasattr(rich_progress_callback, '_started'):
371                config_dict = {
372                    'results_dir': self.config.results_dir,
373                    'state_file': self.config.state_file,
374                    'items_per_batch': self.config.items_per_batch,
375                    'max_parallel_batches': self.config.max_parallel_batches
376                }
377                display.start(stats, config_dict)
378                rich_progress_callback._started = True
379            
380            # Update display
381            display.update(stats, batch_data, elapsed_time)
382        
383        run.set_on_progress(rich_progress_callback, interval=progress_interval)
384        
385        # Execute with proper cleanup
386        try:
387            run.execute()
388            
389            # Show final status with all batches completed
390            stats = run.status()
391            display.update(stats, run.batch_tracking, (datetime.now() - run._start_time).total_seconds())
392            
393            # Small delay to ensure display updates
394            import time
395            time.sleep(0.2)
396            
397        except KeyboardInterrupt:
398            # Update batch tracking to show cancelled status for pending/running batches
399            with run._state_lock:
400                for batch_id, batch_info in run.batch_tracking.items():
401                    if batch_info['status'] == 'running':
402                        batch_info['status'] = 'cancelled'
403                    elif batch_info['status'] == 'pending':
404                        batch_info['status'] = 'cancelled'
405            
406            # Show final status with cancelled batches
407            stats = run.status()
408            display.update(stats, run.batch_tracking, 0.0)
409            
410            # Add a small delay to ensure the display updates
411            import time
412            time.sleep(0.1)
413            
414            display.stop()
415            raise
416        finally:
417            if display.live:  # Only stop if not already stopped
418                display.stop()
419        
420        return run
421    
422    def _run_with_custom_callback(self, run: 'BatchRun', on_progress: Optional[Callable[[Dict, float, Dict], None]], progress_interval: float) -> 'BatchRun':
423        """Execute batch run with custom progress callback.
424        
425        Args:
426            run: BatchRun instance to execute
427            on_progress: Optional custom progress callback
428            progress_interval: Interval between progress updates
429            
430        Returns:
431            Completed BatchRun instance
432        """
433        # Use custom progress callback if provided
434        if on_progress:
435            run.set_on_progress(on_progress, interval=progress_interval)
436        
437        run.execute()
438        return run
439    
440    def __len__(self) -> int:
441        """Get the number of jobs in the batch."""
442        return len(self.jobs)
443    
444    def __repr__(self) -> str:
445        """String representation of the batch."""
446        return (
447            f"Batch(jobs={len(self.jobs)}, "
448            f"max_parallel_batches={self.config.max_parallel_batches}, "
449            f"cost_limit=${self.config.cost_limit_usd or 'None'})"
450        )

Builder for batch job configuration.

Provides a fluent interface for configuring batch jobs with sensible defaults and validation. The batch can be configured with cost limits, default parameters, and progress callbacks.

Example:

batch = Batch("./results", max_parallel_batches=10, items_per_batch=10)
    .set_state(file="./state.json", reuse_state=True)
    .set_default_params(model="claude-sonnet-4-20250514", temperature=0.7)
    .add_cost_limit(usd=15.0)
    .add_job(messages=[{"role": "user", "content": "Hello"}])
    .add_job(file="./path/to/file.pdf", prompt="Generate summary of file")

run = batch.run()
Batch( results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None)
39    def __init__(self, results_dir: str, max_parallel_batches: int = 10, items_per_batch: int = 10, raw_files: Optional[bool] = None):
40        """Initialize batch configuration.
41        
42        Args:
43            results_dir: Directory to store results
44            max_parallel_batches: Maximum parallel batch requests
45            items_per_batch: Number of jobs per provider batch
46            raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)
47        """
48        # Auto-determine raw_files based on results_dir if not explicitly set
49        if raw_files is None:
50            raw_files = bool(results_dir and results_dir.strip())
51        
52        self.config = BatchParams(
53            state_file=None,
54            results_dir=results_dir,
55            max_parallel_batches=max_parallel_batches,
56            items_per_batch=items_per_batch,
57            reuse_state=True,
58            raw_files=raw_files
59        )
60        self.jobs: List[Job] = []

Initialize batch configuration.

Args: results_dir: Directory to store results max_parallel_batches: Maximum parallel batch requests items_per_batch: Number of jobs per provider batch raw_files: Whether to save debug files (raw responses, JSONL files) from providers (default: True if results_dir is set, False otherwise)

config
jobs: List[Job]
def set_default_params(self, **kwargs) -> Batch:
62    def set_default_params(self, **kwargs) -> 'Batch':
63        """Set default parameters for all jobs.
64        
65        These defaults will be applied to all jobs unless overridden
66        by job-specific parameters.
67        
68        Args:
69            **kwargs: Default parameters (model, temperature, max_tokens, etc.)
70            
71        Returns:
72            Self for chaining
73            
74        Example:
75            ```python
76            batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
77            ```
78        """
79        # Validate if model is provided
80        if "model" in kwargs:
81            self.config.validate_default_params(kwargs["model"])
82        
83        self.config.default_params.update(kwargs)
84        return self

Set default parameters for all jobs.

These defaults will be applied to all jobs unless overridden by job-specific parameters.

Args: **kwargs: Default parameters (model, temperature, max_tokens, etc.)

Returns: Self for chaining

Example:

batch.set_default_params(model="claude-3-sonnet", temperature=0.7)
def set_state( self, file: Optional[str] = None, reuse_state: bool = True) -> Batch:
 86    def set_state(self, file: Optional[str] = None, reuse_state: bool = True) -> 'Batch':
 87        """Set state file configuration.
 88        
 89        Args:
 90            file: Path to state file for persistence (default: None)
 91            reuse_state: Whether to resume from existing state file (default: True)
 92            
 93        Returns:
 94            Self for chaining
 95            
 96        Example:
 97            ```python
 98            batch.set_state(file="./state.json", reuse_state=True)
 99            ```
100        """
101        self.config.state_file = file
102        self.config.reuse_state = reuse_state
103        return self

Set state file configuration.

Args: file: Path to state file for persistence (default: None) reuse_state: Whether to resume from existing state file (default: True)

Returns: Self for chaining

Example:

batch.set_state(file="./state.json", reuse_state=True)
def add_cost_limit(self, usd: float) -> Batch:
105    def add_cost_limit(self, usd: float) -> 'Batch':
106        """Add cost limit for the batch.
107        
108        The batch will stop accepting new jobs once the cost limit is reached.
109        Active jobs will be allowed to complete.
110        
111        Args:
112            usd: Cost limit in USD
113            
114        Returns:
115            Self for chaining
116            
117        Example:
118            ```python
119            batch.add_cost_limit(usd=50.0)
120            ```
121        """
122        if usd <= 0:
123            raise ValueError("Cost limit must be positive")
124        self.config.cost_limit_usd = usd
125        return self

Add cost limit for the batch.

The batch will stop accepting new jobs once the cost limit is reached. Active jobs will be allowed to complete.

Args: usd: Cost limit in USD

Returns: Self for chaining

Example:

batch.add_cost_limit(usd=50.0)
def raw_files(self, enabled: bool = True) -> Batch:
127    def raw_files(self, enabled: bool = True) -> 'Batch':
128        """Enable or disable saving debug files from providers.
129        
130        When enabled, debug files (raw API responses, JSONL files) will be saved
131        in a 'raw_files' subdirectory within the results directory.
132        This is useful for debugging, auditing, or accessing provider-specific metadata.
133        
134        Args:
135            enabled: Whether to save debug files (default: True)
136            
137        Returns:
138            Self for chaining
139            
140        Example:
141            ```python
142            batch.raw_files(True)
143            ```
144        """
145        self.config.raw_files = enabled
146        return self

Enable or disable saving debug files from providers.

When enabled, debug files (raw API responses, JSONL files) will be saved in a 'raw_files' subdirectory within the results directory. This is useful for debugging, auditing, or accessing provider-specific metadata.

Args: enabled: Whether to save debug files (default: True)

Returns: Self for chaining

Example:

batch.raw_files(True)
def set_verbosity(self, level: str) -> Batch:
148    def set_verbosity(self, level: str) -> 'Batch':
149        """Set logging verbosity level.
150        
151        Args:
152            level: Verbosity level ("debug", "info", "warn", "error")
153            
154        Returns:
155            Self for chaining
156            
157        Example:
158            ```python
159            batch.set_verbosity("error")  # For production
160            batch.set_verbosity("debug")  # For debugging
161            ```
162        """
163        valid_levels = {"debug", "info", "warn", "error"}
164        if level.lower() not in valid_levels:
165            raise ValueError(f"Invalid verbosity level: {level}. Must be one of {valid_levels}")
166        self.config.verbosity = level.lower()
167        return self

Set logging verbosity level.

Args: level: Verbosity level ("debug", "info", "warn", "error")

Returns: Self for chaining

Example:

batch.set_verbosity("error")  # For production
batch.set_verbosity("debug")  # For debugging
def add_time_limit( self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> Batch:
169    def add_time_limit(self, seconds: Optional[float] = None, minutes: Optional[float] = None, hours: Optional[float] = None) -> 'Batch':
170        """Add time limit for the entire batch execution.
171        
172        When time limit is reached, all active provider batches are cancelled and 
173        remaining unprocessed jobs are marked as failed. The batch execution 
174        completes normally without throwing exceptions.
175        
176        Args:
177            seconds: Time limit in seconds (optional)
178            minutes: Time limit in minutes (optional)
179            hours: Time limit in hours (optional)
180            
181        Returns:
182            Self for chaining
183            
184        Raises:
185            ValueError: If no time units specified, or if total time is outside 
186                       valid range (min: 10 seconds, max: 24 hours)
187            
188        Note:
189            - Can combine multiple time units
190            - Time limit is checked every second by a background watchdog thread
191            - Jobs that exceed time limit appear in results()["failed"] with time limit error message
192            - No exceptions are thrown when time limit is reached
193            
194        Example:
195            ```python
196            batch.add_time_limit(seconds=30)  # 30 seconds
197            batch.add_time_limit(minutes=5)   # 5 minutes
198            batch.add_time_limit(hours=2)     # 2 hours
199            batch.add_time_limit(hours=1, minutes=30, seconds=15)  # 5415 seconds total
200            ```
201        """
202        time_limit_seconds = 0.0
203        
204        if seconds is not None:
205            time_limit_seconds += seconds
206        if minutes is not None:
207            time_limit_seconds += minutes * 60
208        if hours is not None:
209            time_limit_seconds += hours * 3600
210            
211        if time_limit_seconds == 0:
212            raise ValueError("Must specify at least one of seconds, minutes, or hours")
213            
214        self.config.time_limit_seconds = time_limit_seconds
215        return self

Add time limit for the entire batch execution.

When time limit is reached, all active provider batches are cancelled and remaining unprocessed jobs are marked as failed. The batch execution completes normally without throwing exceptions.

Args: seconds: Time limit in seconds (optional) minutes: Time limit in minutes (optional) hours: Time limit in hours (optional)

Returns: Self for chaining

Raises: ValueError: If no time units specified, or if total time is outside valid range (min: 10 seconds, max: 24 hours)

Note: - Can combine multiple time units - Time limit is checked every second by a background watchdog thread - Jobs that exceed time limit appear in results()["failed"] with time limit error message - No exceptions are thrown when time limit is reached

Example:

batch.add_time_limit(seconds=30)  # 30 seconds
batch.add_time_limit(minutes=5)   # 5 minutes
batch.add_time_limit(hours=2)     # 2 hours
batch.add_time_limit(hours=1, minutes=30, seconds=15)  # 5415 seconds total
def add_job( self, messages: Optional[List[Dict[str, Any]]] = None, file: Union[str, pathlib.Path, NoneType] = None, prompt: Optional[str] = None, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, response_model: Optional[Type[pydantic.main.BaseModel]] = None, enable_citations: bool = False, **kwargs) -> Batch:
217    def add_job(
218        self,
219        messages: Optional[List[Message]] = None,
220        file: Optional[Union[str, Path]] = None,
221        prompt: Optional[str] = None,
222        model: Optional[str] = None,
223        temperature: Optional[float] = None,
224        max_tokens: Optional[int] = None,
225        response_model: Optional[Type[BaseModel]] = None,
226        enable_citations: bool = False,
227        **kwargs
228    ) -> 'Batch':
229        """Add a job to the batch.
230        
231        Either provide messages OR file+prompt, not both. Parameters not provided
232        will use the defaults set via the defaults() method.
233        
234        Args:
235            messages: Chat messages for direct input
236            file: File path for file-based input
237            prompt: Prompt to use with file input
238            model: Model to use (overrides default)
239            temperature: Sampling temperature (overrides default)
240            max_tokens: Max tokens to generate (overrides default)
241            response_model: Pydantic model for structured output
242            enable_citations: Whether to extract citations
243            **kwargs: Additional parameters
244            
245        Returns:
246            Self for chaining
247            
248        Example:
249            ```python
250            batch.add_job(
251                messages=[{"role": "user", "content": "Hello"}],
252                model="gpt-4"
253            )
254            ```
255        """
256        # Generate unique job ID
257        job_id = f"job-{uuid.uuid4().hex[:8]}"
258        
259        # Merge with defaults
260        params = self.config.default_params.copy()
261        
262        # Update with provided parameters
263        if model is not None:
264            params["model"] = model
265        if temperature is not None:
266            params["temperature"] = temperature
267        if max_tokens is not None:
268            params["max_tokens"] = max_tokens
269        
270        # Add other kwargs
271        params.update(kwargs)
272        
273        # Ensure model is provided
274        if "model" not in params:
275            raise ValueError("Model must be provided either in defaults or job parameters")
276        
277        # Validate parameters
278        provider = get_provider(params["model"])
279        # Extract params without model to avoid duplicate
280        param_subset = {k: v for k, v in params.items() if k != "model"}
281        provider.validate_params(params["model"], **param_subset)
282        
283        # Convert file path if string
284        if isinstance(file, str):
285            file = Path(file)
286        
287        # Warn about temporary file paths that may not persist
288        if file:
289            file_str = str(file)
290            if "/tmp/" in file_str or "/var/folders/" in file_str or "temp" in file_str.lower():
291                logger = logging.getLogger("batchata")
292                logger.debug(f"File path appears to be in a temporary directory: {file}")
293                logger.debug("This may cause issues when resuming from state if temp files are cleaned up")
294        
295        # Create job
296        job = Job(
297            id=job_id,
298            messages=messages,
299            file=file,
300            prompt=prompt,
301            response_model=response_model,
302            enable_citations=enable_citations,
303            **params
304        )
305        
306        # Validate citation compatibility
307        if response_model and enable_citations:
308            from ..utils.validation import validate_flat_model
309            validate_flat_model(response_model)
310        
311        # Validate job with provider (includes PDF validation for Anthropic)
312        provider.validate_job(job)
313        
314        
315        self.jobs.append(job)
316        return self

Add a job to the batch.

Either provide messages OR file+prompt, not both. Parameters not provided will use the defaults set via the defaults() method.

Args: messages: Chat messages for direct input file: File path for file-based input prompt: Prompt to use with file input model: Model to use (overrides default) temperature: Sampling temperature (overrides default) max_tokens: Max tokens to generate (overrides default) response_model: Pydantic model for structured output enable_citations: Whether to extract citations **kwargs: Additional parameters

Returns: Self for chaining

Example:

batch.add_job(
    messages=[{"role": "user", "content": "Hello"}],
    model="gpt-4"
)
def run( self, on_progress: Optional[Callable[[Dict, float, Dict], NoneType]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> BatchRun:
318    def run(self, on_progress: Optional[Callable[[Dict, float, Dict], None]] = None, progress_interval: float = 1.0, print_status: bool = False, dry_run: bool = False) -> 'BatchRun':
319        """Execute the batch.
320        
321        Creates a BatchRun instance and executes the jobs synchronously.
322        
323        Args:
324            on_progress: Optional progress callback function that receives
325                        (stats_dict, elapsed_time_seconds, batch_data)
326            progress_interval: Interval in seconds between progress updates (default: 1.0)
327            print_status: Whether to show rich progress display (default: False)
328            dry_run: If True, only show cost estimation without executing (default: False)
329            
330        Returns:
331            BatchRun instance with completed results
332            
333        Raises:
334            ValueError: If no jobs have been added
335        """
336        if not self.jobs:
337            raise ValueError("No jobs added to batch")
338        
339        # Import here to avoid circular dependency
340        from .batch_run import BatchRun
341        
342        # Create and start the run
343        run = BatchRun(self.config, self.jobs)
344        
345        # Handle dry run mode
346        if dry_run:
347            return run.dry_run()
348        
349        # Set progress callback - either rich display or custom callback
350        if print_status:
351            return self._run_with_rich_display(run, progress_interval)
352        else:
353            return self._run_with_custom_callback(run, on_progress, progress_interval)

Execute the batch.

Creates a BatchRun instance and executes the jobs synchronously.

Args: on_progress: Optional progress callback function that receives (stats_dict, elapsed_time_seconds, batch_data) progress_interval: Interval in seconds between progress updates (default: 1.0) print_status: Whether to show rich progress display (default: False) dry_run: If True, only show cost estimation without executing (default: False)

Returns: BatchRun instance with completed results

Raises: ValueError: If no jobs have been added

class BatchRun:
 25class BatchRun:
 26    """Manages the execution of a batch job synchronously.
 27    
 28    Processes jobs in batches based on items_per_batch configuration.
 29    Simpler synchronous execution for clear logging and debugging.
 30    
 31    Example:
 32        ```python
 33        config = BatchParams(...)
 34        run = BatchRun(config, jobs)
 35        run.execute()
 36        results = run.results()
 37        ```
 38    """
 39    
 40    def __init__(self, config: BatchParams, jobs: List[Job]):
 41        """Initialize batch run.
 42        
 43        Args:
 44            config: Batch configuration
 45            jobs: List of jobs to execute
 46        """
 47        self.config = config
 48        self.jobs = {job.id: job for job in jobs}
 49        
 50        # Set logging level based on config
 51        set_log_level(level=config.verbosity.upper())
 52        
 53        # Initialize components
 54        self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd)
 55        
 56        # Use temp file for state if not provided
 57        state_file = config.state_file
 58        if not state_file:
 59            state_file = create_temp_state_file(config)
 60            config.reuse_state = False
 61            logger.info(f"Created temporary state file: {state_file}")
 62        
 63        self.state_manager = StateManager(state_file)
 64        
 65        # State tracking
 66        self.pending_jobs: List[Job] = []
 67        self.completed_results: Dict[str, JobResult] = {}  # job_id -> result
 68        self.failed_jobs: Dict[str, str] = {}  # job_id -> error
 69        self.cancelled_jobs: Dict[str, str] = {}  # job_id -> reason
 70        
 71        # Batch tracking
 72        self.total_batches = 0
 73        self.completed_batches = 0
 74        self.current_batch_index = 0
 75        self.current_batch_size = 0
 76        
 77        # Execution control
 78        self._started = False
 79        self._start_time: Optional[datetime] = None
 80        self._time_limit_exceeded = False
 81        self._progress_callback: Optional[Callable[[Dict, float], None]] = None
 82        self._progress_interval: float = 1.0  # Default to 1 second
 83        
 84        # Threading primitives
 85        self._state_lock = threading.Lock()
 86        self._shutdown_event = threading.Event()
 87        self._progress_lock = threading.Lock()
 88        self._last_progress_update = 0.0
 89        
 90        # Batch tracking for progress display
 91        self.batch_tracking: Dict[str, Dict] = {}  # batch_id -> batch_info
 92        
 93        # Active batch tracking for cancellation
 94        self._active_batches: Dict[str, object] = {}  # batch_id -> provider
 95        self._active_batches_lock = threading.Lock()
 96        
 97        # Results directory
 98        self.results_dir = Path(config.results_dir)
 99        
100        # If not reusing state, clear the results directory
101        if not config.reuse_state and self.results_dir.exists():
102            import shutil
103            shutil.rmtree(self.results_dir)
104        
105        self.results_dir.mkdir(parents=True, exist_ok=True)
106        
107        # Raw files directory (if enabled)
108        self.raw_files_dir = None
109        if config.raw_files:
110            self.raw_files_dir = self.results_dir / "raw_files"
111            self.raw_files_dir.mkdir(parents=True, exist_ok=True)
112        
113        # Try to resume from saved state
114        self._resume_from_state()
115    
116    
117    def _resume_from_state(self):
118        """Resume from saved state if available."""
119        # Check if we should reuse state
120        if not self.config.reuse_state:
121            # Clear any existing state and start fresh
122            self.state_manager.clear()
123            self.pending_jobs = list(self.jobs.values())
124            return
125            
126        state = self.state_manager.load()
127        if state is None:
128            # No saved state, use jobs passed to constructor
129            self.pending_jobs = list(self.jobs.values())
130            return
131        
132        logger.info("Resuming batch run from saved state")
133        
134        # Restore pending jobs
135        self.pending_jobs = []
136        for job_data in state.pending_jobs:
137            job = Job.from_dict(job_data)
138            # Check if file exists (if job has a file)
139            if job.file and not job.file.exists():
140                logger.error(f"File not found for job {job.id}: {job.file}")
141                logger.error("This may happen if files were in temporary directories that were cleaned up")
142                self.failed_jobs[job.id] = f"File not found: {job.file}"
143            else:
144                self.pending_jobs.append(job)
145        
146        # Restore completed results from file references
147        for result_ref in state.completed_results:
148            job_id = result_ref["job_id"]
149            file_path = result_ref["file_path"]
150            try:
151                with open(file_path, 'r') as f:
152                    result_data = json.load(f)
153                result = JobResult.from_dict(result_data)
154                self.completed_results[job_id] = result
155            except Exception as e:
156                logger.error(f"Failed to load result for {job_id} from {file_path}: {e}")
157                # Move to failed jobs if we can't load the result
158                self.failed_jobs[job_id] = f"Failed to load result file: {e}"
159        
160        # Restore failed jobs
161        for job_data in state.failed_jobs:
162            self.failed_jobs[job_data["id"]] = job_data.get("error", "Unknown error")
163        
164        # Restore cancelled jobs (if they exist in state)
165        for job_data in getattr(state, 'cancelled_jobs', []):
166            self.cancelled_jobs[job_data["id"]] = job_data.get("reason", "Cancelled")
167        
168        # Restore cost tracker
169        self.cost_tracker.used_usd = state.total_cost_usd
170        
171        logger.info(
172            f"Resumed with {len(self.pending_jobs)} pending, "
173            f"{len(self.completed_results)} completed, "
174            f"{len(self.failed_jobs)} failed, "
175            f"{len(self.cancelled_jobs)} cancelled"
176        )
177    
178    def to_json(self) -> Dict:
179        """Convert current state to JSON-serializable dict."""
180        return {
181            "created_at": datetime.now().isoformat(),
182            "pending_jobs": [job.to_dict() for job in self.pending_jobs],
183            "completed_results": [
184                {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")}
185                for job_id in self.completed_results.keys()
186            ],
187            "failed_jobs": [
188                {
189                    "id": job_id, 
190                    "error": error,
191                    "timestamp": datetime.now().isoformat()
192                } for job_id, error in self.failed_jobs.items()
193            ],
194            "cancelled_jobs": [
195                {
196                    "id": job_id, 
197                    "reason": reason,
198                    "timestamp": datetime.now().isoformat()
199                } for job_id, reason in self.cancelled_jobs.items()
200            ],
201            "total_cost_usd": self.cost_tracker.used_usd,
202            "config": {
203                "state_file": self.config.state_file,
204                "results_dir": self.config.results_dir,
205                "max_parallel_batches": self.config.max_parallel_batches,
206                "items_per_batch": self.config.items_per_batch,
207                "cost_limit_usd": self.config.cost_limit_usd,
208                "default_params": self.config.default_params,
209                "raw_files": self.config.raw_files
210            }
211        }
212    
213    def execute(self):
214        """Execute synchronous batch run and wait for completion."""
215        if self._started:
216            raise RuntimeError("Batch run already started")
217        
218        self._started = True
219        self._start_time = datetime.now()
220        
221        # Register signal handler for graceful shutdown
222        def signal_handler(signum, frame):
223            logger.warning("Received interrupt signal, shutting down gracefully...")
224            self._shutdown_event.set()
225        
226        # Store original handler to restore later
227        original_handler = signal.signal(signal.SIGINT, signal_handler)
228        
229        try:
230            logger.info("Starting batch run")
231            
232            # Start time limit watchdog if configured
233            self._start_time_limit_watchdog()
234            
235            # Call initial progress
236            if self._progress_callback:
237                with self._progress_lock:
238                    with self._state_lock:
239                        stats = self.status()
240                        batch_data = dict(self.batch_tracking)
241                    self._progress_callback(stats, 0.0, batch_data)
242                    self._last_progress_update = time.time()
243            
244            # Process all jobs synchronously
245            self._process_all_jobs()
246            
247            logger.info("Batch run completed")
248        finally:
249            # Restore original signal handler
250            signal.signal(signal.SIGINT, original_handler)
251    
252    def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun':
253        """Set progress callback for execution monitoring.
254        
255        The callback will be called periodically with progress statistics
256        including completed jobs, total jobs, current cost, etc.
257        
258        Args:
259            callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data)
260                     - stats_dict: Progress statistics dictionary
261                     - elapsed_time_seconds: Time elapsed since batch started (float)
262                     - batch_data: Dictionary mapping batch_id to batch information
263            interval: Interval in seconds between progress updates (default: 1.0)
264            
265        Returns:
266            Self for chaining
267            
268        Example:
269            ```python
270            run.set_on_progress(
271                lambda stats, time, batch_data: print(
272                    f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
273                )
274            )
275            ```
276        """
277        self._progress_callback = callback
278        self._progress_interval = interval
279        return self
280    
281    def _start_time_limit_watchdog(self):
282        """Start a background thread to check for time limit every second."""
283        if not self.config.time_limit_seconds:
284            return
285        
286        def time_limit_watchdog():
287            """Check for time limit every second and trigger shutdown if exceeded."""
288            while not self._shutdown_event.is_set():
289                if self._check_time_limit():
290                    logger.warning("Batch execution time limit exceeded")
291                    with self._state_lock:
292                        self._time_limit_exceeded = True
293                    self._shutdown_event.set()
294                    break
295                time.sleep(1.0)
296        
297        # Start watchdog as daemon thread
298        watchdog_thread = threading.Thread(target=time_limit_watchdog, daemon=True)
299        watchdog_thread.start()
300        logger.debug(f"Started time limit watchdog thread (time limit: {self.config.time_limit_seconds}s)")
301    
302    def _check_time_limit(self) -> bool:
303        """Check if batch execution has exceeded time limit."""
304        if not self.config.time_limit_seconds or not self._start_time:
305            return False
306        
307        elapsed = (datetime.now() - self._start_time).total_seconds()
308        return elapsed >= self.config.time_limit_seconds
309    
310    def _process_all_jobs(self):
311        """Process all jobs with parallel execution."""
312        # Prepare all batches
313        batches = self._prepare_batches()
314        self.total_batches = len(batches)
315        
316        # Process batches in parallel
317        with ThreadPoolExecutor(max_workers=self.config.max_parallel_batches) as executor:
318            futures = [executor.submit(self._execute_batch_wrapped, provider, batch_jobs) 
319                      for _, provider, batch_jobs in batches]
320            
321            try:
322                for future in as_completed(futures):
323                    # Stop if shutdown event detected (includes time limit)
324                    if self._shutdown_event.is_set():
325                        break
326                    future.result()  # Re-raise any exceptions
327            except KeyboardInterrupt:
328                self._shutdown_event.set()
329                # Cancel remaining futures
330                for future in futures:
331                    future.cancel()
332                raise
333            finally:
334                # Handle time limit or cancellation - mark remaining jobs appropriately
335                with self._state_lock:
336                    if self._shutdown_event.is_set():
337                        # If time limit exceeded, cancel all active batches
338                        if self._time_limit_exceeded:
339                            self._cancel_all_active_batches()
340                        
341                        # Mark any unprocessed jobs based on reason for shutdown
342                        for _, _, batch_jobs in batches:
343                            for job in batch_jobs:
344                                # Skip jobs already processed
345                                if (job.id in self.completed_results or 
346                                    job.id in self.failed_jobs or 
347                                    job.id in self.cancelled_jobs):
348                                    continue
349                                
350                                # Mark based on shutdown reason
351                                if self._time_limit_exceeded:
352                                    self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded"
353                                else:
354                                    self.cancelled_jobs[job.id] = "Cancelled by user"
355                                
356                                if job in self.pending_jobs:
357                                    self.pending_jobs.remove(job)
358                        
359                        # Save state
360                        self.state_manager.save(self)
361    
362    def _cancel_all_active_batches(self):
363        """Cancel all active batches at the provider level."""
364        with self._active_batches_lock:
365            active_batch_items = list(self._active_batches.items())
366            
367        logger.info(f"Cancelling {len(active_batch_items)} active batches due to time limit exceeded")
368        
369        # Cancel outside the lock to avoid blocking
370        for batch_id, provider in active_batch_items:
371            try:
372                provider.cancel_batch(batch_id)
373                logger.info(f"Cancelled batch {batch_id} due to time limit exceeded")
374            except Exception as e:
375                logger.warning(f"Failed to cancel batch {batch_id}: {e}")
376        
377        # Clear the tracking after cancellation attempts
378        with self._active_batches_lock:
379            self._active_batches.clear()
380    
381    def _execute_batch_wrapped(self, provider, batch_jobs):
382        """Thread-safe wrapper for _execute_batch."""
383        try:
384            result = self._execute_batch(provider, batch_jobs)
385            with self._state_lock:
386                self._update_batch_results(result)
387                # Remove jobs from pending_jobs if specified
388                jobs_to_remove = result.get("jobs_to_remove", [])
389                for job in jobs_to_remove:
390                    if job in self.pending_jobs:
391                        self.pending_jobs.remove(job)
392        except TimeoutError:
393            # Handle time limit exceeded - mark jobs as failed
394            with self._state_lock:
395                for job in batch_jobs:
396                    self.failed_jobs[job.id] = "Time limit exceeded: batch execution time limit exceeded"
397                    if job in self.pending_jobs:
398                        self.pending_jobs.remove(job)
399                self.state_manager.save(self)
400            # Don't re-raise, just return the result
401            return
402        except KeyboardInterrupt:
403            self._shutdown_event.set()
404            # Handle user cancellation
405            with self._state_lock:
406                for job in batch_jobs:
407                    self.cancelled_jobs[job.id] = "Cancelled by user"
408                    if job in self.pending_jobs:
409                        self.pending_jobs.remove(job)
410                self.state_manager.save(self)
411            raise
412    
413    def _group_jobs_by_provider(self) -> Dict[str, List[Job]]:
414        """Group jobs by provider."""
415        jobs_by_provider = {}
416        
417        for job in self.pending_jobs[:]:  # Copy to avoid modification during iteration
418            try:
419                provider = get_provider(job.model)
420                provider_name = provider.__class__.__name__
421                
422                if provider_name not in jobs_by_provider:
423                    jobs_by_provider[provider_name] = []
424                
425                jobs_by_provider[provider_name].append(job)
426                
427            except Exception as e:
428                logger.error(f"Failed to get provider for job {job.id}: {e}")
429                with self._state_lock:
430                    self.failed_jobs[job.id] = str(e)
431                    self.pending_jobs.remove(job)
432        
433        return jobs_by_provider
434    
435    def _split_into_batches(self, jobs: List[Job]) -> List[List[Job]]:
436        """Split jobs into batches based on items_per_batch."""
437        batches = []
438        batch_size = self.config.items_per_batch
439        
440        for i in range(0, len(jobs), batch_size):
441            batch = jobs[i:i + batch_size]
442            batches.append(batch)
443        
444        return batches
445    
446    def _prepare_batches(self) -> List[Tuple[str, object, List[Job]]]:
447        """Prepare all batches as simple list of (provider_name, provider, jobs)."""
448        batches = []
449        jobs_by_provider = self._group_jobs_by_provider()
450        
451        for provider_name, provider_jobs in jobs_by_provider.items():
452            provider = get_provider(provider_jobs[0].model)
453            job_batches = self._split_into_batches(provider_jobs)
454            
455            for batch_jobs in job_batches:
456                batches.append((provider_name, provider, batch_jobs))
457                
458                # Pre-populate batch tracking for pending batches
459                batch_id = f"pending_{len(self.batch_tracking)}"
460                estimated_cost = provider.estimate_cost(batch_jobs)
461                self.batch_tracking[batch_id] = {
462                    'start_time': None,
463                    'status': 'pending',
464                    'total': len(batch_jobs),
465                    'completed': 0,
466                    'cost': 0.0,
467                    'estimated_cost': estimated_cost,
468                    'provider': provider_name,
469                    'jobs': batch_jobs
470                }
471        
472        return batches
473    
474    def _poll_batch_status(self, provider, batch_id: str) -> Tuple[str, Optional[Dict]]:
475        """Poll until batch completes."""
476        status, error_details = provider.get_batch_status(batch_id)
477        logger.info(f"Initial batch status: {status}")
478        poll_count = 0
479        
480        # Use provider-specific polling interval
481        provider_polling_interval = provider.get_polling_interval()
482        logger.debug(f"Using {provider_polling_interval}s polling interval for {provider.__class__.__name__}")
483        
484        while status not in ["complete", "failed"]:
485            poll_count += 1
486            logger.debug(f"Polling attempt {poll_count}, current status: {status}")
487            
488            # Interruptible wait - will wake up immediately if shutdown event is set (includes time limit)
489            if self._shutdown_event.wait(provider_polling_interval):
490                # Check if it's time limit exceeded or user cancellation
491                with self._state_lock:
492                    is_time_limit_exceeded = self._time_limit_exceeded
493                
494                if is_time_limit_exceeded:
495                    logger.info(f"Batch {batch_id} polling interrupted by time limit exceeded")
496                    raise TimeoutError("Batch cancelled due to time limit exceeded")
497                else:
498                    logger.info(f"Batch {batch_id} polling interrupted by user")
499                    raise KeyboardInterrupt("Batch cancelled by user")
500            
501            status, error_details = provider.get_batch_status(batch_id)
502            
503            if self._progress_callback:
504                # Rate limit progress updates and synchronize calls to prevent duplicate printing
505                current_time = time.time()
506                should_update = current_time - self._last_progress_update >= self._progress_interval
507                
508                if should_update:
509                    with self._progress_lock:
510                        # Double-check timing inside the lock to avoid race condition
511                        if current_time - self._last_progress_update >= self._progress_interval:
512                            with self._state_lock:
513                                stats = self.status()
514                                elapsed_time = (datetime.now() - self._start_time).total_seconds()
515                                batch_data = dict(self.batch_tracking)
516                            self._progress_callback(stats, elapsed_time, batch_data)
517                            self._last_progress_update = current_time
518            
519            elapsed_seconds = poll_count * provider_polling_interval
520        
521        return status, error_details
522    
523    
524    def _update_batch_results(self, batch_result: Dict):
525        """Update state from batch results."""
526        results = batch_result.get("results", [])
527        failed = batch_result.get("failed", {})
528                
529        # Update completed results
530        for result in results:
531            if result.is_success:
532                self.completed_results[result.job_id] = result
533                self._save_result_to_file(result)
534                logger.info(f"✓ Job {result.job_id} completed successfully")
535            else:
536                error_message = result.error or "Unknown error"
537                self.failed_jobs[result.job_id] = error_message
538                self._save_result_to_file(result)
539                logger.error(f"✗ Job {result.job_id} failed: {result.error}")
540            
541            # Remove completed/failed job from pending
542            self.pending_jobs = [job for job in self.pending_jobs if job.id != result.job_id]
543        
544        # Update failed jobs
545        for job_id, error in failed.items():
546            self.failed_jobs[job_id] = error
547            # Remove failed job from pending
548            self.pending_jobs = [job for job in self.pending_jobs if job.id != job_id]
549            logger.error(f"✗ Job {job_id} failed: {error}")
550        
551        # Update batch tracking
552        self.completed_batches += 1
553        
554        # Save state
555        self.state_manager.save(self)
556    
557    def _execute_batch(self, provider, batch_jobs: List[Job]) -> Dict:
558        """Execute one batch, return results dict with jobs/costs/errors."""
559        if not batch_jobs:
560            return {"results": [], "failed": {}, "cost": 0.0}
561        
562        # Reserve cost limit
563        logger.info(f"Estimating cost for batch of {len(batch_jobs)} jobs...")
564        estimated_cost = provider.estimate_cost(batch_jobs)
565        remaining = self.cost_tracker.remaining()
566        remaining_str = f"${remaining:.4f}" if remaining is not None else "unlimited"
567        logger.info(f"Total estimated cost: ${estimated_cost:.4f}, remaining budget: {remaining_str}")
568        
569        if not self.cost_tracker.reserve_cost(estimated_cost):
570            logger.warning(f"Cost limit would be exceeded, skipping batch of {len(batch_jobs)} jobs")
571            failed = {}
572            for job in batch_jobs:
573                failed[job.id] = "Cost limit exceeded"
574            return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)}
575        
576        batch_id = None
577        job_mapping = None
578        try:
579            # Create batch
580            logger.info(f"Creating batch with {len(batch_jobs)} jobs...")
581            raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None
582            batch_id, job_mapping = provider.create_batch(batch_jobs, raw_files_path)
583            
584            # Track active batch for cancellation
585            with self._active_batches_lock:
586                self._active_batches[batch_id] = provider
587            
588            # Track batch creation
589            with self._state_lock:
590                # Remove pending entry if it exists
591                pending_keys = [k for k in self.batch_tracking.keys() if k.startswith('pending_')]
592                for pending_key in pending_keys:
593                    if self.batch_tracking[pending_key]['jobs'] == batch_jobs:
594                        del self.batch_tracking[pending_key]
595                        break
596                
597                # Add actual batch tracking
598                self.batch_tracking[batch_id] = {
599                    'start_time': datetime.now(),
600                    'status': 'running',
601                    'total': len(batch_jobs),
602                    'completed': 0,
603                    'cost': 0.0,
604                    'estimated_cost': estimated_cost,
605                    'provider': provider.__class__.__name__,
606                    'jobs': batch_jobs
607                }
608            
609            # Poll for completion
610            logger.info(f"Polling for batch {batch_id} completion...")
611            status, error_details = self._poll_batch_status(provider, batch_id)
612            
613            if status == "failed":
614                if error_details:
615                    logger.error(f"Batch {batch_id} failed with details: {error_details}")
616                else:
617                    logger.error(f"Batch {batch_id} failed")
618                
619                # Save error details if configured
620                if self.raw_files_dir and error_details:
621                    self._save_batch_error_details(batch_id, error_details)
622                
623                # Continue to get individual results - some jobs might have succeeded
624            
625            # Get results
626            logger.info(f"Getting results for batch {batch_id}")
627            raw_files_path = str(self.raw_files_dir) if self.raw_files_dir else None
628            results = provider.get_batch_results(batch_id, job_mapping, raw_files_path)
629            
630            # Calculate actual cost and adjust reservation
631            actual_cost = sum(r.cost_usd for r in results)
632            self.cost_tracker.adjust_reserved_cost(estimated_cost, actual_cost)
633            
634            # Update batch tracking for completion
635            success_count = len([r for r in results if r.is_success])
636            failed_count = len([r for r in results if not r.is_success])
637            batch_status = 'complete' if failed_count == 0 else 'failed'
638            
639            with self._state_lock:
640                if batch_id in self.batch_tracking:
641                    self.batch_tracking[batch_id]['status'] = batch_status
642                    self.batch_tracking[batch_id]['completed'] = success_count
643                    self.batch_tracking[batch_id]['failed'] = failed_count
644                    self.batch_tracking[batch_id]['cost'] = actual_cost
645                    self.batch_tracking[batch_id]['completion_time'] = datetime.now()
646                    if batch_status == 'failed' and failed_count > 0:
647                        # Use the first job's error as the batch error summary
648                        first_error = next((r.error for r in results if not r.is_success), 'Some jobs failed')
649                        self.batch_tracking[batch_id]['error'] = first_error
650            
651            # Remove from active batches tracking
652            with self._active_batches_lock:
653                self._active_batches.pop(batch_id, None)
654            
655            status_symbol = "✓" if batch_status == 'complete' else "⚠"
656            logger.info(
657                f"{status_symbol} Batch {batch_id} completed: "
658                f"{success_count} success, "
659                f"{failed_count} failed, "
660                f"cost: ${actual_cost:.6f}"
661            )
662            
663            return {"results": results, "failed": {}, "cost": actual_cost, "jobs_to_remove": list(batch_jobs)}
664            
665        except TimeoutError:
666            logger.info(f"Time limit exceeded for batch{f' {batch_id}' if batch_id else ''}")
667            if batch_id:
668                # Update batch tracking for time limit exceeded
669                with self._state_lock:
670                    if batch_id in self.batch_tracking:
671                        self.batch_tracking[batch_id]['status'] = 'failed'
672                        self.batch_tracking[batch_id]['error'] = 'Time limit exceeded: batch execution time limit exceeded'
673                        self.batch_tracking[batch_id]['completion_time'] = datetime.now()
674                # NOTE: Don't remove from _active_batches - let centralized cancellation handle it
675            # Release the reservation since batch exceeded time limit
676            self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0)
677            # Re-raise to be handled by wrapper
678            raise
679            
680        except KeyboardInterrupt:
681            logger.warning(f"\nCancelling batch{f' {batch_id}' if batch_id else ''}...")
682            if batch_id:
683                # Update batch tracking for cancellation
684                with self._state_lock:
685                    if batch_id in self.batch_tracking:
686                        self.batch_tracking[batch_id]['status'] = 'cancelled'
687                        self.batch_tracking[batch_id]['error'] = 'Cancelled by user'
688                        self.batch_tracking[batch_id]['completion_time'] = datetime.now()
689                # Remove from active batches tracking
690                with self._active_batches_lock:
691                    self._active_batches.pop(batch_id, None)
692            # Release the reservation since batch was cancelled
693            self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0)
694            # Handle cancellation in the wrapper with proper locking
695            raise
696            
697        except Exception as e:
698            logger.error(f"✗ Batch execution failed: {e}")
699            # Update batch tracking for exception
700            if batch_id:
701                with self._state_lock:
702                    if batch_id in self.batch_tracking:
703                        self.batch_tracking[batch_id]['status'] = 'failed'
704                        self.batch_tracking[batch_id]['error'] = str(e)
705                        self.batch_tracking[batch_id]['completion_time'] = datetime.now()
706                # Remove from active batches tracking
707                with self._active_batches_lock:
708                    self._active_batches.pop(batch_id, None)
709            # Release the reservation since batch failed
710            self.cost_tracker.adjust_reserved_cost(estimated_cost, 0.0)
711            failed = {}
712            for job in batch_jobs:
713                failed[job.id] = str(e)
714            return {"results": [], "failed": failed, "cost": 0.0, "jobs_to_remove": list(batch_jobs)}
715    
716    
717    def _save_result_to_file(self, result: JobResult):
718        """Save individual result to file."""
719        result_file = self.results_dir / f"{result.job_id}.json"
720        
721        try:
722            with open(result_file, 'w') as f:
723                json.dump(result.to_dict(), f, indent=2)
724        except Exception as e:
725            logger.error(f"Failed to save result for {result.job_id}: {e}")
726    
727    def _save_batch_error_details(self, batch_id: str, error_details: Dict):
728        """Save batch error details to debug files directory."""
729        try:
730            error_file = self.raw_files_dir / f"batch_{batch_id}_error.json"
731            with open(error_file, 'w') as f:
732                json.dump({
733                    "batch_id": batch_id,
734                    "timestamp": datetime.now().isoformat(),
735                    "error_details": error_details
736                }, f, indent=2)
737            logger.info(f"Saved batch error details to {error_file}")
738        except Exception as e:
739            logger.error(f"Failed to save batch error details: {e}")
740    
741    @property
742    def is_complete(self) -> bool:
743        """Whether all jobs are complete."""
744        total_jobs = len(self.jobs)
745        completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs)
746        return len(self.pending_jobs) == 0 and completed_count == total_jobs
747
748    
749    def status(self, print_status: bool = False) -> Dict:
750        """Get current execution statistics."""
751        total_jobs = len(self.jobs)
752        completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs)
753        remaining_count = total_jobs - completed_count
754        
755        stats = {
756            "total": total_jobs,
757            "pending": remaining_count,
758            "active": 0,  # Always 0 for synchronous execution
759            "completed": len(self.completed_results),
760            "failed": len(self.failed_jobs),
761            "cancelled": len(self.cancelled_jobs),
762            "cost_usd": self.cost_tracker.used_usd,
763            "cost_limit_usd": self.cost_tracker.limit_usd,
764            "is_complete": self.is_complete,
765            "batches_total": self.total_batches,
766            "batches_completed": self.completed_batches,
767            "batches_pending": self.total_batches - self.completed_batches,
768            "current_batch_index": self.current_batch_index,
769            "current_batch_size": self.current_batch_size,
770            "items_per_batch": self.config.items_per_batch
771        }
772        
773        if print_status:
774            logger.info("\nBatch Run Status:")
775            logger.info(f"  Total jobs: {stats['total']}")
776            logger.info(f"  Pending: {stats['pending']}")
777            logger.info(f"  Active: {stats['active']}")
778            logger.info(f"  Completed: {stats['completed']}")
779            logger.info(f"  Failed: {stats['failed']}")
780            logger.info(f"  Cancelled: {stats['cancelled']}")
781            logger.info(f"  Cost: ${stats['cost_usd']:.6f}")
782            if stats['cost_limit_usd']:
783                logger.info(f"  Cost limit: ${stats['cost_limit_usd']:.2f}")
784            logger.info(f"  Complete: {stats['is_complete']}")
785        
786        return stats
787    
788    def results(self) -> Dict[str, List[JobResult]]:
789        """Get all results organized by status.
790        
791        Returns:
792            {
793                "completed": [JobResult],
794                "failed": [JobResult],
795                "cancelled": [JobResult]
796            }
797        """
798        return {
799            "completed": list(self.completed_results.values()),
800            "failed": self._create_failed_results(),
801            "cancelled": self._create_cancelled_results()
802        }
803    
804    def get_failed_jobs(self) -> Dict[str, str]:
805        """Get failed jobs with error messages.
806        
807        Note: This method is deprecated. Use results()['failed'] instead.
808        """
809        return dict(self.failed_jobs)
810    
811    def _create_failed_results(self) -> List[JobResult]:
812        """Convert failed jobs to JobResult objects."""
813        failed_results = []
814        for job_id, error_msg in self.failed_jobs.items():
815            failed_results.append(JobResult(
816                job_id=job_id,
817                raw_response=None,
818                parsed_response=None,
819                error=error_msg,
820                cost_usd=0.0,
821                input_tokens=0,
822                output_tokens=0
823            ))
824        return failed_results
825    
826    def _create_cancelled_results(self) -> List[JobResult]:
827        """Convert cancelled jobs to JobResult objects."""
828        cancelled_results = []
829        for job_id, reason in self.cancelled_jobs.items():
830            cancelled_results.append(JobResult(
831                job_id=job_id,
832                raw_response=None,
833                parsed_response=None,
834                error=reason,
835                cost_usd=0.0,
836                input_tokens=0,
837                output_tokens=0
838            ))
839        return cancelled_results
840    
841    def shutdown(self):
842        """Shutdown (no-op for synchronous execution)."""
843        pass
844    
845    def dry_run(self) -> 'BatchRun':
846        """Perform a dry run - show cost estimation and job details without executing.
847        
848        Returns:
849            Self for chaining (doesn't actually execute jobs)
850        """
851        logger.info("=== DRY RUN MODE ===")
852        logger.info("This will show cost estimates without executing jobs")
853        
854        # Load existing state if reuse_state=True
855        if self.config.reuse_state:
856            self.state_manager.load_state(self)
857        
858        # Filter out completed jobs from previous runs
859        self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results]
860        
861        if not self.pending_jobs:
862            logger.info("No pending jobs to analyze (all jobs already completed)")
863            return self
864        
865        logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...")
866        
867        # Group jobs by provider and analyze costs
868        provider_groups = self._group_jobs_by_provider()
869        total_estimated_cost = 0.0
870        
871        logger.info(f"\nJob breakdown:")
872        for provider_name, jobs in provider_groups.items():
873            provider = get_provider(jobs[0].model)
874            logger.info(f"\n{provider_name} ({len(jobs)} jobs):")
875            
876            job_batches = [jobs[i:i + self.config.items_per_batch] 
877                          for i in range(0, len(jobs), self.config.items_per_batch)]
878            
879            for batch_idx, batch_jobs in enumerate(job_batches, 1):
880                estimated_cost = provider.estimate_cost(batch_jobs)
881                total_estimated_cost += estimated_cost
882                
883                logger.info(f"  Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}")
884                for job in batch_jobs:
885                    if job.file:
886                        logger.info(f"    - {job.id}: {job.file.name} (citations: {job.enable_citations})")
887                    else:
888                        logger.info(f"    - {job.id}: direct messages (citations: {job.enable_citations})")
889        
890        # Show cost summary
891        logger.info(f"\n=== COST SUMMARY ===")
892        logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}")
893        
894        if self.config.cost_limit_usd:
895            logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}")
896            if total_estimated_cost > self.config.cost_limit_usd:
897                excess = total_estimated_cost - self.config.cost_limit_usd
898                logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}")
899            else:
900                remaining = self.config.cost_limit_usd - total_estimated_cost
901                logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)")
902        else:
903            logger.info("No cost limit set")
904        
905        # Show execution plan
906        logger.info(f"\n=== EXECUTION PLAN ===")
907        total_batches = sum(
908            len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0)
909            for jobs in provider_groups.values()
910        )
911        logger.info(f"Total batches to process: {total_batches}")
912        logger.info(f"Max parallel batches: {self.config.max_parallel_batches}")
913        logger.info(f"Items per batch: {self.config.items_per_batch}")
914        logger.info(f"Results directory: {self.config.results_dir}")
915        
916        logger.info("\n=== DRY RUN COMPLETE ===")
917        logger.info("To execute for real, call run() without dry_run=True")
918        
919        return self

Manages the execution of a batch job synchronously.

Processes jobs in batches based on items_per_batch configuration. Simpler synchronous execution for clear logging and debugging.

Example:

config = BatchParams(...)
run = BatchRun(config, jobs)
run.execute()
results = run.results()
BatchRun( config: batchata.core.batch_params.BatchParams, jobs: List[Job])
 40    def __init__(self, config: BatchParams, jobs: List[Job]):
 41        """Initialize batch run.
 42        
 43        Args:
 44            config: Batch configuration
 45            jobs: List of jobs to execute
 46        """
 47        self.config = config
 48        self.jobs = {job.id: job for job in jobs}
 49        
 50        # Set logging level based on config
 51        set_log_level(level=config.verbosity.upper())
 52        
 53        # Initialize components
 54        self.cost_tracker = CostTracker(limit_usd=config.cost_limit_usd)
 55        
 56        # Use temp file for state if not provided
 57        state_file = config.state_file
 58        if not state_file:
 59            state_file = create_temp_state_file(config)
 60            config.reuse_state = False
 61            logger.info(f"Created temporary state file: {state_file}")
 62        
 63        self.state_manager = StateManager(state_file)
 64        
 65        # State tracking
 66        self.pending_jobs: List[Job] = []
 67        self.completed_results: Dict[str, JobResult] = {}  # job_id -> result
 68        self.failed_jobs: Dict[str, str] = {}  # job_id -> error
 69        self.cancelled_jobs: Dict[str, str] = {}  # job_id -> reason
 70        
 71        # Batch tracking
 72        self.total_batches = 0
 73        self.completed_batches = 0
 74        self.current_batch_index = 0
 75        self.current_batch_size = 0
 76        
 77        # Execution control
 78        self._started = False
 79        self._start_time: Optional[datetime] = None
 80        self._time_limit_exceeded = False
 81        self._progress_callback: Optional[Callable[[Dict, float], None]] = None
 82        self._progress_interval: float = 1.0  # Default to 1 second
 83        
 84        # Threading primitives
 85        self._state_lock = threading.Lock()
 86        self._shutdown_event = threading.Event()
 87        self._progress_lock = threading.Lock()
 88        self._last_progress_update = 0.0
 89        
 90        # Batch tracking for progress display
 91        self.batch_tracking: Dict[str, Dict] = {}  # batch_id -> batch_info
 92        
 93        # Active batch tracking for cancellation
 94        self._active_batches: Dict[str, object] = {}  # batch_id -> provider
 95        self._active_batches_lock = threading.Lock()
 96        
 97        # Results directory
 98        self.results_dir = Path(config.results_dir)
 99        
100        # If not reusing state, clear the results directory
101        if not config.reuse_state and self.results_dir.exists():
102            import shutil
103            shutil.rmtree(self.results_dir)
104        
105        self.results_dir.mkdir(parents=True, exist_ok=True)
106        
107        # Raw files directory (if enabled)
108        self.raw_files_dir = None
109        if config.raw_files:
110            self.raw_files_dir = self.results_dir / "raw_files"
111            self.raw_files_dir.mkdir(parents=True, exist_ok=True)
112        
113        # Try to resume from saved state
114        self._resume_from_state()

Initialize batch run.

Args: config: Batch configuration jobs: List of jobs to execute

config
jobs
cost_tracker
state_manager
pending_jobs: List[Job]
completed_results: Dict[str, JobResult]
failed_jobs: Dict[str, str]
cancelled_jobs: Dict[str, str]
total_batches
completed_batches
current_batch_index
current_batch_size
batch_tracking: Dict[str, Dict]
results_dir
raw_files_dir
def to_json(self) -> Dict:
178    def to_json(self) -> Dict:
179        """Convert current state to JSON-serializable dict."""
180        return {
181            "created_at": datetime.now().isoformat(),
182            "pending_jobs": [job.to_dict() for job in self.pending_jobs],
183            "completed_results": [
184                {"job_id": job_id, "file_path": str(self.results_dir / f"{job_id}.json")}
185                for job_id in self.completed_results.keys()
186            ],
187            "failed_jobs": [
188                {
189                    "id": job_id, 
190                    "error": error,
191                    "timestamp": datetime.now().isoformat()
192                } for job_id, error in self.failed_jobs.items()
193            ],
194            "cancelled_jobs": [
195                {
196                    "id": job_id, 
197                    "reason": reason,
198                    "timestamp": datetime.now().isoformat()
199                } for job_id, reason in self.cancelled_jobs.items()
200            ],
201            "total_cost_usd": self.cost_tracker.used_usd,
202            "config": {
203                "state_file": self.config.state_file,
204                "results_dir": self.config.results_dir,
205                "max_parallel_batches": self.config.max_parallel_batches,
206                "items_per_batch": self.config.items_per_batch,
207                "cost_limit_usd": self.config.cost_limit_usd,
208                "default_params": self.config.default_params,
209                "raw_files": self.config.raw_files
210            }
211        }

Convert current state to JSON-serializable dict.

def execute(self):
213    def execute(self):
214        """Execute synchronous batch run and wait for completion."""
215        if self._started:
216            raise RuntimeError("Batch run already started")
217        
218        self._started = True
219        self._start_time = datetime.now()
220        
221        # Register signal handler for graceful shutdown
222        def signal_handler(signum, frame):
223            logger.warning("Received interrupt signal, shutting down gracefully...")
224            self._shutdown_event.set()
225        
226        # Store original handler to restore later
227        original_handler = signal.signal(signal.SIGINT, signal_handler)
228        
229        try:
230            logger.info("Starting batch run")
231            
232            # Start time limit watchdog if configured
233            self._start_time_limit_watchdog()
234            
235            # Call initial progress
236            if self._progress_callback:
237                with self._progress_lock:
238                    with self._state_lock:
239                        stats = self.status()
240                        batch_data = dict(self.batch_tracking)
241                    self._progress_callback(stats, 0.0, batch_data)
242                    self._last_progress_update = time.time()
243            
244            # Process all jobs synchronously
245            self._process_all_jobs()
246            
247            logger.info("Batch run completed")
248        finally:
249            # Restore original signal handler
250            signal.signal(signal.SIGINT, original_handler)

Execute synchronous batch run and wait for completion.

def set_on_progress( self, callback: Callable[[Dict, float, Dict], NoneType], interval: float = 1.0) -> BatchRun:
252    def set_on_progress(self, callback: Callable[[Dict, float, Dict], None], interval: float = 1.0) -> 'BatchRun':
253        """Set progress callback for execution monitoring.
254        
255        The callback will be called periodically with progress statistics
256        including completed jobs, total jobs, current cost, etc.
257        
258        Args:
259            callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data)
260                     - stats_dict: Progress statistics dictionary
261                     - elapsed_time_seconds: Time elapsed since batch started (float)
262                     - batch_data: Dictionary mapping batch_id to batch information
263            interval: Interval in seconds between progress updates (default: 1.0)
264            
265        Returns:
266            Self for chaining
267            
268        Example:
269            ```python
270            run.set_on_progress(
271                lambda stats, time, batch_data: print(
272                    f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
273                )
274            )
275            ```
276        """
277        self._progress_callback = callback
278        self._progress_interval = interval
279        return self

Set progress callback for execution monitoring.

The callback will be called periodically with progress statistics including completed jobs, total jobs, current cost, etc.

Args: callback: Function that receives (stats_dict, elapsed_time_seconds, batch_data) - stats_dict: Progress statistics dictionary - elapsed_time_seconds: Time elapsed since batch started (float) - batch_data: Dictionary mapping batch_id to batch information interval: Interval in seconds between progress updates (default: 1.0)

Returns: Self for chaining

Example:

run.set_on_progress(
    lambda stats, time, batch_data: print(
        f"Progress: {stats['completed']}/{stats['total']}, {time:.1f}s"
    )
)
is_complete: bool
741    @property
742    def is_complete(self) -> bool:
743        """Whether all jobs are complete."""
744        total_jobs = len(self.jobs)
745        completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs)
746        return len(self.pending_jobs) == 0 and completed_count == total_jobs

Whether all jobs are complete.

def status(self, print_status: bool = False) -> Dict:
749    def status(self, print_status: bool = False) -> Dict:
750        """Get current execution statistics."""
751        total_jobs = len(self.jobs)
752        completed_count = len(self.completed_results) + len(self.failed_jobs) + len(self.cancelled_jobs)
753        remaining_count = total_jobs - completed_count
754        
755        stats = {
756            "total": total_jobs,
757            "pending": remaining_count,
758            "active": 0,  # Always 0 for synchronous execution
759            "completed": len(self.completed_results),
760            "failed": len(self.failed_jobs),
761            "cancelled": len(self.cancelled_jobs),
762            "cost_usd": self.cost_tracker.used_usd,
763            "cost_limit_usd": self.cost_tracker.limit_usd,
764            "is_complete": self.is_complete,
765            "batches_total": self.total_batches,
766            "batches_completed": self.completed_batches,
767            "batches_pending": self.total_batches - self.completed_batches,
768            "current_batch_index": self.current_batch_index,
769            "current_batch_size": self.current_batch_size,
770            "items_per_batch": self.config.items_per_batch
771        }
772        
773        if print_status:
774            logger.info("\nBatch Run Status:")
775            logger.info(f"  Total jobs: {stats['total']}")
776            logger.info(f"  Pending: {stats['pending']}")
777            logger.info(f"  Active: {stats['active']}")
778            logger.info(f"  Completed: {stats['completed']}")
779            logger.info(f"  Failed: {stats['failed']}")
780            logger.info(f"  Cancelled: {stats['cancelled']}")
781            logger.info(f"  Cost: ${stats['cost_usd']:.6f}")
782            if stats['cost_limit_usd']:
783                logger.info(f"  Cost limit: ${stats['cost_limit_usd']:.2f}")
784            logger.info(f"  Complete: {stats['is_complete']}")
785        
786        return stats

Get current execution statistics.

def results(self) -> Dict[str, List[JobResult]]:
788    def results(self) -> Dict[str, List[JobResult]]:
789        """Get all results organized by status.
790        
791        Returns:
792            {
793                "completed": [JobResult],
794                "failed": [JobResult],
795                "cancelled": [JobResult]
796            }
797        """
798        return {
799            "completed": list(self.completed_results.values()),
800            "failed": self._create_failed_results(),
801            "cancelled": self._create_cancelled_results()
802        }

Get all results organized by status.

Returns: { "completed": [JobResult], "failed": [JobResult], "cancelled": [JobResult] }

def get_failed_jobs(self) -> Dict[str, str]:
804    def get_failed_jobs(self) -> Dict[str, str]:
805        """Get failed jobs with error messages.
806        
807        Note: This method is deprecated. Use results()['failed'] instead.
808        """
809        return dict(self.failed_jobs)

Get failed jobs with error messages.

Note: This method is deprecated. Use results()['failed'] instead.

def shutdown(self):
841    def shutdown(self):
842        """Shutdown (no-op for synchronous execution)."""
843        pass

Shutdown (no-op for synchronous execution).

def dry_run(self) -> BatchRun:
845    def dry_run(self) -> 'BatchRun':
846        """Perform a dry run - show cost estimation and job details without executing.
847        
848        Returns:
849            Self for chaining (doesn't actually execute jobs)
850        """
851        logger.info("=== DRY RUN MODE ===")
852        logger.info("This will show cost estimates without executing jobs")
853        
854        # Load existing state if reuse_state=True
855        if self.config.reuse_state:
856            self.state_manager.load_state(self)
857        
858        # Filter out completed jobs from previous runs
859        self.pending_jobs = [job for job in self.jobs.values() if job.id not in self.completed_results]
860        
861        if not self.pending_jobs:
862            logger.info("No pending jobs to analyze (all jobs already completed)")
863            return self
864        
865        logger.info(f"Analyzing {len(self.pending_jobs)} pending jobs...")
866        
867        # Group jobs by provider and analyze costs
868        provider_groups = self._group_jobs_by_provider()
869        total_estimated_cost = 0.0
870        
871        logger.info(f"\nJob breakdown:")
872        for provider_name, jobs in provider_groups.items():
873            provider = get_provider(jobs[0].model)
874            logger.info(f"\n{provider_name} ({len(jobs)} jobs):")
875            
876            job_batches = [jobs[i:i + self.config.items_per_batch] 
877                          for i in range(0, len(jobs), self.config.items_per_batch)]
878            
879            for batch_idx, batch_jobs in enumerate(job_batches, 1):
880                estimated_cost = provider.estimate_cost(batch_jobs)
881                total_estimated_cost += estimated_cost
882                
883                logger.info(f"  Batch {batch_idx}: {len(batch_jobs)} jobs, estimated cost: ${estimated_cost:.4f}")
884                for job in batch_jobs:
885                    if job.file:
886                        logger.info(f"    - {job.id}: {job.file.name} (citations: {job.enable_citations})")
887                    else:
888                        logger.info(f"    - {job.id}: direct messages (citations: {job.enable_citations})")
889        
890        # Show cost summary
891        logger.info(f"\n=== COST SUMMARY ===")
892        logger.info(f"Total estimated cost: ${total_estimated_cost:.4f}")
893        
894        if self.config.cost_limit_usd:
895            logger.info(f"Cost limit: ${self.config.cost_limit_usd:.2f}")
896            if total_estimated_cost > self.config.cost_limit_usd:
897                excess = total_estimated_cost - self.config.cost_limit_usd
898                logger.warning(f"⚠️ Estimated cost exceeds limit by ${excess:.4f}")
899            else:
900                remaining = self.config.cost_limit_usd - total_estimated_cost
901                logger.info(f"✅ Within cost limit (${remaining:.4f} remaining)")
902        else:
903            logger.info("No cost limit set")
904        
905        # Show execution plan
906        logger.info(f"\n=== EXECUTION PLAN ===")
907        total_batches = sum(
908            len(jobs) // self.config.items_per_batch + (1 if len(jobs) % self.config.items_per_batch else 0)
909            for jobs in provider_groups.values()
910        )
911        logger.info(f"Total batches to process: {total_batches}")
912        logger.info(f"Max parallel batches: {self.config.max_parallel_batches}")
913        logger.info(f"Items per batch: {self.config.items_per_batch}")
914        logger.info(f"Results directory: {self.config.results_dir}")
915        
916        logger.info("\n=== DRY RUN COMPLETE ===")
917        logger.info("To execute for real, call run() without dry_run=True")
918        
919        return self

Perform a dry run - show cost estimation and job details without executing.

Returns: Self for chaining (doesn't actually execute jobs)

@dataclass
class Job:
12@dataclass
13class Job:
14    """Configuration for a single AI job.
15    
16    Either provide messages OR prompt (with optional file), not both.
17    
18    Attributes:
19        id: Unique identifier for the job
20        messages: Chat messages for direct message input
21        file: Optional file path for file-based input
22        prompt: Prompt text (can be used alone or with file)
23        model: Model name (e.g., "claude-3-sonnet")
24        temperature: Sampling temperature (0.0-1.0)
25        max_tokens: Maximum tokens to generate
26        response_model: Pydantic model for structured output
27        enable_citations: Whether to extract citations from response
28    """
29    
30    id: str  # Unique identifier
31    model: str  # Model name (e.g., "claude-3-sonnet")
32    messages: Optional[List[Message]] = None  # Chat messages
33    file: Optional[Path] = None  # File input
34    prompt: Optional[str] = None  # Prompt for file
35    temperature: float = 0.7
36    max_tokens: int = 1000
37    response_model: Optional[Type[BaseModel]] = None  # For structured output
38    enable_citations: bool = False
39    
40    def __post_init__(self):
41        """Validate job configuration."""
42        if self.messages and (self.file or self.prompt):
43            raise ValueError("Provide either messages OR file+prompt, not both")
44        
45        if self.file and not self.prompt:
46            raise ValueError("File input requires a prompt")
47        
48        if not self.messages and not self.prompt:
49            raise ValueError("Must provide either messages or prompt")
50    
51    def to_dict(self) -> Dict[str, Any]:
52        """Serialize for state persistence."""
53        return {
54            "id": self.id,
55            "model": self.model,
56            "messages": self.messages,
57            "file": str(self.file) if self.file else None,
58            "prompt": self.prompt,
59            "temperature": self.temperature,
60            "max_tokens": self.max_tokens,
61            "response_model": self.response_model.__name__ if self.response_model else None,
62            "enable_citations": self.enable_citations
63        }
64    
65    @classmethod
66    def from_dict(cls, data: Dict[str, Any]) -> 'Job':
67        """Deserialize from state."""
68        # Convert file string back to Path if present
69        file_path = None
70        if data.get("file"):
71            file_path = Path(data["file"])
72        
73        # Note: response_model reconstruction would need additional logic
74        # For now, we'll set it to None during deserialization
75        return cls(
76            id=data["id"],
77            model=data["model"],
78            messages=data.get("messages"),
79            file=file_path,
80            prompt=data.get("prompt"),
81            temperature=data.get("temperature", 0.7),
82            max_tokens=data.get("max_tokens", 1000),
83            response_model=None,  # Cannot reconstruct from string
84            enable_citations=data.get("enable_citations", False)
85        )

Configuration for a single AI job.

Either provide messages OR prompt (with optional file), not both.

Attributes: id: Unique identifier for the job messages: Chat messages for direct message input file: Optional file path for file-based input prompt: Prompt text (can be used alone or with file) model: Model name (e.g., "claude-3-sonnet") temperature: Sampling temperature (0.0-1.0) max_tokens: Maximum tokens to generate response_model: Pydantic model for structured output enable_citations: Whether to extract citations from response

Job( id: str, model: str, messages: Optional[List[Dict[str, Any]]] = None, file: Optional[pathlib.Path] = None, prompt: Optional[str] = None, temperature: float = 0.7, max_tokens: int = 1000, response_model: Optional[Type[pydantic.main.BaseModel]] = None, enable_citations: bool = False)
id: str
model: str
messages: Optional[List[Dict[str, Any]]] = None
file: Optional[pathlib.Path] = None
prompt: Optional[str] = None
temperature: float = 0.7
max_tokens: int = 1000
response_model: Optional[Type[pydantic.main.BaseModel]] = None
enable_citations: bool = False
def to_dict(self) -> Dict[str, Any]:
51    def to_dict(self) -> Dict[str, Any]:
52        """Serialize for state persistence."""
53        return {
54            "id": self.id,
55            "model": self.model,
56            "messages": self.messages,
57            "file": str(self.file) if self.file else None,
58            "prompt": self.prompt,
59            "temperature": self.temperature,
60            "max_tokens": self.max_tokens,
61            "response_model": self.response_model.__name__ if self.response_model else None,
62            "enable_citations": self.enable_citations
63        }

Serialize for state persistence.

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Job:
65    @classmethod
66    def from_dict(cls, data: Dict[str, Any]) -> 'Job':
67        """Deserialize from state."""
68        # Convert file string back to Path if present
69        file_path = None
70        if data.get("file"):
71            file_path = Path(data["file"])
72        
73        # Note: response_model reconstruction would need additional logic
74        # For now, we'll set it to None during deserialization
75        return cls(
76            id=data["id"],
77            model=data["model"],
78            messages=data.get("messages"),
79            file=file_path,
80            prompt=data.get("prompt"),
81            temperature=data.get("temperature", 0.7),
82            max_tokens=data.get("max_tokens", 1000),
83            response_model=None,  # Cannot reconstruct from string
84            enable_citations=data.get("enable_citations", False)
85        )

Deserialize from state.

@dataclass
class JobResult:
 11@dataclass
 12class JobResult:
 13    """Result from a completed AI job.
 14    
 15    Attributes:
 16        job_id: ID of the job this result is for
 17        raw_response: Raw text response from the model (None for failed jobs)
 18        parsed_response: Structured output (if response_model was used)
 19        citations: Extracted citations (if enable_citations was True)
 20        citation_mappings: Maps field names to relevant citations (if response_model used)
 21        input_tokens: Number of input tokens used
 22        output_tokens: Number of output tokens generated
 23        cost_usd: Total cost in USD
 24        error: Error message if job failed
 25        batch_id: ID of the batch this job was part of (for mapping to raw files)
 26    """
 27    
 28    job_id: str
 29    raw_response: Optional[str] = None  # Raw text response (None for failed jobs)
 30    parsed_response: Optional[Union[BaseModel, Dict]] = None  # Structured output or error dict
 31    citations: Optional[List[Citation]] = None  # Extracted citations
 32    citation_mappings: Optional[Dict[str, List[Citation]]] = None  # Field -> citations mapping
 33    input_tokens: int = 0
 34    output_tokens: int = 0
 35    cost_usd: float = 0.0
 36    error: Optional[str] = None  # Error message if failed
 37    batch_id: Optional[str] = None  # Batch ID for mapping to raw files
 38    
 39    @property
 40    def is_success(self) -> bool:
 41        """Whether the job completed successfully."""
 42        return self.error is None
 43    
 44    @property
 45    def total_tokens(self) -> int:
 46        """Total tokens used (input + output)."""
 47        return self.input_tokens + self.output_tokens
 48    
 49    def to_dict(self) -> Dict[str, Any]:
 50        """Serialize for state persistence."""
 51        # Handle parsed_response serialization
 52        parsed_response = None
 53        if self.parsed_response is not None:
 54            if isinstance(self.parsed_response, dict):
 55                parsed_response = self.parsed_response
 56            elif isinstance(self.parsed_response, BaseModel):
 57                parsed_response = self.parsed_response.model_dump()
 58            else:
 59                parsed_response = str(self.parsed_response)
 60        
 61        # Handle citation_mappings serialization
 62        citation_mappings = None
 63        if self.citation_mappings:
 64            citation_mappings = {
 65                field: [asdict(c) for c in citations]
 66                for field, citations in self.citation_mappings.items()
 67            }
 68        
 69        return {
 70            "job_id": self.job_id,
 71            "raw_response": self.raw_response,
 72            "parsed_response": parsed_response,
 73            "citations": [asdict(c) for c in self.citations] if self.citations else None,
 74            "citation_mappings": citation_mappings,
 75            "input_tokens": self.input_tokens,
 76            "output_tokens": self.output_tokens,
 77            "cost_usd": self.cost_usd,
 78            "error": self.error,
 79            "batch_id": self.batch_id
 80        }
 81    
 82    @classmethod
 83    def from_dict(cls, data: Dict[str, Any]) -> 'JobResult':
 84        """Deserialize from state."""
 85        # Reconstruct citations if present
 86        citations = None
 87        if data.get("citations"):
 88            citations = [Citation(**c) for c in data["citations"]]
 89        
 90        # Reconstruct citation_mappings if present
 91        citation_mappings = None
 92        if data.get("citation_mappings"):
 93            citation_mappings = {
 94                field: [Citation(**c) for c in citations]
 95                for field, citations in data["citation_mappings"].items()
 96            }
 97        
 98        return cls(
 99            job_id=data["job_id"],
100            raw_response=data["raw_response"],
101            parsed_response=data.get("parsed_response"),
102            citations=citations,
103            citation_mappings=citation_mappings,
104            input_tokens=data.get("input_tokens", 0),
105            output_tokens=data.get("output_tokens", 0),
106            cost_usd=data.get("cost_usd", 0.0),
107            error=data.get("error"),
108            batch_id=data.get("batch_id")
109        )

Result from a completed AI job.

Attributes: job_id: ID of the job this result is for raw_response: Raw text response from the model (None for failed jobs) parsed_response: Structured output (if response_model was used) citations: Extracted citations (if enable_citations was True) citation_mappings: Maps field names to relevant citations (if response_model used) input_tokens: Number of input tokens used output_tokens: Number of output tokens generated cost_usd: Total cost in USD error: Error message if job failed batch_id: ID of the batch this job was part of (for mapping to raw files)

JobResult( job_id: str, raw_response: Optional[str] = None, parsed_response: Union[pydantic.main.BaseModel, Dict, NoneType] = None, citations: Optional[List[Citation]] = None, citation_mappings: Optional[Dict[str, List[Citation]]] = None, input_tokens: int = 0, output_tokens: int = 0, cost_usd: float = 0.0, error: Optional[str] = None, batch_id: Optional[str] = None)
job_id: str
raw_response: Optional[str] = None
parsed_response: Union[pydantic.main.BaseModel, Dict, NoneType] = None
citations: Optional[List[Citation]] = None
citation_mappings: Optional[Dict[str, List[Citation]]] = None
input_tokens: int = 0
output_tokens: int = 0
cost_usd: float = 0.0
error: Optional[str] = None
batch_id: Optional[str] = None
is_success: bool
39    @property
40    def is_success(self) -> bool:
41        """Whether the job completed successfully."""
42        return self.error is None

Whether the job completed successfully.

total_tokens: int
44    @property
45    def total_tokens(self) -> int:
46        """Total tokens used (input + output)."""
47        return self.input_tokens + self.output_tokens

Total tokens used (input + output).

def to_dict(self) -> Dict[str, Any]:
49    def to_dict(self) -> Dict[str, Any]:
50        """Serialize for state persistence."""
51        # Handle parsed_response serialization
52        parsed_response = None
53        if self.parsed_response is not None:
54            if isinstance(self.parsed_response, dict):
55                parsed_response = self.parsed_response
56            elif isinstance(self.parsed_response, BaseModel):
57                parsed_response = self.parsed_response.model_dump()
58            else:
59                parsed_response = str(self.parsed_response)
60        
61        # Handle citation_mappings serialization
62        citation_mappings = None
63        if self.citation_mappings:
64            citation_mappings = {
65                field: [asdict(c) for c in citations]
66                for field, citations in self.citation_mappings.items()
67            }
68        
69        return {
70            "job_id": self.job_id,
71            "raw_response": self.raw_response,
72            "parsed_response": parsed_response,
73            "citations": [asdict(c) for c in self.citations] if self.citations else None,
74            "citation_mappings": citation_mappings,
75            "input_tokens": self.input_tokens,
76            "output_tokens": self.output_tokens,
77            "cost_usd": self.cost_usd,
78            "error": self.error,
79            "batch_id": self.batch_id
80        }

Serialize for state persistence.

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> JobResult:
 82    @classmethod
 83    def from_dict(cls, data: Dict[str, Any]) -> 'JobResult':
 84        """Deserialize from state."""
 85        # Reconstruct citations if present
 86        citations = None
 87        if data.get("citations"):
 88            citations = [Citation(**c) for c in data["citations"]]
 89        
 90        # Reconstruct citation_mappings if present
 91        citation_mappings = None
 92        if data.get("citation_mappings"):
 93            citation_mappings = {
 94                field: [Citation(**c) for c in citations]
 95                for field, citations in data["citation_mappings"].items()
 96            }
 97        
 98        return cls(
 99            job_id=data["job_id"],
100            raw_response=data["raw_response"],
101            parsed_response=data.get("parsed_response"),
102            citations=citations,
103            citation_mappings=citation_mappings,
104            input_tokens=data.get("input_tokens", 0),
105            output_tokens=data.get("output_tokens", 0),
106            cost_usd=data.get("cost_usd", 0.0),
107            error=data.get("error"),
108            batch_id=data.get("batch_id")
109        )

Deserialize from state.

@dataclass
class Citation:
 8@dataclass
 9class Citation:
10    """Represents a citation extracted from an AI response."""
11    
12    text: str  # The cited text
13    source: str  # Source identifier (e.g., page number, section)
14    page: Optional[int] = None  # Page number if applicable
15    metadata: Optional[Dict[str, Any]] = None  # Additional metadata

Represents a citation extracted from an AI response.

Citation( text: str, source: str, page: Optional[int] = None, metadata: Optional[Dict[str, Any]] = None)
text: str
source: str
page: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
class BatchataError(builtins.Exception):
5class BatchataError(Exception):
6    """Base exception for all Batchata errors."""
7    pass

Base exception for all Batchata errors.

class CostLimitExceededError(batchata.BatchataError):
35class CostLimitExceededError(BatchataError):
36    """Raised when cost limit would be exceeded."""
37    pass

Raised when cost limit would be exceeded.

class ProviderError(batchata.BatchataError):
15class ProviderError(BatchataError):
16    """Base exception for provider-related errors."""
17    pass

Base exception for provider-related errors.

class ProviderNotFoundError(batchata.ProviderError):
20class ProviderNotFoundError(ProviderError):
21    """Raised when no provider is found for a model."""
22    pass

Raised when no provider is found for a model.

class ValidationError(batchata.BatchataError):
10class ValidationError(BatchataError):
11    """Raised when job or configuration validation fails."""
12    pass

Raised when job or configuration validation fails.