gemini: basic gemini API wrapper
diff --git a/llm/gem/gemini/gemini.go b/llm/gem/gemini/gemini.go
new file mode 100644
index 0000000..a6b83e4
--- /dev/null
+++ b/llm/gem/gemini/gemini.go
@@ -0,0 +1,179 @@
+package gemini
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+// https://ai.google.dev/api/generate-content#request-body
+type Request struct {
+ Contents []Content `json:"contents"`
+ Tools []Tool `json:"tools,omitempty"`
+ SystemInstruction *Content `json:"systemInstruction,omitempty"`
+ GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
+ CachedContent string `json:"cachedContent,omitempty"` // format: "cachedContents/{name}"
+ // ToolConfig has been left out because it does not appear to be useful.
+}
+
+// https://ai.google.dev/api/generate-content#response-body
+type Response struct {
+ Candidates []Candidate `json:"candidates"`
+}
+
+type Candidate struct {
+ Content Content `json:"content"`
+}
+
+type Content struct {
+ Parts []Part `json:"parts"`
+}
+
+// Part is a part of the content.
+// This is a union data structure, only one-of the fields can be set.
+type Part struct {
+ Text string `json:"text,omitempty"`
+ FunctionCall *FunctionCall `json:"functionCall,omitempty"`
+ FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
+ ExecutableCode *ExecutableCode `json:"executableCode,omitempty"`
+ CodeExecutionResult *CodeExecutionResult `json:"codeExecutionResult,omitempty"`
+ // TODO inlineData
+ // TODO fileData
+}
+
+type FunctionCall struct {
+ Name string `json:"name"`
+ Args map[string]any `json:"args"`
+}
+
+type FunctionResponse struct {
+ Name string `json:"name"`
+ Response map[string]any `json:"response"`
+}
+
+type ExecutableCode struct {
+ Language Language `json:"language"`
+ Code string `json:"code"`
+}
+
+type Language int
+
+const (
+ LanguageUnspecified Language = 0
+ LanguagePython Language = 1 // python >= 3.10 with numpy and simpy
+)
+
+type CodeExecutionResult struct {
+ Outcome Outcome `json:"outcome"`
+ Output string `json:"output"`
+}
+
+type Outcome int
+
+const (
+ OutcomeUnspecified Outcome = 0
+ OutcomeOK Outcome = 1
+ OutcomeFailed Outcome = 2
+ OutcomeDeadlineExceeded Outcome = 3
+)
+
+// https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
+type GenerationConfig struct {
+ ResponseMimeType string `json:"responseMimeType,omitempty"` // text/plain, application/json, or text/x.enum
+ ResponseSchema *Schema `json:"responseSchema,omitempty"` // for JSON
+}
+
+// https://ai.google.dev/api/caching#Tool
+type Tool struct {
+ FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
+ CodeExecution *struct{} `json:"codeExecution,omitempty"` // if present, enables the model to execute code
+ // TODO googleSearchRetrieval https://ai.google.dev/api/caching#GoogleSearchRetrieval
+}
+
+// https://ai.google.dev/api/caching#FunctionDeclaration
+type FunctionDeclaration struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters Schema `json:"parameters"`
+}
+
+// https://ai.google.dev/api/caching#Schema
+type Schema struct {
+ Type DataType `json:"type"`
+ Format string `json:"string,omitempty"` // for NUMBER type: float, double for INTEGER type: int32, int64 for STRING type: enum
+ Description string `json:"description,omitempty"`
+ Nullable *bool `json:"nullable,omitempty"`
+ Enum []string `json:"enum,omitempty"`
+ MaxItems string `json:"maxItems,omitempty"` // for ARRAY
+ MinItems string `json:"minItems,omitempty"` // for ARRAY
+ Properties map[string]Schema `json:"properties,omitempty"` // for OBJECT
+ Required []string `json:"required,omitempty"` // for OBJECT
+ Items *Schema `json:"items,omitempty"` // for ARRAY
+}
+
+type DataType int
+
+const (
+ DataTypeUNSPECIFIED = DataType(0) // Not specified, should not be used.
+ DataTypeSTRING = DataType(1)
+ DataTypeNUMBER = DataType(2)
+ DataTypeINTEGER = DataType(3)
+ DataTypeBOOLEAN = DataType(4)
+ DataTypeARRAY = DataType(5)
+ DataTypeOBJECT = DataType(6)
+)
+
+const defaultEndpoint = "https://generativelanguage.googleapis.com/v1beta"
+
+type Model struct {
+ Model string // e.g. "models/gemini-1.5-flash"
+ APIKey string
+ HTTPC *http.Client // if nil, http.DefaultClient is used
+ Endpoint string // if empty, DefaultEndpoint is used
+}
+
+func (m Model) GenerateContent(ctx context.Context, req *Request) (*Response, error) {
+ reqBytes, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("marshaling request: %w", err)
+ }
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/%s:generateContent?key=%s", m.endpoint(), m.Model, m.APIKey), bytes.NewReader(reqBytes))
+ if err != nil {
+ return nil, fmt.Errorf("creating HTTP request: %w", err)
+ }
+ httpReq.Header.Add("Content-Type", "application/json")
+ httpResp, err := m.httpc().Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("GenerateContent: do: %w", err)
+ }
+ defer httpResp.Body.Close()
+ body, err := io.ReadAll(httpResp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("GenerateContent: reading response body: %w", err)
+ }
+ if httpResp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("GenerateContent: HTTP status: %d, %s", httpResp.StatusCode, string(body))
+ }
+ var res Response
+ if err := json.Unmarshal(body, &res); err != nil {
+ return nil, fmt.Errorf("GenerateContent: unmarshaling response: %w, %s", err, string(body))
+ }
+ return &res, nil
+}
+
+func (m Model) endpoint() string {
+ if m.Endpoint != "" {
+ return m.Endpoint
+ }
+ return defaultEndpoint
+}
+
+func (m Model) httpc() *http.Client {
+ if m.HTTPC != nil {
+ return m.HTTPC
+ }
+ return http.DefaultClient
+}
diff --git a/llm/gem/gemini/gemini_test.go b/llm/gem/gemini/gemini_test.go
new file mode 100644
index 0000000..18ced45
--- /dev/null
+++ b/llm/gem/gemini/gemini_test.go
@@ -0,0 +1,33 @@
+package gemini
+
+import (
+ "context"
+ "os"
+ "testing"
+)
+
+func TestGenerateContent(t *testing.T) {
+ // TODO replace with local replay endpoint
+ m := Model{
+ Model: "models/gemini-1.5-flash",
+ APIKey: os.Getenv("GEMINI_API_KEY"),
+ }
+ if testing.Short() {
+ t.Skip("skipping test in short mode")
+ }
+ if m.APIKey == "" {
+ t.Skip("skipping test without API key")
+ }
+
+ res, err := m.GenerateContent(context.Background(), &Request{
+ Contents: []Content{{
+ Parts: []Part{{
+ Text: "What is the capital of France?",
+ }},
+ }},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("res: %+v", res)
+}