Coverage for src/chat_limiter/adapters.py: 93%

150 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 13:37 +0100

1""" 

2Provider-specific adapters for converting between our unified types and provider APIs. 

3""" 

4 

5import time 

6from abc import ABC, abstractmethod 

7from typing import Any 

8 

9from .providers import Provider 

10from .types import ( 

11 ChatCompletionRequest, 

12 ChatCompletionResponse, 

13 Choice, 

14 Message, 

15 MessageRole, 

16 Usage, 

17) 

18 

19 

20class ProviderAdapter(ABC): 

21 """Abstract base class for provider-specific adapters.""" 

22 

23 @abstractmethod 

24 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

25 """Convert our request format to provider-specific format.""" 

26 pass 

27 

28 @abstractmethod 

29 def parse_response( 

30 self, 

31 response_data: dict[str, Any], 

32 original_request: ChatCompletionRequest 

33 ) -> ChatCompletionResponse: 

34 """Convert provider response to our unified format.""" 

35 pass 

36 

37 @abstractmethod 

38 def get_endpoint(self) -> str: 

39 """Get the API endpoint for this provider.""" 

40 pass 

41 

42 

43class OpenAIAdapter(ProviderAdapter): 

44 """Adapter for OpenAI API.""" 

45 

46 def is_reasoning_model(self, model_name: str) -> bool: 

47 """Check if the model is a reasoning model that requires max_completion_tokens.""" 

48 return model_name.startswith(("o1", "o3", "o4")) 

49 

50 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

51 """Convert to OpenAI format.""" 

52 # Convert messages 

53 messages: list[dict[str, Any]] = [] 

54 for msg in request.messages: 

55 messages.append({ 

56 "role": msg.role.value, 

57 "content": msg.content 

58 }) 

59 

60 # Build request 

61 openai_request: dict[str, Any] = { 

62 "model": request.model, 

63 "messages": messages, 

64 } 

65 

66 # Add optional parameters 

67 if request.max_tokens is not None: 

68 # Use max_completion_tokens for reasoning models (o1, o3, o4) 

69 if self.is_reasoning_model(request.model): 

70 openai_request["max_completion_tokens"] = request.max_tokens 

71 else: 

72 openai_request["max_tokens"] = request.max_tokens 

73 if request.temperature is not None: 

74 openai_request["temperature"] = request.temperature 

75 if request.top_p is not None: 

76 openai_request["top_p"] = request.top_p 

77 if request.stop is not None: 

78 openai_request["stop"] = request.stop 

79 if request.stream: 

80 openai_request["stream"] = request.stream 

81 if request.frequency_penalty is not None: 

82 openai_request["frequency_penalty"] = request.frequency_penalty 

83 if request.presence_penalty is not None: 

84 openai_request["presence_penalty"] = request.presence_penalty 

85 

86 return openai_request 

87 

88 def parse_response( 

89 self, 

90 response_data: dict[str, Any], 

91 original_request: ChatCompletionRequest 

92 ) -> ChatCompletionResponse: 

93 """Parse OpenAI response.""" 

94 # Check for errors first 

95 has_error = False 

96 error_message = None 

97 

98 if "error" in response_data: 

99 has_error = True 

100 error_data = response_data["error"] 

101 error_message = error_data.get("message", "Unknown error") 

102 

103 choices = [] 

104 for choice_data in response_data.get("choices", []): 

105 message_data = choice_data.get("message", {}) 

106 message = Message( 

107 role=MessageRole(message_data.get("role", "assistant")), 

108 content=message_data.get("content", "") 

109 ) 

110 choice = Choice( 

111 index=choice_data.get("index", 0), 

112 message=message, 

113 finish_reason=choice_data.get("finish_reason") 

114 ) 

115 choices.append(choice) 

116 

117 # Parse usage 

118 usage = None 

119 if "usage" in response_data: 

120 usage_data = response_data["usage"] 

121 usage = Usage( 

122 prompt_tokens=usage_data.get("prompt_tokens", 0), 

123 completion_tokens=usage_data.get("completion_tokens", 0), 

124 total_tokens=usage_data.get("total_tokens", 0) 

125 ) 

126 

127 return ChatCompletionResponse( 

128 id=response_data.get("id", ""), 

129 model=response_data.get("model", original_request.model), 

130 choices=choices, 

131 usage=usage, 

132 created=response_data.get("created"), 

133 has_error=has_error, 

134 error_message=error_message, 

135 provider="openai", 

136 raw_response=response_data 

137 ) 

138 

139 def get_endpoint(self) -> str: 

140 return "/chat/completions" 

141 

142 

143class AnthropicAdapter(ProviderAdapter): 

144 """Adapter for Anthropic API.""" 

145 

146 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

147 """Convert to Anthropic format.""" 

148 # Anthropic has a different message format 

149 messages: list[dict[str, Any]] = [] 

150 system_message: str | None = None 

151 

152 for msg in request.messages: 

153 if msg.role == MessageRole.SYSTEM: 

154 # Anthropic handles system messages separately 

155 system_message = msg.content 

156 else: 

157 messages.append({ 

158 "role": msg.role.value, 

159 "content": msg.content 

160 }) 

161 

162 # Build request 

163 anthropic_request: dict[str, Any] = { 

164 "model": request.model, 

165 "messages": messages, 

166 "max_tokens": request.max_tokens or 1024, # Required for Anthropic 

167 } 

168 

169 # Add system message if present 

170 if system_message: 

171 anthropic_request["system"] = system_message 

172 

173 # Add optional parameters 

174 if request.temperature is not None: 

175 anthropic_request["temperature"] = request.temperature 

176 if request.top_p is not None: 

177 anthropic_request["top_p"] = request.top_p 

178 if request.stop is not None: 

179 anthropic_request["stop_sequences"] = ( 

180 [request.stop] if isinstance(request.stop, str) else request.stop 

181 ) 

182 if request.stream: 

183 anthropic_request["stream"] = request.stream 

184 if request.top_k is not None: 

185 anthropic_request["top_k"] = request.top_k 

186 

187 return anthropic_request 

188 

189 def parse_response( 

190 self, 

191 response_data: dict[str, Any], 

192 original_request: ChatCompletionRequest 

193 ) -> ChatCompletionResponse: 

194 """Parse Anthropic response.""" 

195 # Check for errors first 

196 has_error = False 

197 error_message = None 

198 

199 if "error" in response_data: 

200 has_error = True 

201 error_data = response_data["error"] 

202 error_message = error_data.get("message", "Unknown error") 

203 

204 # Anthropic returns content differently 

205 content_blocks = response_data.get("content", []) 

206 content = "" 

207 if content_blocks: 

208 # Extract text from content blocks 

209 for block in content_blocks: 

210 if block.get("type") == "text": 

211 content += block.get("text", "") 

212 

213 message = Message( 

214 role=MessageRole.ASSISTANT, 

215 content=content 

216 ) 

217 

218 choice = Choice( 

219 index=0, 

220 message=message, 

221 finish_reason=response_data.get("stop_reason") 

222 ) 

223 

224 # Parse usage 

225 usage = None 

226 if "usage" in response_data: 

227 usage_data = response_data["usage"] 

228 usage = Usage( 

229 prompt_tokens=usage_data.get("input_tokens", 0), 

230 completion_tokens=usage_data.get("output_tokens", 0), 

231 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0) 

232 ) 

233 

234 return ChatCompletionResponse( 

235 id=response_data.get("id", ""), 

236 model=response_data.get("model", original_request.model), 

237 choices=[choice], 

238 usage=usage, 

239 created=int(time.time()), # Anthropic doesn't provide created timestamp 

240 has_error=has_error, 

241 error_message=error_message, 

242 provider="anthropic", 

243 raw_response=response_data 

244 ) 

245 

246 def get_endpoint(self) -> str: 

247 return "/messages" 

248 

249 

250class OpenRouterAdapter(ProviderAdapter): 

251 """Adapter for OpenRouter API.""" 

252 

253 def format_request(self, request: ChatCompletionRequest) -> dict[str, Any]: 

254 """Convert to OpenRouter format (similar to OpenAI).""" 

255 # OpenRouter uses OpenAI-compatible format 

256 messages: list[dict[str, Any]] = [] 

257 for msg in request.messages: 

258 messages.append({ 

259 "role": msg.role.value, 

260 "content": msg.content 

261 }) 

262 

263 # Build request 

264 openrouter_request: dict[str, Any] = { 

265 "model": request.model, 

266 "messages": messages, 

267 } 

268 

269 # Add optional parameters 

270 if request.max_tokens is not None: 

271 openrouter_request["max_tokens"] = request.max_tokens 

272 if request.temperature is not None: 

273 openrouter_request["temperature"] = request.temperature 

274 if request.top_p is not None: 

275 openrouter_request["top_p"] = request.top_p 

276 if request.stop is not None: 

277 openrouter_request["stop"] = request.stop 

278 if request.stream: 

279 openrouter_request["stream"] = request.stream 

280 if request.frequency_penalty is not None: 

281 openrouter_request["frequency_penalty"] = request.frequency_penalty 

282 if request.presence_penalty is not None: 

283 openrouter_request["presence_penalty"] = request.presence_penalty 

284 if request.top_k is not None: 

285 openrouter_request["top_k"] = request.top_k 

286 

287 return openrouter_request 

288 

289 def parse_response( 

290 self, 

291 response_data: dict[str, Any], 

292 original_request: ChatCompletionRequest 

293 ) -> ChatCompletionResponse: 

294 """Parse OpenRouter response (similar to OpenAI).""" 

295 # Check for errors first 

296 has_error = False 

297 error_message = None 

298 

299 if "error" in response_data: 

300 has_error = True 

301 error_data = response_data["error"] 

302 error_message = error_data.get("message", "Unknown error") 

303 

304 choices = [] 

305 for choice_data in response_data.get("choices", []): 

306 message_data = choice_data.get("message", {}) 

307 message = Message( 

308 role=MessageRole(message_data.get("role", "assistant")), 

309 content=message_data.get("content", "") 

310 ) 

311 choice = Choice( 

312 index=choice_data.get("index", 0), 

313 message=message, 

314 finish_reason=choice_data.get("finish_reason") 

315 ) 

316 choices.append(choice) 

317 

318 # Parse usage 

319 usage = None 

320 if "usage" in response_data: 

321 usage_data = response_data["usage"] 

322 usage = Usage( 

323 prompt_tokens=usage_data.get("prompt_tokens", 0), 

324 completion_tokens=usage_data.get("completion_tokens", 0), 

325 total_tokens=usage_data.get("total_tokens", 0) 

326 ) 

327 

328 return ChatCompletionResponse( 

329 id=response_data.get("id", ""), 

330 model=response_data.get("model", original_request.model), 

331 choices=choices, 

332 usage=usage, 

333 created=response_data.get("created"), 

334 has_error=has_error, 

335 error_message=error_message, 

336 provider="openrouter", 

337 raw_response=response_data 

338 ) 

339 

340 def get_endpoint(self) -> str: 

341 return "/chat/completions" 

342 

343 

344# Provider adapter registry 

345PROVIDER_ADAPTERS = { 

346 Provider.OPENAI: OpenAIAdapter(), 

347 Provider.ANTHROPIC: AnthropicAdapter(), 

348 Provider.OPENROUTER: OpenRouterAdapter(), 

349} 

350 

351 

352def get_adapter(provider: Provider) -> ProviderAdapter: 

353 """Get the appropriate adapter for a provider.""" 

354 return PROVIDER_ADAPTERS[provider]