blob: b4db54554d28a68a2e9bde79652d486e2ea36c2b [file] [log] [blame]
package llm
import (
"context"
"time"
)
// LLMProvider defines the interface that all LLM providers must implement
type LLMProvider interface {
// ChatCompletion creates a chat completion
ChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
// CreateEmbeddings generates embeddings for the given input
CreateEmbeddings(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error)
// Close performs any necessary cleanup
Close() error
}
// Provider represents different LLM service providers
type Provider string
const (
ProviderOpenAI Provider = "openai"
ProviderXAI Provider = "xai"
ProviderClaude Provider = "claude"
ProviderGemini Provider = "gemini"
ProviderLocal Provider = "local"
ProviderFake Provider = "fake"
)
// Role represents the role of a message participant
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleTool Role = "tool"
)
// Message represents a single message in a conversation
type Message struct {
Role Role `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
// ToolCall represents a function/tool call request
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function Function `json:"function"`
}
// Function represents a function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
// Tool represents a tool that can be called by the model
type Tool struct {
Type string `json:"type"`
Function Function `json:"function"`
}
// ChatCompletionRequest represents a request to complete a chat conversation
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
N *int `json:"n,omitempty"`
Stream *bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed *int64 `json:"seed,omitempty"`
ExtraParams map[string]interface{} `json:"-"` // For provider-specific parameters
}
// ResponseFormat specifies the format of the response
type ResponseFormat struct {
Type string `json:"type"` // "text" or "json_object"
}
// ChatCompletionResponse represents a response from a chat completion request
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []ChatCompletionChoice `json:"choices"`
Usage Usage `json:"usage"`
Provider Provider `json:"provider"`
}
// ChatCompletionChoice represents a single choice in a chat completion response
type ChatCompletionChoice struct {
Index int `json:"index"`
Message Message `json:"message"`
Logprobs *Logprobs `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason"`
Delta *Message `json:"delta,omitempty"` // For streaming
ExtraData map[string]interface{} `json:"-"` // For provider-specific data
}
// Logprobs represents log probability information
type Logprobs struct {
Content []LogprobContent `json:"content,omitempty"`
}
// LogprobContent represents content with log probabilities
type LogprobContent struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
TopLogprobs []TopLogprob `json:"top_logprobs,omitempty"`
}
// TopLogprob represents a top log probability
type TopLogprob struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
}
// Usage represents token usage information
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// EmbeddingRequest represents a request to generate embeddings
type EmbeddingRequest struct {
Input interface{} `json:"input"` // string, []string, or []int
Model string `json:"model"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
ExtraParams map[string]interface{} `json:"-"` // For provider-specific parameters
}
// EmbeddingResponse represents a response from an embedding request
type EmbeddingResponse struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Usage Usage `json:"usage"`
Model string `json:"model"`
Provider Provider `json:"provider"`
}
// Embedding represents a single embedding
type Embedding struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// ModelInfo represents information about an available model
type ModelInfo struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []ModelPermission `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
Provider Provider `json:"provider"`
ExtraData map[string]interface{} `json:"-"` // For provider-specific data
}
// ModelPermission represents permissions for a model
type ModelPermission struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
// Error represents an error response from an LLM provider
type Error struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
Param string `json:"param,omitempty"`
} `json:"error"`
}
// Config represents configuration for an LLM provider
type Config struct {
Provider Provider `json:"provider"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url,omitempty"`
Timeout time.Duration `json:"timeout,omitempty"`
MaxRetries int `json:"max_retries,omitempty"`
ExtraConfig map[string]interface{} `json:"extra_config,omitempty"`
}
// StreamResponse represents a streaming response chunk
type StreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []ChatCompletionChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
Provider Provider `json:"provider"`
}
// DefaultConfigs provides default configurations for different providers
var DefaultConfigs = map[Provider]Config{
ProviderOpenAI: {
Provider: ProviderOpenAI,
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
MaxRetries: 3,
},
ProviderXAI: {
Provider: ProviderXAI,
BaseURL: "https://api.x.ai/v1",
Timeout: 30 * time.Second,
MaxRetries: 3,
},
ProviderClaude: {
Provider: ProviderClaude,
BaseURL: "https://api.anthropic.com/v1",
Timeout: 30 * time.Second,
MaxRetries: 3,
},
ProviderGemini: {
Provider: ProviderGemini,
BaseURL: "https://generativelanguage.googleapis.com/v1",
Timeout: 30 * time.Second,
MaxRetries: 3,
},
ProviderLocal: {
Provider: ProviderLocal,
BaseURL: "http://localhost:11434",
Timeout: 60 * time.Second,
MaxRetries: 1,
},
ProviderFake: {
Provider: ProviderFake,
BaseURL: "fake://test",
Timeout: 1 * time.Second,
MaxRetries: 0,
},
}