Coverage for src/chat_limiter/adapters.py: 93%
150 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 13:37 +0100
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 13:37 +0100
1"""
2Provider-specific adapters for converting between our unified types and provider APIs.
3"""
5import time
6from abc import ABC, abstractmethod
7from typing import Any
9from .providers import Provider
10from .types import (
11 ChatCompletionRequest,
12 ChatCompletionResponse,
13 Choice,
14 Message,
15 MessageRole,
16 Usage,
17)
20class ProviderAdapter(ABC):
21 """Abstract base class for provider-specific adapters."""
23 @abstractmethod
24 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
25 """Convert our request format to provider-specific format."""
26 pass
28 @abstractmethod
29 def parse_response(
30 self,
31 response_data: dict[str, Any],
32 original_request: ChatCompletionRequest
33 ) -> ChatCompletionResponse:
34 """Convert provider response to our unified format."""
35 pass
37 @abstractmethod
38 def get_endpoint(self) -> str:
39 """Get the API endpoint for this provider."""
40 pass
43class OpenAIAdapter(ProviderAdapter):
44 """Adapter for OpenAI API."""
46 def is_reasoning_model(self, model_name: str) -> bool:
47 """Check if the model is a reasoning model that requires max_completion_tokens."""
48 return model_name.startswith(("o1", "o3", "o4"))
50 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
51 """Convert to OpenAI format."""
52 # Convert messages
53 messages: list[dict[str, Any]] = []
54 for msg in request.messages:
55 messages.append({
56 "role": msg.role.value,
57 "content": msg.content
58 })
60 # Build request
61 openai_request: dict[str, Any] = {
62 "model": request.model,
63 "messages": messages,
64 }
66 # Add optional parameters
67 if request.max_tokens is not None:
68 # Use max_completion_tokens for reasoning models (o1, o3, o4)
69 if self.is_reasoning_model(request.model):
70 openai_request["max_completion_tokens"] = request.max_tokens
71 else:
72 openai_request["max_tokens"] = request.max_tokens
73 if request.temperature is not None:
74 openai_request["temperature"] = request.temperature
75 if request.top_p is not None:
76 openai_request["top_p"] = request.top_p
77 if request.stop is not None:
78 openai_request["stop"] = request.stop
79 if request.stream:
80 openai_request["stream"] = request.stream
81 if request.frequency_penalty is not None:
82 openai_request["frequency_penalty"] = request.frequency_penalty
83 if request.presence_penalty is not None:
84 openai_request["presence_penalty"] = request.presence_penalty
86 return openai_request
88 def parse_response(
89 self,
90 response_data: dict[str, Any],
91 original_request: ChatCompletionRequest
92 ) -> ChatCompletionResponse:
93 """Parse OpenAI response."""
94 # Check for errors first
95 has_error = False
96 error_message = None
98 if "error" in response_data:
99 has_error = True
100 error_data = response_data["error"]
101 error_message = error_data.get("message", "Unknown error")
103 choices = []
104 for choice_data in response_data.get("choices", []):
105 message_data = choice_data.get("message", {})
106 message = Message(
107 role=MessageRole(message_data.get("role", "assistant")),
108 content=message_data.get("content", "")
109 )
110 choice = Choice(
111 index=choice_data.get("index", 0),
112 message=message,
113 finish_reason=choice_data.get("finish_reason")
114 )
115 choices.append(choice)
117 # Parse usage
118 usage = None
119 if "usage" in response_data:
120 usage_data = response_data["usage"]
121 usage = Usage(
122 prompt_tokens=usage_data.get("prompt_tokens", 0),
123 completion_tokens=usage_data.get("completion_tokens", 0),
124 total_tokens=usage_data.get("total_tokens", 0)
125 )
127 return ChatCompletionResponse(
128 id=response_data.get("id", ""),
129 model=response_data.get("model", original_request.model),
130 choices=choices,
131 usage=usage,
132 created=response_data.get("created"),
133 has_error=has_error,
134 error_message=error_message,
135 provider="openai",
136 raw_response=response_data
137 )
139 def get_endpoint(self) -> str:
140 return "/chat/completions"
143class AnthropicAdapter(ProviderAdapter):
144 """Adapter for Anthropic API."""
146 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
147 """Convert to Anthropic format."""
148 # Anthropic has a different message format
149 messages: list[dict[str, Any]] = []
150 system_message: str | None = None
152 for msg in request.messages:
153 if msg.role == MessageRole.SYSTEM:
154 # Anthropic handles system messages separately
155 system_message = msg.content
156 else:
157 messages.append({
158 "role": msg.role.value,
159 "content": msg.content
160 })
162 # Build request
163 anthropic_request: dict[str, Any] = {
164 "model": request.model,
165 "messages": messages,
166 "max_tokens": request.max_tokens or 1024, # Required for Anthropic
167 }
169 # Add system message if present
170 if system_message:
171 anthropic_request["system"] = system_message
173 # Add optional parameters
174 if request.temperature is not None:
175 anthropic_request["temperature"] = request.temperature
176 if request.top_p is not None:
177 anthropic_request["top_p"] = request.top_p
178 if request.stop is not None:
179 anthropic_request["stop_sequences"] = (
180 [request.stop] if isinstance(request.stop, str) else request.stop
181 )
182 if request.stream:
183 anthropic_request["stream"] = request.stream
184 if request.top_k is not None:
185 anthropic_request["top_k"] = request.top_k
187 return anthropic_request
189 def parse_response(
190 self,
191 response_data: dict[str, Any],
192 original_request: ChatCompletionRequest
193 ) -> ChatCompletionResponse:
194 """Parse Anthropic response."""
195 # Check for errors first
196 has_error = False
197 error_message = None
199 if "error" in response_data:
200 has_error = True
201 error_data = response_data["error"]
202 error_message = error_data.get("message", "Unknown error")
204 # Anthropic returns content differently
205 content_blocks = response_data.get("content", [])
206 content = ""
207 if content_blocks:
208 # Extract text from content blocks
209 for block in content_blocks:
210 if block.get("type") == "text":
211 content += block.get("text", "")
213 message = Message(
214 role=MessageRole.ASSISTANT,
215 content=content
216 )
218 choice = Choice(
219 index=0,
220 message=message,
221 finish_reason=response_data.get("stop_reason")
222 )
224 # Parse usage
225 usage = None
226 if "usage" in response_data:
227 usage_data = response_data["usage"]
228 usage = Usage(
229 prompt_tokens=usage_data.get("input_tokens", 0),
230 completion_tokens=usage_data.get("output_tokens", 0),
231 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0)
232 )
234 return ChatCompletionResponse(
235 id=response_data.get("id", ""),
236 model=response_data.get("model", original_request.model),
237 choices=[choice],
238 usage=usage,
239 created=int(time.time()), # Anthropic doesn't provide created timestamp
240 has_error=has_error,
241 error_message=error_message,
242 provider="anthropic",
243 raw_response=response_data
244 )
246 def get_endpoint(self) -> str:
247 return "/messages"
250class OpenRouterAdapter(ProviderAdapter):
251 """Adapter for OpenRouter API."""
253 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]:
254 """Convert to OpenRouter format (similar to OpenAI)."""
255 # OpenRouter uses OpenAI-compatible format
256 messages: list[dict[str, Any]] = []
257 for msg in request.messages:
258 messages.append({
259 "role": msg.role.value,
260 "content": msg.content
261 })
263 # Build request
264 openrouter_request: dict[str, Any] = {
265 "model": request.model,
266 "messages": messages,
267 }
269 # Add optional parameters
270 if request.max_tokens is not None:
271 openrouter_request["max_tokens"] = request.max_tokens
272 if request.temperature is not None:
273 openrouter_request["temperature"] = request.temperature
274 if request.top_p is not None:
275 openrouter_request["top_p"] = request.top_p
276 if request.stop is not None:
277 openrouter_request["stop"] = request.stop
278 if request.stream:
279 openrouter_request["stream"] = request.stream
280 if request.frequency_penalty is not None:
281 openrouter_request["frequency_penalty"] = request.frequency_penalty
282 if request.presence_penalty is not None:
283 openrouter_request["presence_penalty"] = request.presence_penalty
284 if request.top_k is not None:
285 openrouter_request["top_k"] = request.top_k
287 return openrouter_request
289 def parse_response(
290 self,
291 response_data: dict[str, Any],
292 original_request: ChatCompletionRequest
293 ) -> ChatCompletionResponse:
294 """Parse OpenRouter response (similar to OpenAI)."""
295 # Check for errors first
296 has_error = False
297 error_message = None
299 if "error" in response_data:
300 has_error = True
301 error_data = response_data["error"]
302 error_message = error_data.get("message", "Unknown error")
304 choices = []
305 for choice_data in response_data.get("choices", []):
306 message_data = choice_data.get("message", {})
307 message = Message(
308 role=MessageRole(message_data.get("role", "assistant")),
309 content=message_data.get("content", "")
310 )
311 choice = Choice(
312 index=choice_data.get("index", 0),
313 message=message,
314 finish_reason=choice_data.get("finish_reason")
315 )
316 choices.append(choice)
318 # Parse usage
319 usage = None
320 if "usage" in response_data:
321 usage_data = response_data["usage"]
322 usage = Usage(
323 prompt_tokens=usage_data.get("prompt_tokens", 0),
324 completion_tokens=usage_data.get("completion_tokens", 0),
325 total_tokens=usage_data.get("total_tokens", 0)
326 )
328 return ChatCompletionResponse(
329 id=response_data.get("id", ""),
330 model=response_data.get("model", original_request.model),
331 choices=choices,
332 usage=usage,
333 created=response_data.get("created"),
334 has_error=has_error,
335 error_message=error_message,
336 provider="openrouter",
337 raw_response=response_data
338 )
340 def get_endpoint(self) -> str:
341 return "/chat/completions"
344# Provider adapter registry
345PROVIDER_ADAPTERS = {
346 Provider.OPENAI: OpenAIAdapter(),
347 Provider.ANTHROPIC: AnthropicAdapter(),
348 Provider.OPENROUTER: OpenRouterAdapter(),
349}
352def get_adapter(provider: Provider) -> ProviderAdapter:
353 """Get the appropriate adapter for a provider."""
354 return PROVIDER_ADAPTERS[provider]