blob: aa6151f1d7b6e59f1b8df159dfd1d7e363d5f082 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package oai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "math/rand/v2"
11 "net/http"
Philip Zeyliger72252cb2025-05-10 17:00:08 -070012 "strings"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070013 "time"
14
15 "github.com/sashabaranov/go-openai"
16 "sketch.dev/llm"
17)
18
19const (
20 DefaultMaxTokens = 8192
21
22 OpenAIURL = "https://api.openai.com/v1"
23 FireworksURL = "https://api.fireworks.ai/inference/v1"
24 LlamaCPPURL = "http://localhost:8080/v1"
25 TogetherURL = "https://api.together.xyz/v1"
26 GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070027 MistralURL = "https://api.mistral.ai/v1"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070028
29 // Environment variable names for API keys
30 OpenAIAPIKeyEnv = "OPENAI_API_KEY"
31 FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
32 TogetherAPIKeyEnv = "TOGETHER_API_KEY"
33 GeminiAPIKeyEnv = "GEMINI_API_KEY"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070034 MistralAPIKeyEnv = "MISTRAL_API_KEY"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070035)
36
37type Model struct {
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070038 UserName string // provided by the user to identify this model (e.g. "gpt4.1")
39 ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
40 URL string
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070041 APIKeyEnv string // environment variable name for the API key
42 IsReasoningModel bool // whether this model is a reasoning model (e.g. O3, O4-mini)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070043}
44
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070045var (
46 DefaultModel = GPT41
47
48 GPT41 = Model{
49 UserName: "gpt4.1",
50 ModelName: "gpt-4.1-2025-04-14",
51 URL: OpenAIURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070052 APIKeyEnv: OpenAIAPIKeyEnv,
53 }
54
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070055 GPT4o = Model{
56 UserName: "gpt4o",
57 ModelName: "gpt-4o-2024-08-06",
58 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070059 APIKeyEnv: OpenAIAPIKeyEnv,
60 }
61
62 GPT4oMini = Model{
63 UserName: "gpt4o-mini",
64 ModelName: "gpt-4o-mini-2024-07-18",
65 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070066 APIKeyEnv: OpenAIAPIKeyEnv,
67 }
68
69 GPT41Mini = Model{
70 UserName: "gpt4.1-mini",
71 ModelName: "gpt-4.1-mini-2025-04-14",
72 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070073 APIKeyEnv: OpenAIAPIKeyEnv,
74 }
75
76 GPT41Nano = Model{
77 UserName: "gpt4.1-nano",
78 ModelName: "gpt-4.1-nano-2025-04-14",
79 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070080 APIKeyEnv: OpenAIAPIKeyEnv,
81 }
82
83 O3 = Model{
84 UserName: "o3",
85 ModelName: "o3-2025-04-16",
86 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070087 APIKeyEnv: OpenAIAPIKeyEnv,
88 IsReasoningModel: true,
89 }
90
91 O4Mini = Model{
92 UserName: "o4-mini",
93 ModelName: "o4-mini-2025-04-16",
94 URL: OpenAIURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070095 APIKeyEnv: OpenAIAPIKeyEnv,
96 IsReasoningModel: true,
97 }
98
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070099 Gemini25Flash = Model{
100 UserName: "gemini-flash-2.5",
101 ModelName: "gemini-2.5-flash-preview-04-17",
102 URL: GeminiURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700103 APIKeyEnv: GeminiAPIKeyEnv,
104 }
105
106 Gemini25Pro = Model{
107 UserName: "gemini-pro-2.5",
108 ModelName: "gemini-2.5-pro-preview-03-25",
109 URL: GeminiURL,
110 // GRRRR. Really??
111 // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
112 // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
113 // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
114 // Whatever that means. Are we caching? I have no idea.
115 // How do you always manage to be the annoying one, Google?
116 // I'm not complicating things just for you.
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700117 APIKeyEnv: GeminiAPIKeyEnv,
118 }
119
120 TogetherDeepseekV3 = Model{
121 UserName: "together-deepseek-v3",
122 ModelName: "deepseek-ai/DeepSeek-V3",
123 URL: TogetherURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700124 APIKeyEnv: TogetherAPIKeyEnv,
125 }
126
Josh Bleecher Snyderd1bd5192025-06-02 14:10:52 -0700127 TogetherDeepseekR1 = Model{
128 UserName: "together-deepseek-r1",
129 ModelName: "deepseek-ai/DeepSeek-R1",
130 URL: TogetherURL,
Josh Bleecher Snyderd1bd5192025-06-02 14:10:52 -0700131 APIKeyEnv: TogetherAPIKeyEnv,
132 }
133
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700134 TogetherLlama4Maverick = Model{
135 UserName: "together-llama4-maverick",
136 ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
137 URL: TogetherURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700138 APIKeyEnv: TogetherAPIKeyEnv,
139 }
140
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700141 FireworksLlama4Maverick = Model{
142 UserName: "fireworks-llama4-maverick",
143 ModelName: "accounts/fireworks/models/llama4-maverick-instruct-basic",
144 URL: FireworksURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700145 APIKeyEnv: FireworksAPIKeyEnv,
146 }
147
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700148 TogetherLlama3_3_70B = Model{
149 UserName: "together-llama3-70b",
150 ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
151 URL: TogetherURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700152 APIKeyEnv: TogetherAPIKeyEnv,
153 }
154
155 TogetherMistralSmall = Model{
156 UserName: "together-mistral-small",
157 ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
158 URL: TogetherURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700159 APIKeyEnv: TogetherAPIKeyEnv,
160 }
161
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700162 TogetherQwen3 = Model{
163 UserName: "together-qwen3",
164 ModelName: "Qwen/Qwen3-235B-A22B-fp8-tput",
165 URL: TogetherURL,
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700166 APIKeyEnv: TogetherAPIKeyEnv,
167 }
168
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700169 TogetherGemma2 = Model{
170 UserName: "together-gemma2",
171 ModelName: "google/gemma-2-27b-it",
172 URL: TogetherURL,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700173 APIKeyEnv: TogetherAPIKeyEnv,
174 }
175
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700176 LlamaCPP = Model{
177 UserName: "llama.cpp",
178 ModelName: "llama.cpp local model",
179 URL: LlamaCPPURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700180 }
181
182 FireworksDeepseekV3 = Model{
183 UserName: "fireworks-deepseek-v3",
184 ModelName: "accounts/fireworks/models/deepseek-v3-0324",
185 URL: FireworksURL,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700186 APIKeyEnv: FireworksAPIKeyEnv,
187 }
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -0700188
189 MistralMedium = Model{
190 UserName: "mistral-medium-3",
191 ModelName: "mistral-medium-latest",
192 URL: MistralURL,
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -0700193 APIKeyEnv: MistralAPIKeyEnv,
194 }
Josh Bleecher Snyder1a648f32025-05-21 17:15:04 +0000195
196 DevstralSmall = Model{
197 UserName: "devstral-small",
198 ModelName: "devstral-small-latest",
199 URL: MistralURL,
Josh Bleecher Snyder1a648f32025-05-21 17:15:04 +0000200 APIKeyEnv: MistralAPIKeyEnv,
201 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700202)
203
204// Service provides chat completions.
205// Fields should not be altered concurrently with calling any method on Service.
206type Service struct {
207 HTTPC *http.Client // defaults to http.DefaultClient if nil
208 APIKey string // optional, if not set will try to load from env var
209 Model Model // defaults to DefaultModel if zero value
210 MaxTokens int // defaults to DefaultMaxTokens if zero
211 Org string // optional - organization ID
212}
213
214var _ llm.Service = (*Service)(nil)
215
216// ModelsRegistry is a registry of all known models with their user-friendly names.
217var ModelsRegistry = []Model{
218 GPT41,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700219 GPT41Mini,
220 GPT41Nano,
221 GPT4o,
222 GPT4oMini,
223 O3,
224 O4Mini,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700225 Gemini25Flash,
226 Gemini25Pro,
227 TogetherDeepseekV3,
Josh Bleecher Snyderd1bd5192025-06-02 14:10:52 -0700228 TogetherDeepseekR1,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700229 TogetherLlama4Maverick,
230 TogetherLlama3_3_70B,
231 TogetherMistralSmall,
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700232 TogetherQwen3,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700233 TogetherGemma2,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700234 LlamaCPP,
235 FireworksDeepseekV3,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700236 FireworksLlama4Maverick,
237 MistralMedium,
Josh Bleecher Snyder1a648f32025-05-21 17:15:04 +0000238 DevstralSmall,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700239}
240
241// ListModels returns a list of all available models with their user-friendly names.
242func ListModels() []string {
243 var names []string
244 for _, model := range ModelsRegistry {
245 if model.UserName != "" {
246 names = append(names, model.UserName)
247 }
248 }
249 return names
250}
251
252// ModelByUserName returns a model by its user-friendly name.
253// Returns nil if no model with the given name is found.
254func ModelByUserName(name string) *Model {
255 for _, model := range ModelsRegistry {
256 if model.UserName == name {
257 return &model
258 }
259 }
260 return nil
261}
262
263var (
264 fromLLMRole = map[llm.MessageRole]string{
265 llm.MessageRoleAssistant: "assistant",
266 llm.MessageRoleUser: "user",
267 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700268 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
269 llm.ToolChoiceTypeAuto: "auto",
270 llm.ToolChoiceTypeAny: "any",
271 llm.ToolChoiceTypeNone: "none",
272 llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
273 }
274 toLLMRole = map[string]llm.MessageRole{
275 "assistant": llm.MessageRoleAssistant,
276 "user": llm.MessageRoleUser,
277 }
278 toLLMStopReason = map[string]llm.StopReason{
279 "stop": llm.StopReasonStopSequence,
280 "length": llm.StopReasonMaxTokens,
281 "tool_calls": llm.StopReasonToolUse,
282 "function_call": llm.StopReasonToolUse, // Map both to ToolUse
283 "content_filter": llm.StopReasonStopSequence, // No direct equivalent
284 }
285)
286
287// fromLLMContent converts llm.Content to the format expected by OpenAI.
288func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
289 switch c.Type {
290 case llm.ContentTypeText:
291 return c.Text, nil
292 case llm.ContentTypeToolUse:
293 // For OpenAI, tool use is sent as a null content with tool_calls in the message
294 return "", []openai.ToolCall{
295 {
296 Type: openai.ToolTypeFunction,
297 ID: c.ID, // Use the content ID if provided
298 Function: openai.FunctionCall{
299 Name: c.ToolName,
300 Arguments: string(c.ToolInput),
301 },
302 },
303 }
304 case llm.ContentTypeToolResult:
305 // Tool results in OpenAI are sent as a separate message with tool_call_id
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700306 // OpenAI doesn't support multiple content items or images in tool results
307 // Combine all text content into a single string
308 var resultText string
309 if len(c.ToolResult) > 0 {
310 // Collect all text from content objects
311 texts := make([]string, 0, len(c.ToolResult))
312 for _, result := range c.ToolResult {
313 if result.Text != "" {
314 texts = append(texts, result.Text)
315 }
316 }
317 resultText = strings.Join(texts, "\n")
318 }
319 return resultText, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700320 default:
321 // For thinking or other types, convert to text
322 return c.Text, nil
323 }
324}
325
326// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
327func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
328 // For OpenAI, we need to handle tool results differently than regular messages
329 // Each tool result becomes its own message with role="tool"
330
331 var messages []openai.ChatCompletionMessage
332
333 // Check if this is a regular message or contains tool results
334 var regularContent []llm.Content
335 var toolResults []llm.Content
336
337 for _, c := range msg.Content {
338 if c.Type == llm.ContentTypeToolResult {
339 toolResults = append(toolResults, c)
340 } else {
341 regularContent = append(regularContent, c)
342 }
343 }
344
345 // Process tool results as separate messages, but first
346 for _, tr := range toolResults {
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700347 // Convert toolresult array to a string for OpenAI
348 var toolResultContent string
349 if len(tr.ToolResult) > 0 {
350 // For now, just use the first text content in the array
351 toolResultContent = tr.ToolResult[0].Text
352 }
353
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700354 m := openai.ChatCompletionMessage{
355 Role: "tool",
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700356 Content: cmp.Or(toolResultContent, " "), // Use empty space if empty to avoid omitempty issues
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700357 ToolCallID: tr.ToolUseID,
358 }
359 messages = append(messages, m)
360 }
361 // Process regular content second
362 if len(regularContent) > 0 {
363 m := openai.ChatCompletionMessage{
364 Role: fromLLMRole[msg.Role],
365 }
366
367 // For assistant messages that contain tool calls
368 var toolCalls []openai.ToolCall
369 var textContent string
370
371 for _, c := range regularContent {
372 content, tools := fromLLMContent(c)
373 if len(tools) > 0 {
374 toolCalls = append(toolCalls, tools...)
375 } else if content != "" {
376 if textContent != "" {
377 textContent += "\n"
378 }
379 textContent += content
380 }
381 }
382
383 m.Content = textContent
384 m.ToolCalls = toolCalls
385
386 messages = append(messages, m)
387 }
388
389 return messages
390}
391
392// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
393func fromLLMToolChoice(tc *llm.ToolChoice) any {
394 if tc == nil {
395 return nil
396 }
397
398 if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
399 return openai.ToolChoice{
400 Type: openai.ToolTypeFunction,
401 Function: openai.ToolFunction{
402 Name: tc.Name,
403 },
404 }
405 }
406
407 // For non-specific tool choice, just use the string
408 return fromLLMToolChoiceType[tc.Type]
409}
410
411// fromLLMTool converts llm.Tool to the format expected by OpenAI.
412func fromLLMTool(t *llm.Tool) openai.Tool {
413 return openai.Tool{
414 Type: openai.ToolTypeFunction,
415 Function: &openai.FunctionDefinition{
416 Name: t.Name,
417 Description: t.Description,
418 Parameters: t.InputSchema,
419 },
420 }
421}
422
423// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
424func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
425 if len(systemContent) == 0 {
426 return nil
427 }
428
429 // Combine all system content into a single system message
430 var systemText string
431 for i, content := range systemContent {
432 if i > 0 && systemText != "" && content.Text != "" {
433 systemText += "\n"
434 }
435 systemText += content.Text
436 }
437
438 if systemText == "" {
439 return nil
440 }
441
442 return []openai.ChatCompletionMessage{
443 {
444 Role: "system",
445 Content: systemText,
446 },
447 }
448}
449
450// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
451func toRawLLMContent(content string) llm.Content {
452 return llm.Content{
453 Type: llm.ContentTypeText,
454 Text: content,
455 }
456}
457
458// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
459func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
460 // Generate a content ID if needed
461 id := toolCall.ID
462 if id == "" {
463 // Create a deterministic ID based on the function name if no ID is provided
464 id = "tc_" + toolCall.Function.Name
465 }
466
467 return llm.Content{
468 ID: id,
469 Type: llm.ContentTypeToolUse,
470 ToolName: toolCall.Function.Name,
471 ToolInput: json.RawMessage(toolCall.Function.Arguments),
472 }
473}
474
475// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
476func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
477 return llm.Content{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700478 Type: llm.ContentTypeToolResult,
479 ToolUseID: msg.ToolCallID,
480 ToolResult: []llm.Content{{
481 Type: llm.ContentTypeText,
482 Text: msg.Content,
483 }},
484 ToolError: false, // OpenAI doesn't specify errors explicitly
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700485 }
486}
487
488// toLLMContents converts message content from OpenAI to []llm.Content.
489func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
490 var contents []llm.Content
491
492 // If this is a tool response, handle it separately
493 if msg.Role == "tool" && msg.ToolCallID != "" {
494 return []llm.Content{toToolResultLLMContent(msg)}
495 }
496
497 // If there's text content, add it
498 if msg.Content != "" {
499 contents = append(contents, toRawLLMContent(msg.Content))
500 }
501
502 // If there are tool calls, add them
503 for _, tc := range msg.ToolCalls {
504 contents = append(contents, toToolCallLLMContent(tc))
505 }
506
507 // If empty, add an empty text content
508 if len(contents) == 0 {
509 contents = append(contents, llm.Content{
510 Type: llm.ContentTypeText,
511 Text: "",
512 })
513 }
514
515 return contents
516}
517
518// toLLMUsage converts usage information from OpenAI to llm.Usage.
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700519func (s *Service) toLLMUsage(au openai.Usage, headers http.Header) llm.Usage {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700520 // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
521 in := uint64(au.PromptTokens)
522 var inc uint64
523 if au.PromptTokensDetails != nil {
524 inc = uint64(au.PromptTokensDetails.CachedTokens)
525 }
526 out := uint64(au.CompletionTokens)
527 u := llm.Usage{
528 InputTokens: in,
529 CacheReadInputTokens: inc,
530 CacheCreationInputTokens: in,
531 OutputTokens: out,
532 }
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700533 u.CostUSD = llm.CostUSDFromResponse(headers)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700534 return u
535}
536
537// toLLMResponse converts the OpenAI response to llm.Response.
538func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
539 // fmt.Printf("Raw response\n")
540 // enc := json.NewEncoder(os.Stdout)
541 // enc.SetIndent("", " ")
542 // enc.Encode(r)
543 // fmt.Printf("\n")
544
545 if len(r.Choices) == 0 {
546 return &llm.Response{
547 ID: r.ID,
548 Model: r.Model,
549 Role: llm.MessageRoleAssistant,
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700550 Usage: s.toLLMUsage(r.Usage, r.Header()),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700551 }
552 }
553
554 // Process the primary choice
555 choice := r.Choices[0]
556
557 return &llm.Response{
558 ID: r.ID,
559 Model: r.Model,
560 Role: toRoleFromString(choice.Message.Role),
561 Content: toLLMContents(choice.Message),
562 StopReason: toStopReason(string(choice.FinishReason)),
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700563 Usage: s.toLLMUsage(r.Usage, r.Header()),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700564 }
565}
566
567// toRoleFromString converts a role string to llm.MessageRole.
568func toRoleFromString(role string) llm.MessageRole {
569 if role == "tool" || role == "system" || role == "function" {
570 return llm.MessageRoleAssistant // Map special roles to assistant for consistency
571 }
572 if mr, ok := toLLMRole[role]; ok {
573 return mr
574 }
575 return llm.MessageRoleUser // Default to user if unknown
576}
577
578// toStopReason converts a finish reason string to llm.StopReason.
579func toStopReason(reason string) llm.StopReason {
580 if sr, ok := toLLMStopReason[reason]; ok {
581 return sr
582 }
583 return llm.StopReasonStopSequence // Default
584}
585
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700586// TokenContextWindow returns the maximum token context window size for this service
587func (s *Service) TokenContextWindow() int {
588 model := cmp.Or(s.Model, DefaultModel)
589
590 // OpenAI models generally have 128k context windows
591 // Some newer models have larger windows, but 128k is a safe default
592 switch model.ModelName {
593 case "gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14", "gpt-4.1-nano-2025-04-14":
594 return 200000 // 200k for newer GPT-4.1 models
595 case "gpt-4o-2024-08-06", "gpt-4o-mini-2024-07-18":
596 return 128000 // 128k for GPT-4o models
597 case "o3-2025-04-16", "o3-mini-2025-04-16":
598 return 200000 // 200k for O3 models
599 default:
600 // Default for unknown models
601 return 128000
602 }
603}
604
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700605// Do sends a request to OpenAI using the go-openai package.
606func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
607 // Configure the OpenAI client
608 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
609 model := cmp.Or(s.Model, DefaultModel)
610
611 // TODO: do this one during Service setup? maybe with a constructor instead?
612 config := openai.DefaultConfig(s.APIKey)
613 if model.URL != "" {
614 config.BaseURL = model.URL
615 }
616 if s.Org != "" {
617 config.OrgID = s.Org
618 }
619 config.HTTPClient = httpc
620
621 client := openai.NewClientWithConfig(config)
622
623 // Start with system messages if provided
624 var allMessages []openai.ChatCompletionMessage
625 if len(ir.System) > 0 {
626 sysMessages := fromLLMSystem(ir.System)
627 allMessages = append(allMessages, sysMessages...)
628 }
629
630 // Add regular and tool messages
631 for _, msg := range ir.Messages {
632 msgs := fromLLMMessage(msg)
633 allMessages = append(allMessages, msgs...)
634 }
635
636 // Convert tools
637 var tools []openai.Tool
638 for _, t := range ir.Tools {
639 tools = append(tools, fromLLMTool(t))
640 }
641
642 // Create the OpenAI request
643 req := openai.ChatCompletionRequest{
644 Model: model.ModelName,
645 Messages: allMessages,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700646 Tools: tools,
647 ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
648 }
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700649 if model.IsReasoningModel {
650 req.MaxCompletionTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
651 } else {
652 req.MaxTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
653 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700654 // fmt.Printf("Sending request to OpenAI\n")
655 // enc := json.NewEncoder(os.Stdout)
656 // enc.SetIndent("", " ")
657 // enc.Encode(req)
658 // fmt.Printf("\n")
659
660 // Retry mechanism
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000661 backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700662
663 // retry loop
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000664 var errs error // accumulated errors across all attempts
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700665 for attempts := 0; ; attempts++ {
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000666 if attempts > 10 {
667 return nil, fmt.Errorf("openai request failed after %d attempts: %w", attempts, errs)
668 }
669 if attempts > 0 {
670 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
671 slog.WarnContext(ctx, "openai request sleep before retry", "sleep", sleep, "attempts", attempts)
672 time.Sleep(sleep)
673 }
674
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700675 resp, err := client.CreateChatCompletion(ctx, req)
676
677 // Handle successful response
678 if err == nil {
679 return s.toLLMResponse(&resp), nil
680 }
681
682 // Handle errors
683 var apiErr *openai.APIError
684 if ok := errors.As(err, &apiErr); !ok {
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000685 // Not an OpenAI API error, return immediately with accumulated errors
686 return nil, errors.Join(errs, err)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700687 }
688
689 switch {
690 case apiErr.HTTPStatusCode >= 500:
691 // Server error, try again with backoff
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000692 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
693 errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700694 continue
695
696 case apiErr.HTTPStatusCode == 429:
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000697 // Rate limited, accumulate error and retry
698 slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error())
699 errs = errors.Join(errs, fmt.Errorf("status %d (rate limited): %s", apiErr.HTTPStatusCode, apiErr.Error()))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700700 continue
701
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000702 case apiErr.HTTPStatusCode >= 400 && apiErr.HTTPStatusCode < 500:
703 // Client error, probably unrecoverable
704 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
705 return nil, errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
706
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700707 default:
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000708 // Other error, accumulate and retry
709 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
710 errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
711 continue
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700712 }
713 }
714}