blob: fe395e69c313943e3497942664b53af5cc4e4609 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package oai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "math/rand/v2"
11 "net/http"
12 "time"
13
14 "github.com/sashabaranov/go-openai"
15 "sketch.dev/llm"
16)
17
18const (
19 DefaultMaxTokens = 8192
20
21 OpenAIURL = "https://api.openai.com/v1"
22 FireworksURL = "https://api.fireworks.ai/inference/v1"
23 LlamaCPPURL = "http://localhost:8080/v1"
24 TogetherURL = "https://api.together.xyz/v1"
25 GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070026 MistralURL = "https://api.mistral.ai/v1"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070027
28 // Environment variable names for API keys
29 OpenAIAPIKeyEnv = "OPENAI_API_KEY"
30 FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
31 TogetherAPIKeyEnv = "TOGETHER_API_KEY"
32 GeminiAPIKeyEnv = "GEMINI_API_KEY"
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -070033 MistralAPIKeyEnv = "MISTRAL_API_KEY"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070034)
35
36type Model struct {
37 UserName string // provided by the user to identify this model (e.g. "gpt4.1")
38 ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
39 URL string
40 Cost ModelCost
41 APIKeyEnv string // environment variable name for the API key
42}
43
44type ModelCost struct {
45 Input uint64 // in cents per million tokens
46 CachedInput uint64 // in cents per million tokens
47 Output uint64 // in cents per million tokens
48}
49
50var (
51 DefaultModel = GPT41
52
53 GPT41 = Model{
54 UserName: "gpt4.1",
55 ModelName: "gpt-4.1-2025-04-14",
56 URL: OpenAIURL,
57 Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
58 APIKeyEnv: OpenAIAPIKeyEnv,
59 }
60
61 Gemini25Flash = Model{
62 UserName: "gemini-flash-2.5",
63 ModelName: "gemini-2.5-flash-preview-04-17",
64 URL: GeminiURL,
65 Cost: ModelCost{Input: 15, Output: 60},
66 APIKeyEnv: GeminiAPIKeyEnv,
67 }
68
69 Gemini25Pro = Model{
70 UserName: "gemini-pro-2.5",
71 ModelName: "gemini-2.5-pro-preview-03-25",
72 URL: GeminiURL,
73 // GRRRR. Really??
74 // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
75 // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
76 // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
77 // Whatever that means. Are we caching? I have no idea.
78 // How do you always manage to be the annoying one, Google?
79 // I'm not complicating things just for you.
80 Cost: ModelCost{Input: 125, Output: 1000},
81 APIKeyEnv: GeminiAPIKeyEnv,
82 }
83
84 TogetherDeepseekV3 = Model{
85 UserName: "together-deepseek-v3",
86 ModelName: "deepseek-ai/DeepSeek-V3",
87 URL: TogetherURL,
88 Cost: ModelCost{Input: 125, Output: 125},
89 APIKeyEnv: TogetherAPIKeyEnv,
90 }
91
92 TogetherLlama4Maverick = Model{
93 UserName: "together-llama4-maverick",
94 ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
95 URL: TogetherURL,
96 Cost: ModelCost{Input: 27, Output: 85},
97 APIKeyEnv: TogetherAPIKeyEnv,
98 }
99
100 TogetherLlama3_3_70B = Model{
101 UserName: "together-llama3-70b",
102 ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
103 URL: TogetherURL,
104 Cost: ModelCost{Input: 88, Output: 88},
105 APIKeyEnv: TogetherAPIKeyEnv,
106 }
107
108 TogetherMistralSmall = Model{
109 UserName: "together-mistral-small",
110 ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
111 URL: TogetherURL,
112 Cost: ModelCost{Input: 80, Output: 80},
113 APIKeyEnv: TogetherAPIKeyEnv,
114 }
115
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700116 TogetherQwen3 = Model{
117 UserName: "together-qwen3",
118 ModelName: "Qwen/Qwen3-235B-A22B-fp8-tput",
119 URL: TogetherURL,
120 Cost: ModelCost{Input: 20, Output: 60},
121 APIKeyEnv: TogetherAPIKeyEnv,
122 }
123
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700124 LlamaCPP = Model{
125 UserName: "llama.cpp",
126 ModelName: "llama.cpp local model",
127 URL: LlamaCPPURL,
128 // zero cost
129 Cost: ModelCost{},
130 }
131
132 FireworksDeepseekV3 = Model{
133 UserName: "fireworks-deepseek-v3",
134 ModelName: "accounts/fireworks/models/deepseek-v3-0324",
135 URL: FireworksURL,
136 Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
137 APIKeyEnv: FireworksAPIKeyEnv,
138 }
Josh Bleecher Snyderfa667032025-05-07 14:13:27 -0700139
140 MistralMedium = Model{
141 UserName: "mistral-medium-3",
142 ModelName: "mistral-medium-latest",
143 URL: MistralURL,
144 Cost: ModelCost{Input: 40, Output: 200},
145 APIKeyEnv: MistralAPIKeyEnv,
146 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700147)
148
149// Service provides chat completions.
150// Fields should not be altered concurrently with calling any method on Service.
151type Service struct {
152 HTTPC *http.Client // defaults to http.DefaultClient if nil
153 APIKey string // optional, if not set will try to load from env var
154 Model Model // defaults to DefaultModel if zero value
155 MaxTokens int // defaults to DefaultMaxTokens if zero
156 Org string // optional - organization ID
157}
158
159var _ llm.Service = (*Service)(nil)
160
161// ModelsRegistry is a registry of all known models with their user-friendly names.
162var ModelsRegistry = []Model{
163 GPT41,
164 Gemini25Flash,
165 Gemini25Pro,
166 TogetherDeepseekV3,
167 TogetherLlama4Maverick,
168 TogetherLlama3_3_70B,
169 TogetherMistralSmall,
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700170 TogetherQwen3,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700171 LlamaCPP,
172 FireworksDeepseekV3,
173}
174
175// ListModels returns a list of all available models with their user-friendly names.
176func ListModels() []string {
177 var names []string
178 for _, model := range ModelsRegistry {
179 if model.UserName != "" {
180 names = append(names, model.UserName)
181 }
182 }
183 return names
184}
185
186// ModelByUserName returns a model by its user-friendly name.
187// Returns nil if no model with the given name is found.
188func ModelByUserName(name string) *Model {
189 for _, model := range ModelsRegistry {
190 if model.UserName == name {
191 return &model
192 }
193 }
194 return nil
195}
196
197var (
198 fromLLMRole = map[llm.MessageRole]string{
199 llm.MessageRoleAssistant: "assistant",
200 llm.MessageRoleUser: "user",
201 }
202 fromLLMContentType = map[llm.ContentType]string{
203 llm.ContentTypeText: "text",
204 llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
205 llm.ContentTypeToolResult: "tool_result",
206 llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
207 llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
208 }
209 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
210 llm.ToolChoiceTypeAuto: "auto",
211 llm.ToolChoiceTypeAny: "any",
212 llm.ToolChoiceTypeNone: "none",
213 llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
214 }
215 toLLMRole = map[string]llm.MessageRole{
216 "assistant": llm.MessageRoleAssistant,
217 "user": llm.MessageRoleUser,
218 }
219 toLLMStopReason = map[string]llm.StopReason{
220 "stop": llm.StopReasonStopSequence,
221 "length": llm.StopReasonMaxTokens,
222 "tool_calls": llm.StopReasonToolUse,
223 "function_call": llm.StopReasonToolUse, // Map both to ToolUse
224 "content_filter": llm.StopReasonStopSequence, // No direct equivalent
225 }
226)
227
228// fromLLMContent converts llm.Content to the format expected by OpenAI.
229func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
230 switch c.Type {
231 case llm.ContentTypeText:
232 return c.Text, nil
233 case llm.ContentTypeToolUse:
234 // For OpenAI, tool use is sent as a null content with tool_calls in the message
235 return "", []openai.ToolCall{
236 {
237 Type: openai.ToolTypeFunction,
238 ID: c.ID, // Use the content ID if provided
239 Function: openai.FunctionCall{
240 Name: c.ToolName,
241 Arguments: string(c.ToolInput),
242 },
243 },
244 }
245 case llm.ContentTypeToolResult:
246 // Tool results in OpenAI are sent as a separate message with tool_call_id
247 return c.ToolResult, nil
248 default:
249 // For thinking or other types, convert to text
250 return c.Text, nil
251 }
252}
253
254// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
255func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
256 // For OpenAI, we need to handle tool results differently than regular messages
257 // Each tool result becomes its own message with role="tool"
258
259 var messages []openai.ChatCompletionMessage
260
261 // Check if this is a regular message or contains tool results
262 var regularContent []llm.Content
263 var toolResults []llm.Content
264
265 for _, c := range msg.Content {
266 if c.Type == llm.ContentTypeToolResult {
267 toolResults = append(toolResults, c)
268 } else {
269 regularContent = append(regularContent, c)
270 }
271 }
272
273 // Process tool results as separate messages, but first
274 for _, tr := range toolResults {
275 m := openai.ChatCompletionMessage{
276 Role: "tool",
277 Content: cmp.Or(tr.ToolResult, " "), // TODO: remove omitempty upstream
278 ToolCallID: tr.ToolUseID,
279 }
280 messages = append(messages, m)
281 }
282 // Process regular content second
283 if len(regularContent) > 0 {
284 m := openai.ChatCompletionMessage{
285 Role: fromLLMRole[msg.Role],
286 }
287
288 // For assistant messages that contain tool calls
289 var toolCalls []openai.ToolCall
290 var textContent string
291
292 for _, c := range regularContent {
293 content, tools := fromLLMContent(c)
294 if len(tools) > 0 {
295 toolCalls = append(toolCalls, tools...)
296 } else if content != "" {
297 if textContent != "" {
298 textContent += "\n"
299 }
300 textContent += content
301 }
302 }
303
304 m.Content = textContent
305 m.ToolCalls = toolCalls
306
307 messages = append(messages, m)
308 }
309
310 return messages
311}
312
313// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
314func fromLLMToolChoice(tc *llm.ToolChoice) any {
315 if tc == nil {
316 return nil
317 }
318
319 if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
320 return openai.ToolChoice{
321 Type: openai.ToolTypeFunction,
322 Function: openai.ToolFunction{
323 Name: tc.Name,
324 },
325 }
326 }
327
328 // For non-specific tool choice, just use the string
329 return fromLLMToolChoiceType[tc.Type]
330}
331
332// fromLLMTool converts llm.Tool to the format expected by OpenAI.
333func fromLLMTool(t *llm.Tool) openai.Tool {
334 return openai.Tool{
335 Type: openai.ToolTypeFunction,
336 Function: &openai.FunctionDefinition{
337 Name: t.Name,
338 Description: t.Description,
339 Parameters: t.InputSchema,
340 },
341 }
342}
343
344// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
345func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
346 if len(systemContent) == 0 {
347 return nil
348 }
349
350 // Combine all system content into a single system message
351 var systemText string
352 for i, content := range systemContent {
353 if i > 0 && systemText != "" && content.Text != "" {
354 systemText += "\n"
355 }
356 systemText += content.Text
357 }
358
359 if systemText == "" {
360 return nil
361 }
362
363 return []openai.ChatCompletionMessage{
364 {
365 Role: "system",
366 Content: systemText,
367 },
368 }
369}
370
371// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
372func toRawLLMContent(content string) llm.Content {
373 return llm.Content{
374 Type: llm.ContentTypeText,
375 Text: content,
376 }
377}
378
379// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
380func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
381 // Generate a content ID if needed
382 id := toolCall.ID
383 if id == "" {
384 // Create a deterministic ID based on the function name if no ID is provided
385 id = "tc_" + toolCall.Function.Name
386 }
387
388 return llm.Content{
389 ID: id,
390 Type: llm.ContentTypeToolUse,
391 ToolName: toolCall.Function.Name,
392 ToolInput: json.RawMessage(toolCall.Function.Arguments),
393 }
394}
395
396// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
397func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
398 return llm.Content{
399 Type: llm.ContentTypeToolResult,
400 ToolUseID: msg.ToolCallID,
401 ToolResult: msg.Content,
402 ToolError: false, // OpenAI doesn't specify errors explicitly
403 }
404}
405
406// toLLMContents converts message content from OpenAI to []llm.Content.
407func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
408 var contents []llm.Content
409
410 // If this is a tool response, handle it separately
411 if msg.Role == "tool" && msg.ToolCallID != "" {
412 return []llm.Content{toToolResultLLMContent(msg)}
413 }
414
415 // If there's text content, add it
416 if msg.Content != "" {
417 contents = append(contents, toRawLLMContent(msg.Content))
418 }
419
420 // If there are tool calls, add them
421 for _, tc := range msg.ToolCalls {
422 contents = append(contents, toToolCallLLMContent(tc))
423 }
424
425 // If empty, add an empty text content
426 if len(contents) == 0 {
427 contents = append(contents, llm.Content{
428 Type: llm.ContentTypeText,
429 Text: "",
430 })
431 }
432
433 return contents
434}
435
436// toLLMUsage converts usage information from OpenAI to llm.Usage.
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700437func (s *Service) toLLMUsage(au openai.Usage) llm.Usage {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700438 // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
439 in := uint64(au.PromptTokens)
440 var inc uint64
441 if au.PromptTokensDetails != nil {
442 inc = uint64(au.PromptTokensDetails.CachedTokens)
443 }
444 out := uint64(au.CompletionTokens)
445 u := llm.Usage{
446 InputTokens: in,
447 CacheReadInputTokens: inc,
448 CacheCreationInputTokens: in,
449 OutputTokens: out,
450 }
451 u.CostUSD = s.calculateCostFromTokens(u)
452 return u
453}
454
455// toLLMResponse converts the OpenAI response to llm.Response.
456func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
457 // fmt.Printf("Raw response\n")
458 // enc := json.NewEncoder(os.Stdout)
459 // enc.SetIndent("", " ")
460 // enc.Encode(r)
461 // fmt.Printf("\n")
462
463 if len(r.Choices) == 0 {
464 return &llm.Response{
465 ID: r.ID,
466 Model: r.Model,
467 Role: llm.MessageRoleAssistant,
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700468 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700469 }
470 }
471
472 // Process the primary choice
473 choice := r.Choices[0]
474
475 return &llm.Response{
476 ID: r.ID,
477 Model: r.Model,
478 Role: toRoleFromString(choice.Message.Role),
479 Content: toLLMContents(choice.Message),
480 StopReason: toStopReason(string(choice.FinishReason)),
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700481 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700482 }
483}
484
485// toRoleFromString converts a role string to llm.MessageRole.
486func toRoleFromString(role string) llm.MessageRole {
487 if role == "tool" || role == "system" || role == "function" {
488 return llm.MessageRoleAssistant // Map special roles to assistant for consistency
489 }
490 if mr, ok := toLLMRole[role]; ok {
491 return mr
492 }
493 return llm.MessageRoleUser // Default to user if unknown
494}
495
496// toStopReason converts a finish reason string to llm.StopReason.
497func toStopReason(reason string) llm.StopReason {
498 if sr, ok := toLLMStopReason[reason]; ok {
499 return sr
500 }
501 return llm.StopReasonStopSequence // Default
502}
503
504// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
505func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
506 cost := s.Model.Cost
507
508 // TODO: check this for correctness, i am skeptical
509 // Calculate cost in cents
510 megaCents := u.CacheCreationInputTokens*cost.Input +
511 u.CacheReadInputTokens*cost.CachedInput +
512 u.OutputTokens*cost.Output
513
514 cents := float64(megaCents) / 1_000_000
515 // Convert to dollars
516 dollars := cents / 100.0
517 // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
518 return dollars
519}
520
521// Do sends a request to OpenAI using the go-openai package.
522func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
523 // Configure the OpenAI client
524 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
525 model := cmp.Or(s.Model, DefaultModel)
526
527 // TODO: do this one during Service setup? maybe with a constructor instead?
528 config := openai.DefaultConfig(s.APIKey)
529 if model.URL != "" {
530 config.BaseURL = model.URL
531 }
532 if s.Org != "" {
533 config.OrgID = s.Org
534 }
535 config.HTTPClient = httpc
536
537 client := openai.NewClientWithConfig(config)
538
539 // Start with system messages if provided
540 var allMessages []openai.ChatCompletionMessage
541 if len(ir.System) > 0 {
542 sysMessages := fromLLMSystem(ir.System)
543 allMessages = append(allMessages, sysMessages...)
544 }
545
546 // Add regular and tool messages
547 for _, msg := range ir.Messages {
548 msgs := fromLLMMessage(msg)
549 allMessages = append(allMessages, msgs...)
550 }
551
552 // Convert tools
553 var tools []openai.Tool
554 for _, t := range ir.Tools {
555 tools = append(tools, fromLLMTool(t))
556 }
557
558 // Create the OpenAI request
559 req := openai.ChatCompletionRequest{
560 Model: model.ModelName,
561 Messages: allMessages,
562 MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
563 Tools: tools,
564 ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
565 }
566 // fmt.Printf("Sending request to OpenAI\n")
567 // enc := json.NewEncoder(os.Stdout)
568 // enc.SetIndent("", " ")
569 // enc.Encode(req)
570 // fmt.Printf("\n")
571
572 // Retry mechanism
573 backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
574
575 // retry loop
576 for attempts := 0; ; attempts++ {
577 resp, err := client.CreateChatCompletion(ctx, req)
578
579 // Handle successful response
580 if err == nil {
581 return s.toLLMResponse(&resp), nil
582 }
583
584 // Handle errors
585 var apiErr *openai.APIError
586 if ok := errors.As(err, &apiErr); !ok {
587 // Not an OpenAI API error, return immediately
588 return nil, err
589 }
590
591 switch {
592 case apiErr.HTTPStatusCode >= 500:
593 // Server error, try again with backoff
594 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
595 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
596 time.Sleep(sleep)
597 continue
598
599 case apiErr.HTTPStatusCode == 429:
600 // Rate limited, back off longer
601 sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
602 slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
603 time.Sleep(sleep)
604 continue
605
606 default:
607 // Other error, return immediately
608 return nil, fmt.Errorf("OpenAI API error: %w", err)
609 }
610 }
611}