make -one-shot command line work with both -unsafe and regular
The "-one" option had atrophied in a variety of ways, against both
unsafe and dockerized environments. I resurrected it, and simplified
the flag handling slightly by using just one flag.
diff --git a/cmd/sketch/main.go b/cmd/sketch/main.go
index ef13053..8ea4b65 100644
--- a/cmd/sketch/main.go
+++ b/cmd/sketch/main.go
@@ -44,7 +44,7 @@
maxIterations := flag.Uint64("max-iterations", 0, "maximum number of iterations the agent should perform per turn, 0 to disable limit")
maxWallTime := flag.Duration("max-wall-time", 0, "maximum time the agent should run per turn, 0 to disable limit")
maxDollars := flag.Float64("max-dollars", 5.0, "maximum dollars the agent should spend per turn, 0 to disable limit")
- one := flag.Bool("one", false, "run a single iteration and exit without termui")
+ oneShot := flag.String("one-shot", "", "run a single iteration with the given prompt and exit without termui")
verbose := flag.Bool("verbose", false, "enable verbose output")
version := flag.Bool("version", false, "print the version and exit")
workingDir := flag.String("C", "", "when set, change to this directory before running")
@@ -73,8 +73,6 @@
return nil
}
- firstMessage := flag.Args()
-
// Add a global "session_id" to all logs using this context.
// A "session" is a single full run of the agent.
ctx := skribe.ContextWithAttr(context.Background(), slog.String("session_id", *sessionID))
@@ -82,7 +80,7 @@
var slogHandler slog.Handler
var err error
var logFile *os.File
- if !*one && !*verbose {
+ if *oneShot == "" && !*verbose {
// Log to a file
logFile, err = os.CreateTemp("", "sketch-cli-log-*")
if err != nil {
@@ -133,10 +131,6 @@
}
}
- if *one && len(firstMessage) == 0 {
- return fmt.Errorf("-one flag requires a message to send to the agent")
- }
-
var pubKey, antURL, apiKey string
if *skabandAddr == "" {
apiKey = os.Getenv("ANTHROPIC_API_KEY")
@@ -203,6 +197,7 @@
OutsideHostname: getHostname(),
OutsideOS: runtime.GOOS,
OutsideWorkingDir: cwd,
+ OneShot: *oneShot,
}
if err := dockerimg.LaunchContainer(ctx, stdout, stderr, config); err != nil {
if *verbose {
@@ -288,10 +283,6 @@
ps1URL = fmt.Sprintf("http://%s", ln.Addr())
}
- if len(firstMessage) > 0 {
- agent.UserMessage(ctx, strings.Join(firstMessage, " "))
- }
-
if inDocker {
<-agent.Ready()
if ps1URL == "" {
@@ -299,6 +290,10 @@
}
}
+ if *oneShot != "" {
+ agent.UserMessage(ctx, *oneShot)
+ }
+
// Open the web UI URL in the system browser if requested
if *openBrowser {
dockerimg.OpenBrowser(ctx, ps1URL)
@@ -306,15 +301,6 @@
// Create the termui instance
s := termui.New(agent, ps1URL)
- defer func() {
- r := recover()
- if err := s.RestoreOldState(); err != nil {
- fmt.Fprintf(os.Stderr, "couldn't restore old terminal state: %s\n", err)
- }
- if r != nil {
- panic(r)
- }
- }()
// Start skaband connection loop if needed
if *skabandAddr != "" {
@@ -330,7 +316,7 @@
go skabandclient.DialAndServeLoop(ctx, *skabandAddr, *sessionID, pubKey, srv, connectFn)
}
- if *one {
+ if *oneShot != "" {
for {
m := agent.WaitForMessage(ctx)
if m.Content != "" {
@@ -343,6 +329,15 @@
}
}
+ defer func() {
+ r := recover()
+ if err := s.RestoreOldState(); err != nil {
+ fmt.Fprintf(os.Stderr, "couldn't restore old terminal state: %s\n", err)
+ }
+ if r != nil {
+ panic(r)
+ }
+ }()
if err := s.Run(ctx); err != nil {
return err
}
diff --git a/dockerimg/dockerimg.go b/dockerimg/dockerimg.go
index a787114..e40a77b 100644
--- a/dockerimg/dockerimg.go
+++ b/dockerimg/dockerimg.go
@@ -76,6 +76,9 @@
OutsideHostname string
OutsideOS string
OutsideWorkingDir string
+
+ // If not empty, handle this message and exit
+ OneShot string
}
// LaunchContainer creates a docker container for a project, installs sketch and opens a connection to it.
@@ -439,6 +442,9 @@
if config.SkabandAddr != "" {
cmdArgs = append(cmdArgs, "-skaband-addr="+config.SkabandAddr)
}
+ if config.OneShot != "" {
+ cmdArgs = append(cmdArgs, "-one-shot", config.OneShot)
+ }
if out, err := combinedOutput(ctx, "docker", cmdArgs...); err != nil {
return fmt.Errorf("docker create: %s, %w", out, err)
}