blob: 40524f3a65838d7ee99cbe2c0f4e792b6814b465 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package oai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "math/rand/v2"
11 "net/http"
Philip Zeyliger72252cb2025-05-10 17:00:08 -070012 "strings"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070013 "time"
14
15 "github.com/sashabaranov/go-openai"
16 "sketch.dev/llm"
17)
18
19const (
20 DefaultMaxTokens = 8192
21
22 OpenAIURL = "https://api.openai.com/v1"
23 FireworksURL = "https://api.fireworks.ai/inference/v1"
24 LlamaCPPURL = "http://localhost:8080/v1"
25 TogetherURL = "https://api.together.xyz/v1"
26 GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070027 MistralURL = "https://api.mistral.ai/v1"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070028
29 // Environment variable names for API keys
30 OpenAIAPIKeyEnv = "OPENAI_API_KEY"
31 FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
32 TogetherAPIKeyEnv = "TOGETHER_API_KEY"
33 GeminiAPIKeyEnv = "GEMINI_API_KEY"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070034 MistralAPIKeyEnv = "MISTRAL_API_KEY"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070035)
36
37type Model struct {
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070038 UserName string // provided by the user to identify this model (e.g. "gpt4.1")
39 ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
40 URL string
41 Cost ModelCost
42 APIKeyEnv string // environment variable name for the API key
43 IsReasoningModel bool // whether this model is a reasoning model (e.g. O3, O4-mini)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070044}
45
46type ModelCost struct {
47 Input uint64 // in cents per million tokens
48 CachedInput uint64 // in cents per million tokens
49 Output uint64 // in cents per million tokens
50}
51
52var (
53 DefaultModel = GPT41
54
55 GPT41 = Model{
56 UserName: "gpt4.1",
57 ModelName: "gpt-4.1-2025-04-14",
58 URL: OpenAIURL,
59 Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
60 APIKeyEnv: OpenAIAPIKeyEnv,
61 }
62
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -070063 GPT4o = Model{
64 UserName: "gpt4o",
65 ModelName: "gpt-4o-2024-08-06",
66 URL: OpenAIURL,
67 Cost: ModelCost{Input: 250, CachedInput: 125, Output: 1000},
68 APIKeyEnv: OpenAIAPIKeyEnv,
69 }
70
71 GPT4oMini = Model{
72 UserName: "gpt4o-mini",
73 ModelName: "gpt-4o-mini-2024-07-18",
74 URL: OpenAIURL,
75 Cost: ModelCost{Input: 15, CachedInput: 8, Output: 60}, // 8 is actually 7.5 GRRR round up for now oh well
76 APIKeyEnv: OpenAIAPIKeyEnv,
77 }
78
79 GPT41Mini = Model{
80 UserName: "gpt4.1-mini",
81 ModelName: "gpt-4.1-mini-2025-04-14",
82 URL: OpenAIURL,
83 Cost: ModelCost{Input: 40, CachedInput: 10, Output: 160},
84 APIKeyEnv: OpenAIAPIKeyEnv,
85 }
86
87 GPT41Nano = Model{
88 UserName: "gpt4.1-nano",
89 ModelName: "gpt-4.1-nano-2025-04-14",
90 URL: OpenAIURL,
91 Cost: ModelCost{Input: 10, CachedInput: 3, Output: 40}, // 3 is actually 2.5 GRRR round up for now oh well
92 APIKeyEnv: OpenAIAPIKeyEnv,
93 }
94
95 O3 = Model{
96 UserName: "o3",
97 ModelName: "o3-2025-04-16",
98 URL: OpenAIURL,
99 Cost: ModelCost{Input: 1000, CachedInput: 250, Output: 4000},
100 APIKeyEnv: OpenAIAPIKeyEnv,
101 IsReasoningModel: true,
102 }
103
104 O4Mini = Model{
105 UserName: "o4-mini",
106 ModelName: "o4-mini-2025-04-16",
107 URL: OpenAIURL,
108 Cost: ModelCost{Input: 110, CachedInput: 28, Output: 440}, // 28 is actually 27.5 GRRR round up for now oh well
109 APIKeyEnv: OpenAIAPIKeyEnv,
110 IsReasoningModel: true,
111 }
112
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700113 Gemini25Flash = Model{
114 UserName: "gemini-flash-2.5",
115 ModelName: "gemini-2.5-flash-preview-04-17",
116 URL: GeminiURL,
117 Cost: ModelCost{Input: 15, Output: 60},
118 APIKeyEnv: GeminiAPIKeyEnv,
119 }
120
121 Gemini25Pro = Model{
122 UserName: "gemini-pro-2.5",
123 ModelName: "gemini-2.5-pro-preview-03-25",
124 URL: GeminiURL,
125 // GRRRR. Really??
126 // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
127 // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
128 // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
129 // Whatever that means. Are we caching? I have no idea.
130 // How do you always manage to be the annoying one, Google?
131 // I'm not complicating things just for you.
132 Cost: ModelCost{Input: 125, Output: 1000},
133 APIKeyEnv: GeminiAPIKeyEnv,
134 }
135
136 TogetherDeepseekV3 = Model{
137 UserName: "together-deepseek-v3",
138 ModelName: "deepseek-ai/DeepSeek-V3",
139 URL: TogetherURL,
140 Cost: ModelCost{Input: 125, Output: 125},
141 APIKeyEnv: TogetherAPIKeyEnv,
142 }
143
144 TogetherLlama4Maverick = Model{
145 UserName: "together-llama4-maverick",
146 ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
147 URL: TogetherURL,
148 Cost: ModelCost{Input: 27, Output: 85},
149 APIKeyEnv: TogetherAPIKeyEnv,
150 }
151
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700152 FireworksLlama4Maverick = Model{
153 UserName: "fireworks-llama4-maverick",
154 ModelName: "accounts/fireworks/models/llama4-maverick-instruct-basic",
155 URL: FireworksURL,
156 Cost: ModelCost{Input: 22, Output: 88},
157 APIKeyEnv: FireworksAPIKeyEnv,
158 }
159
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700160 TogetherLlama3_3_70B = Model{
161 UserName: "together-llama3-70b",
162 ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
163 URL: TogetherURL,
164 Cost: ModelCost{Input: 88, Output: 88},
165 APIKeyEnv: TogetherAPIKeyEnv,
166 }
167
168 TogetherMistralSmall = Model{
169 UserName: "together-mistral-small",
170 ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
171 URL: TogetherURL,
172 Cost: ModelCost{Input: 80, Output: 80},
173 APIKeyEnv: TogetherAPIKeyEnv,
174 }
175
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700176 TogetherQwen3 = Model{
177 UserName: "together-qwen3",
178 ModelName: "Qwen/Qwen3-235B-A22B-fp8-tput",
179 URL: TogetherURL,
180 Cost: ModelCost{Input: 20, Output: 60},
181 APIKeyEnv: TogetherAPIKeyEnv,
182 }
183
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700184 TogetherGemma2 = Model{
185 UserName: "together-gemma2",
186 ModelName: "google/gemma-2-27b-it",
187 URL: TogetherURL,
188 Cost: ModelCost{Input: 80, Output: 80},
189 APIKeyEnv: TogetherAPIKeyEnv,
190 }
191
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700192 LlamaCPP = Model{
193 UserName: "llama.cpp",
194 ModelName: "llama.cpp local model",
195 URL: LlamaCPPURL,
196 // zero cost
197 Cost: ModelCost{},
198 }
199
200 FireworksDeepseekV3 = Model{
201 UserName: "fireworks-deepseek-v3",
202 ModelName: "accounts/fireworks/models/deepseek-v3-0324",
203 URL: FireworksURL,
204 Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
205 APIKeyEnv: FireworksAPIKeyEnv,
206 }
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -0700207
208 MistralMedium = Model{
209 UserName: "mistral-medium-3",
210 ModelName: "mistral-medium-latest",
211 URL: MistralURL,
212 Cost: ModelCost{Input: 40, Output: 200},
213 APIKeyEnv: MistralAPIKeyEnv,
214 }
Josh Bleecher Snyder1a648f32025-05-21 17:15:04 +0000215
216 DevstralSmall = Model{
217 UserName: "devstral-small",
218 ModelName: "devstral-small-latest",
219 URL: MistralURL,
220 Cost: ModelCost{Input: 100, Output: 300},
221 APIKeyEnv: MistralAPIKeyEnv,
222 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700223)
224
225// Service provides chat completions.
226// Fields should not be altered concurrently with calling any method on Service.
227type Service struct {
228 HTTPC *http.Client // defaults to http.DefaultClient if nil
229 APIKey string // optional, if not set will try to load from env var
230 Model Model // defaults to DefaultModel if zero value
231 MaxTokens int // defaults to DefaultMaxTokens if zero
232 Org string // optional - organization ID
233}
234
235var _ llm.Service = (*Service)(nil)
236
237// ModelsRegistry is a registry of all known models with their user-friendly names.
238var ModelsRegistry = []Model{
239 GPT41,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700240 GPT41Mini,
241 GPT41Nano,
242 GPT4o,
243 GPT4oMini,
244 O3,
245 O4Mini,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700246 Gemini25Flash,
247 Gemini25Pro,
248 TogetherDeepseekV3,
249 TogetherLlama4Maverick,
250 TogetherLlama3_3_70B,
251 TogetherMistralSmall,
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700252 TogetherQwen3,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700253 TogetherGemma2,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700254 LlamaCPP,
255 FireworksDeepseekV3,
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700256 FireworksLlama4Maverick,
257 MistralMedium,
Josh Bleecher Snyder1a648f32025-05-21 17:15:04 +0000258 DevstralSmall,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700259}
260
261// ListModels returns a list of all available models with their user-friendly names.
262func ListModels() []string {
263 var names []string
264 for _, model := range ModelsRegistry {
265 if model.UserName != "" {
266 names = append(names, model.UserName)
267 }
268 }
269 return names
270}
271
272// ModelByUserName returns a model by its user-friendly name.
273// Returns nil if no model with the given name is found.
274func ModelByUserName(name string) *Model {
275 for _, model := range ModelsRegistry {
276 if model.UserName == name {
277 return &model
278 }
279 }
280 return nil
281}
282
283var (
284 fromLLMRole = map[llm.MessageRole]string{
285 llm.MessageRoleAssistant: "assistant",
286 llm.MessageRoleUser: "user",
287 }
288 fromLLMContentType = map[llm.ContentType]string{
289 llm.ContentTypeText: "text",
290 llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
291 llm.ContentTypeToolResult: "tool_result",
292 llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
293 llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
294 }
295 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
296 llm.ToolChoiceTypeAuto: "auto",
297 llm.ToolChoiceTypeAny: "any",
298 llm.ToolChoiceTypeNone: "none",
299 llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
300 }
301 toLLMRole = map[string]llm.MessageRole{
302 "assistant": llm.MessageRoleAssistant,
303 "user": llm.MessageRoleUser,
304 }
305 toLLMStopReason = map[string]llm.StopReason{
306 "stop": llm.StopReasonStopSequence,
307 "length": llm.StopReasonMaxTokens,
308 "tool_calls": llm.StopReasonToolUse,
309 "function_call": llm.StopReasonToolUse, // Map both to ToolUse
310 "content_filter": llm.StopReasonStopSequence, // No direct equivalent
311 }
312)
313
314// fromLLMContent converts llm.Content to the format expected by OpenAI.
315func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
316 switch c.Type {
317 case llm.ContentTypeText:
318 return c.Text, nil
319 case llm.ContentTypeToolUse:
320 // For OpenAI, tool use is sent as a null content with tool_calls in the message
321 return "", []openai.ToolCall{
322 {
323 Type: openai.ToolTypeFunction,
324 ID: c.ID, // Use the content ID if provided
325 Function: openai.FunctionCall{
326 Name: c.ToolName,
327 Arguments: string(c.ToolInput),
328 },
329 },
330 }
331 case llm.ContentTypeToolResult:
332 // Tool results in OpenAI are sent as a separate message with tool_call_id
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700333 // OpenAI doesn't support multiple content items or images in tool results
334 // Combine all text content into a single string
335 var resultText string
336 if len(c.ToolResult) > 0 {
337 // Collect all text from content objects
338 texts := make([]string, 0, len(c.ToolResult))
339 for _, result := range c.ToolResult {
340 if result.Text != "" {
341 texts = append(texts, result.Text)
342 }
343 }
344 resultText = strings.Join(texts, "\n")
345 }
346 return resultText, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700347 default:
348 // For thinking or other types, convert to text
349 return c.Text, nil
350 }
351}
352
353// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
354func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
355 // For OpenAI, we need to handle tool results differently than regular messages
356 // Each tool result becomes its own message with role="tool"
357
358 var messages []openai.ChatCompletionMessage
359
360 // Check if this is a regular message or contains tool results
361 var regularContent []llm.Content
362 var toolResults []llm.Content
363
364 for _, c := range msg.Content {
365 if c.Type == llm.ContentTypeToolResult {
366 toolResults = append(toolResults, c)
367 } else {
368 regularContent = append(regularContent, c)
369 }
370 }
371
372 // Process tool results as separate messages, but first
373 for _, tr := range toolResults {
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700374 // Convert toolresult array to a string for OpenAI
375 var toolResultContent string
376 if len(tr.ToolResult) > 0 {
377 // For now, just use the first text content in the array
378 toolResultContent = tr.ToolResult[0].Text
379 }
380
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700381 m := openai.ChatCompletionMessage{
382 Role: "tool",
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700383 Content: cmp.Or(toolResultContent, " "), // Use empty space if empty to avoid omitempty issues
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700384 ToolCallID: tr.ToolUseID,
385 }
386 messages = append(messages, m)
387 }
388 // Process regular content second
389 if len(regularContent) > 0 {
390 m := openai.ChatCompletionMessage{
391 Role: fromLLMRole[msg.Role],
392 }
393
394 // For assistant messages that contain tool calls
395 var toolCalls []openai.ToolCall
396 var textContent string
397
398 for _, c := range regularContent {
399 content, tools := fromLLMContent(c)
400 if len(tools) > 0 {
401 toolCalls = append(toolCalls, tools...)
402 } else if content != "" {
403 if textContent != "" {
404 textContent += "\n"
405 }
406 textContent += content
407 }
408 }
409
410 m.Content = textContent
411 m.ToolCalls = toolCalls
412
413 messages = append(messages, m)
414 }
415
416 return messages
417}
418
419// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
420func fromLLMToolChoice(tc *llm.ToolChoice) any {
421 if tc == nil {
422 return nil
423 }
424
425 if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
426 return openai.ToolChoice{
427 Type: openai.ToolTypeFunction,
428 Function: openai.ToolFunction{
429 Name: tc.Name,
430 },
431 }
432 }
433
434 // For non-specific tool choice, just use the string
435 return fromLLMToolChoiceType[tc.Type]
436}
437
438// fromLLMTool converts llm.Tool to the format expected by OpenAI.
439func fromLLMTool(t *llm.Tool) openai.Tool {
440 return openai.Tool{
441 Type: openai.ToolTypeFunction,
442 Function: &openai.FunctionDefinition{
443 Name: t.Name,
444 Description: t.Description,
445 Parameters: t.InputSchema,
446 },
447 }
448}
449
450// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
451func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
452 if len(systemContent) == 0 {
453 return nil
454 }
455
456 // Combine all system content into a single system message
457 var systemText string
458 for i, content := range systemContent {
459 if i > 0 && systemText != "" && content.Text != "" {
460 systemText += "\n"
461 }
462 systemText += content.Text
463 }
464
465 if systemText == "" {
466 return nil
467 }
468
469 return []openai.ChatCompletionMessage{
470 {
471 Role: "system",
472 Content: systemText,
473 },
474 }
475}
476
477// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
478func toRawLLMContent(content string) llm.Content {
479 return llm.Content{
480 Type: llm.ContentTypeText,
481 Text: content,
482 }
483}
484
485// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
486func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
487 // Generate a content ID if needed
488 id := toolCall.ID
489 if id == "" {
490 // Create a deterministic ID based on the function name if no ID is provided
491 id = "tc_" + toolCall.Function.Name
492 }
493
494 return llm.Content{
495 ID: id,
496 Type: llm.ContentTypeToolUse,
497 ToolName: toolCall.Function.Name,
498 ToolInput: json.RawMessage(toolCall.Function.Arguments),
499 }
500}
501
502// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
503func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
504 return llm.Content{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700505 Type: llm.ContentTypeToolResult,
506 ToolUseID: msg.ToolCallID,
507 ToolResult: []llm.Content{{
508 Type: llm.ContentTypeText,
509 Text: msg.Content,
510 }},
511 ToolError: false, // OpenAI doesn't specify errors explicitly
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700512 }
513}
514
515// toLLMContents converts message content from OpenAI to []llm.Content.
516func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
517 var contents []llm.Content
518
519 // If this is a tool response, handle it separately
520 if msg.Role == "tool" && msg.ToolCallID != "" {
521 return []llm.Content{toToolResultLLMContent(msg)}
522 }
523
524 // If there's text content, add it
525 if msg.Content != "" {
526 contents = append(contents, toRawLLMContent(msg.Content))
527 }
528
529 // If there are tool calls, add them
530 for _, tc := range msg.ToolCalls {
531 contents = append(contents, toToolCallLLMContent(tc))
532 }
533
534 // If empty, add an empty text content
535 if len(contents) == 0 {
536 contents = append(contents, llm.Content{
537 Type: llm.ContentTypeText,
538 Text: "",
539 })
540 }
541
542 return contents
543}
544
545// toLLMUsage converts usage information from OpenAI to llm.Usage.
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700546func (s *Service) toLLMUsage(au openai.Usage) llm.Usage {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700547 // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
548 in := uint64(au.PromptTokens)
549 var inc uint64
550 if au.PromptTokensDetails != nil {
551 inc = uint64(au.PromptTokensDetails.CachedTokens)
552 }
553 out := uint64(au.CompletionTokens)
554 u := llm.Usage{
555 InputTokens: in,
556 CacheReadInputTokens: inc,
557 CacheCreationInputTokens: in,
558 OutputTokens: out,
559 }
560 u.CostUSD = s.calculateCostFromTokens(u)
561 return u
562}
563
564// toLLMResponse converts the OpenAI response to llm.Response.
565func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
566 // fmt.Printf("Raw response\n")
567 // enc := json.NewEncoder(os.Stdout)
568 // enc.SetIndent("", " ")
569 // enc.Encode(r)
570 // fmt.Printf("\n")
571
572 if len(r.Choices) == 0 {
573 return &llm.Response{
574 ID: r.ID,
575 Model: r.Model,
576 Role: llm.MessageRoleAssistant,
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700577 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700578 }
579 }
580
581 // Process the primary choice
582 choice := r.Choices[0]
583
584 return &llm.Response{
585 ID: r.ID,
586 Model: r.Model,
587 Role: toRoleFromString(choice.Message.Role),
588 Content: toLLMContents(choice.Message),
589 StopReason: toStopReason(string(choice.FinishReason)),
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700590 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700591 }
592}
593
594// toRoleFromString converts a role string to llm.MessageRole.
595func toRoleFromString(role string) llm.MessageRole {
596 if role == "tool" || role == "system" || role == "function" {
597 return llm.MessageRoleAssistant // Map special roles to assistant for consistency
598 }
599 if mr, ok := toLLMRole[role]; ok {
600 return mr
601 }
602 return llm.MessageRoleUser // Default to user if unknown
603}
604
605// toStopReason converts a finish reason string to llm.StopReason.
606func toStopReason(reason string) llm.StopReason {
607 if sr, ok := toLLMStopReason[reason]; ok {
608 return sr
609 }
610 return llm.StopReasonStopSequence // Default
611}
612
613// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
614func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
615 cost := s.Model.Cost
616
617 // TODO: check this for correctness, i am skeptical
618 // Calculate cost in cents
619 megaCents := u.CacheCreationInputTokens*cost.Input +
620 u.CacheReadInputTokens*cost.CachedInput +
621 u.OutputTokens*cost.Output
622
623 cents := float64(megaCents) / 1_000_000
624 // Convert to dollars
625 dollars := cents / 100.0
626 // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
627 return dollars
628}
629
630// Do sends a request to OpenAI using the go-openai package.
631func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
632 // Configure the OpenAI client
633 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
634 model := cmp.Or(s.Model, DefaultModel)
635
636 // TODO: do this one during Service setup? maybe with a constructor instead?
637 config := openai.DefaultConfig(s.APIKey)
638 if model.URL != "" {
639 config.BaseURL = model.URL
640 }
641 if s.Org != "" {
642 config.OrgID = s.Org
643 }
644 config.HTTPClient = httpc
645
646 client := openai.NewClientWithConfig(config)
647
648 // Start with system messages if provided
649 var allMessages []openai.ChatCompletionMessage
650 if len(ir.System) > 0 {
651 sysMessages := fromLLMSystem(ir.System)
652 allMessages = append(allMessages, sysMessages...)
653 }
654
655 // Add regular and tool messages
656 for _, msg := range ir.Messages {
657 msgs := fromLLMMessage(msg)
658 allMessages = append(allMessages, msgs...)
659 }
660
661 // Convert tools
662 var tools []openai.Tool
663 for _, t := range ir.Tools {
664 tools = append(tools, fromLLMTool(t))
665 }
666
667 // Create the OpenAI request
668 req := openai.ChatCompletionRequest{
669 Model: model.ModelName,
670 Messages: allMessages,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700671 Tools: tools,
672 ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
673 }
Josh Bleecher Snyder8236cbc2025-05-09 09:57:57 -0700674 if model.IsReasoningModel {
675 req.MaxCompletionTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
676 } else {
677 req.MaxTokens = cmp.Or(s.MaxTokens, DefaultMaxTokens)
678 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700679 // fmt.Printf("Sending request to OpenAI\n")
680 // enc := json.NewEncoder(os.Stdout)
681 // enc.SetIndent("", " ")
682 // enc.Encode(req)
683 // fmt.Printf("\n")
684
685 // Retry mechanism
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000686 backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700687
688 // retry loop
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000689 var errs error // accumulated errors across all attempts
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700690 for attempts := 0; ; attempts++ {
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000691 if attempts > 10 {
692 return nil, fmt.Errorf("openai request failed after %d attempts: %w", attempts, errs)
693 }
694 if attempts > 0 {
695 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
696 slog.WarnContext(ctx, "openai request sleep before retry", "sleep", sleep, "attempts", attempts)
697 time.Sleep(sleep)
698 }
699
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700700 resp, err := client.CreateChatCompletion(ctx, req)
701
702 // Handle successful response
703 if err == nil {
704 return s.toLLMResponse(&resp), nil
705 }
706
707 // Handle errors
708 var apiErr *openai.APIError
709 if ok := errors.As(err, &apiErr); !ok {
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000710 // Not an OpenAI API error, return immediately with accumulated errors
711 return nil, errors.Join(errs, err)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700712 }
713
714 switch {
715 case apiErr.HTTPStatusCode >= 500:
716 // Server error, try again with backoff
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000717 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
718 errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700719 continue
720
721 case apiErr.HTTPStatusCode == 429:
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000722 // Rate limited, accumulate error and retry
723 slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error())
724 errs = errors.Join(errs, fmt.Errorf("status %d (rate limited): %s", apiErr.HTTPStatusCode, apiErr.Error()))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700725 continue
726
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000727 case apiErr.HTTPStatusCode >= 400 && apiErr.HTTPStatusCode < 500:
728 // Client error, probably unrecoverable
729 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
730 return nil, errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
731
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700732 default:
Josh Bleecher Snyder38411992025-05-16 17:51:03 +0000733 // Other error, accumulate and retry
734 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
735 errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
736 continue
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700737 }
738 }
739}