Add unit tests for dockerimg/sshtheater.go with refactoring for testability
Refactored the code to use dependency injection to make it more testable, including:
- Created FileSystem and KeyGenerator interfaces
- Added RealFileSystem and RealKeyGenerator implementations
- Refactored SSHTheater to use these interfaces
- Added comprehensive unit tests for all functionality
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/dockerimg/sshtheater.go b/dockerimg/sshtheater.go
index e54124c..ad61275 100644
--- a/dockerimg/sshtheater.go
+++ b/dockerimg/sshtheater.go
@@ -7,6 +7,7 @@
"crypto/x509"
"encoding/pem"
"fmt"
+ "io/fs"
"os"
"path/filepath"
"strings"
@@ -43,6 +44,9 @@
serverPublicKey ssh.PublicKey
serverIdentity []byte
userIdentity []byte
+
+ fs FileSystem
+ kg KeyGenerator
}
// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
@@ -57,9 +61,14 @@
// in a terminal on your host machine to open a shell into the container without having
// to manually accept changes to your known_hosts file etc.
func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
+ return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
+}
+
+// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
+func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
base := filepath.Join(os.Getenv("HOME"), ".sketch")
- if _, err := os.Stat(base); err != nil {
- if err := os.Mkdir(base, 0o777); err != nil {
+ if _, err := fs.Stat(base); err != nil {
+ if err := fs.Mkdir(base, 0o777); err != nil {
return nil, fmt.Errorf("couldn't create %s: %w", base, err)
}
}
@@ -72,28 +81,33 @@
userIdentityPath: filepath.Join(base, "container_user_identity"),
serverIdentityPath: filepath.Join(base, "container_server_identity"),
sshConfigPath: filepath.Join(base, "ssh_config"),
+ fs: fs,
+ kg: kg,
}
- if _, err := createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
+ if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
return nil, fmt.Errorf("couldn't create server identity: %w", err)
}
- if _, err := createKeyPairIfMissing(cst.userIdentityPath); err != nil {
+ if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
return nil, fmt.Errorf("couldn't create user identity: %w", err)
}
- serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
+ serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
if err != nil {
return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
}
cst.serverIdentity = serverIdentity
- serverPubKeyBytes, err := os.ReadFile(cst.serverIdentityPath + ".pub")
+ serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
+ if err != nil {
+ return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
+ }
serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
if err != nil {
- return nil, fmt.Errorf("couldn't read ssh server public key: %w", err)
+ return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
}
cst.serverPublicKey = serverPubKey
- userIdentity, err := os.ReadFile(cst.userIdentityPath + ".pub")
+ userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
if err != nil {
return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
}
@@ -110,10 +124,14 @@
return cst, nil
}
-func CheckForInclude() error {
+func CheckForIncludeWithFS(fs FileSystem) error {
sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".sketch", "ssh_config")
defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
- f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
+ f, _ := fs.OpenFile(filepath.Join(os.Getenv("HOME"), ".ssh", "config"), os.O_RDONLY, 0)
+ if f == nil {
+ return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s", defaultSSHPath)
+ }
+ defer f.Close()
cfg, _ := ssh_config.Decode(f)
var sketchInludePos *ssh_config.Position
var firstNonIncludePos *ssh_config.Position
@@ -133,11 +151,11 @@
}
if sketchInludePos == nil {
- return fmt.Errorf("⚠️ SSH connections are disabled. To enable them, add the line %q to the top of %s before any 'Host' lines", sketchSSHPathInclude, defaultSSHPath)
+ return fmt.Errorf("⚠️ SSH connections are disabled. to enable them, add the line %q to the top of %s before any 'Host' lines", sketchSSHPathInclude, defaultSSHPath)
}
if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
- fmt.Printf("⚠️ SSH confg warning: The location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. Try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
+ fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
}
return nil
}
@@ -163,25 +181,6 @@
return hosts
}
-func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
- privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
- if err != nil {
- return nil, err
- }
- return privateKey, nil
-}
-
-// generatePublicKey take a rsa.PublicKey and return bytes suitable for writing to .pub file
-// returns in the format "ssh-rsa ..."
-func generatePublicKey(privatekey *rsa.PublicKey) (ssh.PublicKey, error) {
- publicRsaKey, err := ssh.NewPublicKey(privatekey)
- if err != nil {
- return nil, err
- }
-
- return publicRsaKey, nil
-}
-
func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
pemBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
@@ -191,37 +190,37 @@
return pemBytes
}
-func writeKeyToFile(keyBytes []byte, filename string) error {
- err := os.WriteFile(filename, keyBytes, 0o600)
+func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
+ err := c.fs.WriteFile(filename, keyBytes, 0o600)
return err
}
-func createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
- if _, err := os.Stat(idPath); err == nil {
+func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
+ if _, err := c.fs.Stat(idPath); err == nil {
return nil, nil
}
- privateKey, err := generatePrivateKey(keyBitSize)
+ privateKey, err := c.kg.GeneratePrivateKey(keyBitSize)
if err != nil {
- return nil, fmt.Errorf("Error generating private key: %w", err)
+ return nil, fmt.Errorf("error generating private key: %w", err)
}
- publicRsaKey, err := generatePublicKey(&privateKey.PublicKey)
+ publicRsaKey, err := c.kg.GeneratePublicKey(&privateKey.PublicKey)
if err != nil {
- return nil, fmt.Errorf("Error generating public key: %w", err)
+ return nil, fmt.Errorf("error generating public key: %w", err)
}
privateKeyPEM := encodePrivateKeyToPEM(privateKey)
- err = writeKeyToFile(privateKeyPEM, idPath)
+ err = c.writeKeyToFile(privateKeyPEM, idPath)
if err != nil {
- return nil, fmt.Errorf("Error writing private key to file %w", err)
+ return nil, fmt.Errorf("error writing private key to file %w", err)
}
pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
- err = writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
+ err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
if err != nil {
- return nil, fmt.Errorf("Error writing public key to file %w", err)
+ return nil, fmt.Errorf("error writing public key to file %w", err)
}
return publicRsaKey, nil
}
@@ -252,7 +251,7 @@
}
func (c *SSHTheater) addContainerToSSHConfig() error {
- f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+ f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
return fmt.Errorf("couldn't open ssh_config: %w", err)
}
@@ -302,14 +301,14 @@
}
func (c *SSHTheater) addContainerToKnownHosts() error {
- f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+ f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
}
defer f.Close()
pkBytes := c.serverPublicKey.Marshal()
if len(pkBytes) == 0 {
- return fmt.Errorf("empty serverPublicKey. This is a bug")
+ return fmt.Errorf("empty serverPublicKey, this is a bug")
}
newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
@@ -333,7 +332,7 @@
}
func (c *SSHTheater) removeContainerFromKnownHosts() error {
- f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+ f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
return fmt.Errorf("couldn't open ssh_config: %w", err)
}
@@ -372,7 +371,7 @@
}
func (c *SSHTheater) removeContainerFromSSHConfig() error {
- f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+ f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
return fmt.Errorf("couldn't open ssh_config: %w", err)
}
@@ -403,3 +402,57 @@
}
return nil
}
+
+// FileSystem represents a filesystem interface for testability
+type FileSystem interface {
+ Stat(name string) (fs.FileInfo, error)
+ Mkdir(name string, perm fs.FileMode) error
+ ReadFile(name string) ([]byte, error)
+ WriteFile(name string, data []byte, perm fs.FileMode) error
+ OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
+}
+
+// RealFileSystem is the default implementation of FileSystem that uses the OS
+type RealFileSystem struct{}
+
+func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
+ return os.Stat(name)
+}
+
+func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
+ return os.Mkdir(name, perm)
+}
+
+func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
+ return os.ReadFile(name)
+}
+
+func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
+ return os.WriteFile(name, data, perm)
+}
+
+func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
+ return os.OpenFile(name, flag, perm)
+}
+
+// KeyGenerator represents an interface for generating SSH keys for testability
+type KeyGenerator interface {
+ GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
+ GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error)
+}
+
+// RealKeyGenerator is the default implementation of KeyGenerator
+type RealKeyGenerator struct{}
+
+func (kg *RealKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
+ return rsa.GenerateKey(rand.Reader, bitSize)
+}
+
+func (kg *RealKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
+ return ssh.NewPublicKey(privateKey)
+}
+
+// CheckForInclude checks if the user's SSH config includes the Sketch SSH config file
+func CheckForInclude() error {
+ return CheckForIncludeWithFS(&RealFileSystem{})
+}