cmd/sketch: expand ~ in host path in mounts
diff --git a/cmd/sketch/main.go b/cmd/sketch/main.go
index 8067af7..1dd4802 100644
--- a/cmd/sketch/main.go
+++ b/cmd/sketch/main.go
@@ -129,6 +129,25 @@
}
}
+// expandTilde expands ~ in the given path to the user's home directory
+func expandTilde(path string) (string, error) {
+ if path == "~" {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return path, err
+ }
+ return homeDir, nil
+ }
+ if strings.HasPrefix(path, "~/") {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return path, err
+ }
+ return strings.Replace(path, "~", homeDir, 1), nil
+ }
+ return path, nil
+}
+
// CLIFlags holds all command-line arguments
// StringSliceFlag is a custom flag type that allows for repeated flag values.
// It collects all values into a slice.
@@ -283,6 +302,20 @@
flags.openBrowser = !flags.oneShot && os.Getenv("SSH_CONNECTION") == ""
}
+ // expand ~ in mounts
+ for i, mount := range flags.mounts {
+ host, container, ok := strings.Cut(mount, ":")
+ if !ok {
+ continue
+ }
+ expanded, err := expandTilde(host)
+ if err != nil {
+ slog.Warn("failed to expand tilde in mount path", "path", host, "error", err)
+ continue
+ }
+ flags.mounts[i] = expanded + ":" + container
+ }
+
return flags
}
diff --git a/cmd/sketch/main_test.go b/cmd/sketch/main_test.go
new file mode 100644
index 0000000..3ebd05d
--- /dev/null
+++ b/cmd/sketch/main_test.go
@@ -0,0 +1,37 @@
+package main
+
+import (
+ "os"
+ "testing"
+)
+
+func TestExpandTilde(t *testing.T) {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ t.Fatalf("Failed to get home directory: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {"tilde only", "~", homeDir},
+ {"tilde with path", "~/Documents", homeDir + "/Documents"},
+ {"no tilde", "/absolute/path", "/absolute/path"},
+ {"tilde in middle", "/path/~/middle", "/path/~/middle"},
+ {"relative path", "relative/path", "relative/path"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := expandTilde(tt.input)
+ if err != nil {
+ t.Errorf("expandTilde(%q) returned error: %v", tt.input, err)
+ }
+ if result != tt.expected {
+ t.Errorf("expandTilde(%q) = %q, want %q", tt.input, result, tt.expected)
+ }
+ })
+ }
+}