blob: 551880680d32f2cbf631dd72ecb229326522d74a [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package oai
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "math/rand/v2"
11 "net/http"
12 "time"
13
14 "github.com/sashabaranov/go-openai"
15 "sketch.dev/llm"
16)
17
18const (
19 DefaultMaxTokens = 8192
20
21 OpenAIURL = "https://api.openai.com/v1"
22 FireworksURL = "https://api.fireworks.ai/inference/v1"
23 LlamaCPPURL = "http://localhost:8080/v1"
24 TogetherURL = "https://api.together.xyz/v1"
25 GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
26
27 // Environment variable names for API keys
28 OpenAIAPIKeyEnv = "OPENAI_API_KEY"
29 FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
30 TogetherAPIKeyEnv = "TOGETHER_API_KEY"
31 GeminiAPIKeyEnv = "GEMINI_API_KEY"
32)
33
34type Model struct {
35 UserName string // provided by the user to identify this model (e.g. "gpt4.1")
36 ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
37 URL string
38 Cost ModelCost
39 APIKeyEnv string // environment variable name for the API key
40}
41
42type ModelCost struct {
43 Input uint64 // in cents per million tokens
44 CachedInput uint64 // in cents per million tokens
45 Output uint64 // in cents per million tokens
46}
47
48var (
49 DefaultModel = GPT41
50
51 GPT41 = Model{
52 UserName: "gpt4.1",
53 ModelName: "gpt-4.1-2025-04-14",
54 URL: OpenAIURL,
55 Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
56 APIKeyEnv: OpenAIAPIKeyEnv,
57 }
58
59 Gemini25Flash = Model{
60 UserName: "gemini-flash-2.5",
61 ModelName: "gemini-2.5-flash-preview-04-17",
62 URL: GeminiURL,
63 Cost: ModelCost{Input: 15, Output: 60},
64 APIKeyEnv: GeminiAPIKeyEnv,
65 }
66
67 Gemini25Pro = Model{
68 UserName: "gemini-pro-2.5",
69 ModelName: "gemini-2.5-pro-preview-03-25",
70 URL: GeminiURL,
71 // GRRRR. Really??
72 // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
73 // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
74 // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
75 // Whatever that means. Are we caching? I have no idea.
76 // How do you always manage to be the annoying one, Google?
77 // I'm not complicating things just for you.
78 Cost: ModelCost{Input: 125, Output: 1000},
79 APIKeyEnv: GeminiAPIKeyEnv,
80 }
81
82 TogetherDeepseekV3 = Model{
83 UserName: "together-deepseek-v3",
84 ModelName: "deepseek-ai/DeepSeek-V3",
85 URL: TogetherURL,
86 Cost: ModelCost{Input: 125, Output: 125},
87 APIKeyEnv: TogetherAPIKeyEnv,
88 }
89
90 TogetherLlama4Maverick = Model{
91 UserName: "together-llama4-maverick",
92 ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
93 URL: TogetherURL,
94 Cost: ModelCost{Input: 27, Output: 85},
95 APIKeyEnv: TogetherAPIKeyEnv,
96 }
97
98 TogetherLlama3_3_70B = Model{
99 UserName: "together-llama3-70b",
100 ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
101 URL: TogetherURL,
102 Cost: ModelCost{Input: 88, Output: 88},
103 APIKeyEnv: TogetherAPIKeyEnv,
104 }
105
106 TogetherMistralSmall = Model{
107 UserName: "together-mistral-small",
108 ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
109 URL: TogetherURL,
110 Cost: ModelCost{Input: 80, Output: 80},
111 APIKeyEnv: TogetherAPIKeyEnv,
112 }
113
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700114 TogetherQwen3 = Model{
115 UserName: "together-qwen3",
116 ModelName: "Qwen/Qwen3-235B-A22B-fp8-tput",
117 URL: TogetherURL,
118 Cost: ModelCost{Input: 20, Output: 60},
119 APIKeyEnv: TogetherAPIKeyEnv,
120 }
121
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700122 LlamaCPP = Model{
123 UserName: "llama.cpp",
124 ModelName: "llama.cpp local model",
125 URL: LlamaCPPURL,
126 // zero cost
127 Cost: ModelCost{},
128 }
129
130 FireworksDeepseekV3 = Model{
131 UserName: "fireworks-deepseek-v3",
132 ModelName: "accounts/fireworks/models/deepseek-v3-0324",
133 URL: FireworksURL,
134 Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
135 APIKeyEnv: FireworksAPIKeyEnv,
136 }
137)
138
139// Service provides chat completions.
140// Fields should not be altered concurrently with calling any method on Service.
141type Service struct {
142 HTTPC *http.Client // defaults to http.DefaultClient if nil
143 APIKey string // optional, if not set will try to load from env var
144 Model Model // defaults to DefaultModel if zero value
145 MaxTokens int // defaults to DefaultMaxTokens if zero
146 Org string // optional - organization ID
147}
148
149var _ llm.Service = (*Service)(nil)
150
151// ModelsRegistry is a registry of all known models with their user-friendly names.
152var ModelsRegistry = []Model{
153 GPT41,
154 Gemini25Flash,
155 Gemini25Pro,
156 TogetherDeepseekV3,
157 TogetherLlama4Maverick,
158 TogetherLlama3_3_70B,
159 TogetherMistralSmall,
Josh Bleecher Snyder3e213082025-05-02 13:22:02 -0700160 TogetherQwen3,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700161 LlamaCPP,
162 FireworksDeepseekV3,
163}
164
165// ListModels returns a list of all available models with their user-friendly names.
166func ListModels() []string {
167 var names []string
168 for _, model := range ModelsRegistry {
169 if model.UserName != "" {
170 names = append(names, model.UserName)
171 }
172 }
173 return names
174}
175
176// ModelByUserName returns a model by its user-friendly name.
177// Returns nil if no model with the given name is found.
178func ModelByUserName(name string) *Model {
179 for _, model := range ModelsRegistry {
180 if model.UserName == name {
181 return &model
182 }
183 }
184 return nil
185}
186
187var (
188 fromLLMRole = map[llm.MessageRole]string{
189 llm.MessageRoleAssistant: "assistant",
190 llm.MessageRoleUser: "user",
191 }
192 fromLLMContentType = map[llm.ContentType]string{
193 llm.ContentTypeText: "text",
194 llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
195 llm.ContentTypeToolResult: "tool_result",
196 llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
197 llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
198 }
199 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
200 llm.ToolChoiceTypeAuto: "auto",
201 llm.ToolChoiceTypeAny: "any",
202 llm.ToolChoiceTypeNone: "none",
203 llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
204 }
205 toLLMRole = map[string]llm.MessageRole{
206 "assistant": llm.MessageRoleAssistant,
207 "user": llm.MessageRoleUser,
208 }
209 toLLMStopReason = map[string]llm.StopReason{
210 "stop": llm.StopReasonStopSequence,
211 "length": llm.StopReasonMaxTokens,
212 "tool_calls": llm.StopReasonToolUse,
213 "function_call": llm.StopReasonToolUse, // Map both to ToolUse
214 "content_filter": llm.StopReasonStopSequence, // No direct equivalent
215 }
216)
217
218// fromLLMContent converts llm.Content to the format expected by OpenAI.
219func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
220 switch c.Type {
221 case llm.ContentTypeText:
222 return c.Text, nil
223 case llm.ContentTypeToolUse:
224 // For OpenAI, tool use is sent as a null content with tool_calls in the message
225 return "", []openai.ToolCall{
226 {
227 Type: openai.ToolTypeFunction,
228 ID: c.ID, // Use the content ID if provided
229 Function: openai.FunctionCall{
230 Name: c.ToolName,
231 Arguments: string(c.ToolInput),
232 },
233 },
234 }
235 case llm.ContentTypeToolResult:
236 // Tool results in OpenAI are sent as a separate message with tool_call_id
237 return c.ToolResult, nil
238 default:
239 // For thinking or other types, convert to text
240 return c.Text, nil
241 }
242}
243
244// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
245func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
246 // For OpenAI, we need to handle tool results differently than regular messages
247 // Each tool result becomes its own message with role="tool"
248
249 var messages []openai.ChatCompletionMessage
250
251 // Check if this is a regular message or contains tool results
252 var regularContent []llm.Content
253 var toolResults []llm.Content
254
255 for _, c := range msg.Content {
256 if c.Type == llm.ContentTypeToolResult {
257 toolResults = append(toolResults, c)
258 } else {
259 regularContent = append(regularContent, c)
260 }
261 }
262
263 // Process tool results as separate messages, but first
264 for _, tr := range toolResults {
265 m := openai.ChatCompletionMessage{
266 Role: "tool",
267 Content: cmp.Or(tr.ToolResult, " "), // TODO: remove omitempty upstream
268 ToolCallID: tr.ToolUseID,
269 }
270 messages = append(messages, m)
271 }
272 // Process regular content second
273 if len(regularContent) > 0 {
274 m := openai.ChatCompletionMessage{
275 Role: fromLLMRole[msg.Role],
276 }
277
278 // For assistant messages that contain tool calls
279 var toolCalls []openai.ToolCall
280 var textContent string
281
282 for _, c := range regularContent {
283 content, tools := fromLLMContent(c)
284 if len(tools) > 0 {
285 toolCalls = append(toolCalls, tools...)
286 } else if content != "" {
287 if textContent != "" {
288 textContent += "\n"
289 }
290 textContent += content
291 }
292 }
293
294 m.Content = textContent
295 m.ToolCalls = toolCalls
296
297 messages = append(messages, m)
298 }
299
300 return messages
301}
302
303// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
304func fromLLMToolChoice(tc *llm.ToolChoice) any {
305 if tc == nil {
306 return nil
307 }
308
309 if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
310 return openai.ToolChoice{
311 Type: openai.ToolTypeFunction,
312 Function: openai.ToolFunction{
313 Name: tc.Name,
314 },
315 }
316 }
317
318 // For non-specific tool choice, just use the string
319 return fromLLMToolChoiceType[tc.Type]
320}
321
322// fromLLMTool converts llm.Tool to the format expected by OpenAI.
323func fromLLMTool(t *llm.Tool) openai.Tool {
324 return openai.Tool{
325 Type: openai.ToolTypeFunction,
326 Function: &openai.FunctionDefinition{
327 Name: t.Name,
328 Description: t.Description,
329 Parameters: t.InputSchema,
330 },
331 }
332}
333
334// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
335func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
336 if len(systemContent) == 0 {
337 return nil
338 }
339
340 // Combine all system content into a single system message
341 var systemText string
342 for i, content := range systemContent {
343 if i > 0 && systemText != "" && content.Text != "" {
344 systemText += "\n"
345 }
346 systemText += content.Text
347 }
348
349 if systemText == "" {
350 return nil
351 }
352
353 return []openai.ChatCompletionMessage{
354 {
355 Role: "system",
356 Content: systemText,
357 },
358 }
359}
360
361// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
362func toRawLLMContent(content string) llm.Content {
363 return llm.Content{
364 Type: llm.ContentTypeText,
365 Text: content,
366 }
367}
368
369// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
370func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
371 // Generate a content ID if needed
372 id := toolCall.ID
373 if id == "" {
374 // Create a deterministic ID based on the function name if no ID is provided
375 id = "tc_" + toolCall.Function.Name
376 }
377
378 return llm.Content{
379 ID: id,
380 Type: llm.ContentTypeToolUse,
381 ToolName: toolCall.Function.Name,
382 ToolInput: json.RawMessage(toolCall.Function.Arguments),
383 }
384}
385
386// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
387func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
388 return llm.Content{
389 Type: llm.ContentTypeToolResult,
390 ToolUseID: msg.ToolCallID,
391 ToolResult: msg.Content,
392 ToolError: false, // OpenAI doesn't specify errors explicitly
393 }
394}
395
396// toLLMContents converts message content from OpenAI to []llm.Content.
397func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
398 var contents []llm.Content
399
400 // If this is a tool response, handle it separately
401 if msg.Role == "tool" && msg.ToolCallID != "" {
402 return []llm.Content{toToolResultLLMContent(msg)}
403 }
404
405 // If there's text content, add it
406 if msg.Content != "" {
407 contents = append(contents, toRawLLMContent(msg.Content))
408 }
409
410 // If there are tool calls, add them
411 for _, tc := range msg.ToolCalls {
412 contents = append(contents, toToolCallLLMContent(tc))
413 }
414
415 // If empty, add an empty text content
416 if len(contents) == 0 {
417 contents = append(contents, llm.Content{
418 Type: llm.ContentTypeText,
419 Text: "",
420 })
421 }
422
423 return contents
424}
425
426// toLLMUsage converts usage information from OpenAI to llm.Usage.
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700427func (s *Service) toLLMUsage(au openai.Usage) llm.Usage {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700428 // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
429 in := uint64(au.PromptTokens)
430 var inc uint64
431 if au.PromptTokensDetails != nil {
432 inc = uint64(au.PromptTokensDetails.CachedTokens)
433 }
434 out := uint64(au.CompletionTokens)
435 u := llm.Usage{
436 InputTokens: in,
437 CacheReadInputTokens: inc,
438 CacheCreationInputTokens: in,
439 OutputTokens: out,
440 }
441 u.CostUSD = s.calculateCostFromTokens(u)
442 return u
443}
444
445// toLLMResponse converts the OpenAI response to llm.Response.
446func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
447 // fmt.Printf("Raw response\n")
448 // enc := json.NewEncoder(os.Stdout)
449 // enc.SetIndent("", " ")
450 // enc.Encode(r)
451 // fmt.Printf("\n")
452
453 if len(r.Choices) == 0 {
454 return &llm.Response{
455 ID: r.ID,
456 Model: r.Model,
457 Role: llm.MessageRoleAssistant,
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700458 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700459 }
460 }
461
462 // Process the primary choice
463 choice := r.Choices[0]
464
465 return &llm.Response{
466 ID: r.ID,
467 Model: r.Model,
468 Role: toRoleFromString(choice.Message.Role),
469 Content: toLLMContents(choice.Message),
470 StopReason: toStopReason(string(choice.FinishReason)),
Josh Bleecher Snyder66439b02025-05-02 18:35:32 -0700471 Usage: s.toLLMUsage(r.Usage),
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700472 }
473}
474
475// toRoleFromString converts a role string to llm.MessageRole.
476func toRoleFromString(role string) llm.MessageRole {
477 if role == "tool" || role == "system" || role == "function" {
478 return llm.MessageRoleAssistant // Map special roles to assistant for consistency
479 }
480 if mr, ok := toLLMRole[role]; ok {
481 return mr
482 }
483 return llm.MessageRoleUser // Default to user if unknown
484}
485
486// toStopReason converts a finish reason string to llm.StopReason.
487func toStopReason(reason string) llm.StopReason {
488 if sr, ok := toLLMStopReason[reason]; ok {
489 return sr
490 }
491 return llm.StopReasonStopSequence // Default
492}
493
494// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
495func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
496 cost := s.Model.Cost
497
498 // TODO: check this for correctness, i am skeptical
499 // Calculate cost in cents
500 megaCents := u.CacheCreationInputTokens*cost.Input +
501 u.CacheReadInputTokens*cost.CachedInput +
502 u.OutputTokens*cost.Output
503
504 cents := float64(megaCents) / 1_000_000
505 // Convert to dollars
506 dollars := cents / 100.0
507 // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
508 return dollars
509}
510
511// Do sends a request to OpenAI using the go-openai package.
512func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
513 // Configure the OpenAI client
514 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
515 model := cmp.Or(s.Model, DefaultModel)
516
517 // TODO: do this one during Service setup? maybe with a constructor instead?
518 config := openai.DefaultConfig(s.APIKey)
519 if model.URL != "" {
520 config.BaseURL = model.URL
521 }
522 if s.Org != "" {
523 config.OrgID = s.Org
524 }
525 config.HTTPClient = httpc
526
527 client := openai.NewClientWithConfig(config)
528
529 // Start with system messages if provided
530 var allMessages []openai.ChatCompletionMessage
531 if len(ir.System) > 0 {
532 sysMessages := fromLLMSystem(ir.System)
533 allMessages = append(allMessages, sysMessages...)
534 }
535
536 // Add regular and tool messages
537 for _, msg := range ir.Messages {
538 msgs := fromLLMMessage(msg)
539 allMessages = append(allMessages, msgs...)
540 }
541
542 // Convert tools
543 var tools []openai.Tool
544 for _, t := range ir.Tools {
545 tools = append(tools, fromLLMTool(t))
546 }
547
548 // Create the OpenAI request
549 req := openai.ChatCompletionRequest{
550 Model: model.ModelName,
551 Messages: allMessages,
552 MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
553 Tools: tools,
554 ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
555 }
556 // fmt.Printf("Sending request to OpenAI\n")
557 // enc := json.NewEncoder(os.Stdout)
558 // enc.SetIndent("", " ")
559 // enc.Encode(req)
560 // fmt.Printf("\n")
561
562 // Retry mechanism
563 backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
564
565 // retry loop
566 for attempts := 0; ; attempts++ {
567 resp, err := client.CreateChatCompletion(ctx, req)
568
569 // Handle successful response
570 if err == nil {
571 return s.toLLMResponse(&resp), nil
572 }
573
574 // Handle errors
575 var apiErr *openai.APIError
576 if ok := errors.As(err, &apiErr); !ok {
577 // Not an OpenAI API error, return immediately
578 return nil, err
579 }
580
581 switch {
582 case apiErr.HTTPStatusCode >= 500:
583 // Server error, try again with backoff
584 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
585 slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
586 time.Sleep(sleep)
587 continue
588
589 case apiErr.HTTPStatusCode == 429:
590 // Rate limited, back off longer
591 sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
592 slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
593 time.Sleep(sleep)
594 continue
595
596 default:
597 // Other error, return immediately
598 return nil, fmt.Errorf("OpenAI API error: %w", err)
599 }
600 }
601}