blob: b2e73cf3da84edda79ce4596c4a515b66a0c635e [file] [log] [blame]
Josh Bleecher Snyderb4782142025-05-05 19:16:54 +00001// Package experiment provides support for experimental features.
2package experiment
3
4import (
5 "fmt"
6 "io"
7 "strings"
8 "sync"
9)
10
11// Experiment represents an experimental feature.
12// Experiments are global.
13type Experiment struct {
14 Name string // The name of the experiment used in -x flag
15 Description string // A short description of what the experiment does
16 Enabled bool // Whether the experiment is enabled
17}
18
19var (
20 mu sync.Mutex
21 experiments = []Experiment{
22 {
23 Name: "list",
24 Description: "List all available experiments and exit",
25 },
26 {
27 Name: "all",
28 Description: "Enable all experiments",
29 },
Josh Bleecher Snyder503b5e32025-05-05 13:30:55 -070030 {
31 Name: "not_done",
32 Description: "Let the model backtrack halfway through a done tool call",
33 },
Josh Bleecher Snyderb4782142025-05-05 19:16:54 +000034 }
35 byName = map[string]*Experiment{}
36)
37
38func Enabled(name string) bool {
39 mu.Lock()
40 defer mu.Unlock()
41 return byName[name].Enabled
42}
43
44func init() {
45 for _, e := range experiments {
46 byName[e.Name] = &e
47 }
48}
49
50func (e Experiment) String() string {
51 return fmt.Sprintf("\t%-15s %s\n", e.Name, e.Description)
52}
53
54// Fprint writes a list of all available experiments to w.
55func Fprint(w io.Writer) {
56 mu.Lock()
57 defer mu.Unlock()
58
59 fmt.Fprintln(w, "Available experiments:")
60 for _, e := range experiments {
61 fmt.Fprintln(w, e)
62 }
63}
64
65// Flag is a custom flag type that allows for comma-separated
66// values and can be used multiple times.
67type Flag struct {
68 Value string
69}
70
71// String returns the string representation of the flag value.
72func (f *Flag) String() string {
73 return f.Value
74}
75
76// Set adds a value to the flag.
77func (f *Flag) Set(value string) error {
78 f.Value = f.Value + "," + value // quadratic, doesn't matter, tiny N
79 return nil
80}
81
82// Get returns the flag values.
83func (f *Flag) Get() any {
84 return f.Value
85}
86
87// Process handles all flag values, enabling the appropriate experiments.
88func (f *Flag) Process() error {
89 mu.Lock()
90 defer mu.Unlock()
91
92 for name := range strings.SplitSeq(f.Value, ",") {
93 name = strings.TrimSpace(name)
94 if name == "" {
95 continue
96 }
97 e, ok := byName[name]
98 if !ok {
99 return fmt.Errorf("unknown experiment: %q", name)
100 }
101 e.Enabled = true
102 }
103 if byName["all"].Enabled {
104 for _, e := range experiments {
105 if e.Name == "list" {
106 continue
107 }
108 e.Enabled = true
109 }
110 }
111 return nil
112}