blob: a4e41a4204319da55325dd1d27120cee63a31aff [file] [log] [blame]
Sean McCullough2cba6952025-04-25 20:32:10 +00001package dockerimg
2
3import (
Sean McCullough15c95282025-05-08 16:48:38 -07004 "bufio"
Sean McCullough2cba6952025-04-25 20:32:10 +00005 "bytes"
6 "crypto/rand"
7 "crypto/rsa"
8 "fmt"
9 "io/fs"
10 "os"
11 "path/filepath"
12 "strings"
13 "testing"
14
15 "golang.org/x/crypto/ssh"
16)
17
18// MockFileSystem implements the FileSystem interface for testing
19type MockFileSystem struct {
20 Files map[string][]byte
21 CreatedDirs map[string]bool
22 OpenedFiles map[string]*MockFile
23 StatCalledWith []string
Sean McCullough0d95d3a2025-04-30 16:22:28 +000024 TempFiles []string
Sean McCullough2cba6952025-04-25 20:32:10 +000025 FailOn map[string]error // Map of function name to error to simulate failures
26}
27
28func NewMockFileSystem() *MockFileSystem {
29 return &MockFileSystem{
30 Files: make(map[string][]byte),
31 CreatedDirs: make(map[string]bool),
32 OpenedFiles: make(map[string]*MockFile),
Sean McCullough0d95d3a2025-04-30 16:22:28 +000033 TempFiles: []string{},
Sean McCullough2cba6952025-04-25 20:32:10 +000034 FailOn: make(map[string]error),
35 }
36}
37
38func (m *MockFileSystem) Stat(name string) (fs.FileInfo, error) {
39 m.StatCalledWith = append(m.StatCalledWith, name)
40 if err, ok := m.FailOn["Stat"]; ok {
41 return nil, err
42 }
43
44 _, exists := m.Files[name]
45 if exists {
46 return nil, nil // File exists
47 }
48 _, exists = m.CreatedDirs[name]
49 if exists {
50 return nil, nil // Directory exists
51 }
52 return nil, os.ErrNotExist
53}
54
55func (m *MockFileSystem) Mkdir(name string, perm fs.FileMode) error {
56 if err, ok := m.FailOn["Mkdir"]; ok {
57 return err
58 }
59 m.CreatedDirs[name] = true
60 return nil
61}
62
Sean McCulloughc796e7f2025-04-30 08:44:06 -070063func (m *MockFileSystem) MkdirAll(name string, perm fs.FileMode) error {
64 if err, ok := m.FailOn["MkdirAll"]; ok {
65 return err
66 }
67 m.CreatedDirs[name] = true
68 return nil
69}
70
Sean McCullough2cba6952025-04-25 20:32:10 +000071func (m *MockFileSystem) ReadFile(name string) ([]byte, error) {
72 if err, ok := m.FailOn["ReadFile"]; ok {
73 return nil, err
74 }
75
76 data, exists := m.Files[name]
77 if !exists {
78 return nil, fmt.Errorf("file not found: %s", name)
79 }
80 return data, nil
81}
82
83func (m *MockFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
84 if err, ok := m.FailOn["WriteFile"]; ok {
85 return err
86 }
87 m.Files[name] = data
88 return nil
89}
90
91// MockFile implements a simple in-memory file for testing
92type MockFile struct {
93 name string
94 buffer *bytes.Buffer
95 fs *MockFileSystem
96 position int64
97}
98
99// MockFileContents represents in-memory file contents for testing
100type MockFileContents struct {
101 name string
102 contents string
103}
104
105func (m *MockFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
106 if err, ok := m.FailOn["OpenFile"]; ok {
107 return nil, err
108 }
109
110 // Initialize the file content if it doesn't exist and we're not in read-only mode
111 if _, exists := m.Files[name]; !exists && (flag&os.O_CREATE != 0) {
112 m.Files[name] = []byte{}
113 }
114
115 data, exists := m.Files[name]
116 if !exists {
117 return nil, fmt.Errorf("file not found: %s", name)
118 }
119
120 // For OpenFile, we'll just use WriteFile to simulate file operations
121 // The actual file handle isn't used for much in the sshtheater code
122 // but we still need to return a valid file handle
123 tmpFile, err := os.CreateTemp("", "mockfile-*")
124 if err != nil {
125 return nil, err
126 }
127 if _, err := tmpFile.Write(data); err != nil {
128 tmpFile.Close()
129 return nil, err
130 }
131 if _, err := tmpFile.Seek(0, 0); err != nil {
132 tmpFile.Close()
133 return nil, err
134 }
135
136 return tmpFile, nil
137}
138
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000139func (m *MockFileSystem) TempFile(dir, pattern string) (*os.File, error) {
140 if err, ok := m.FailOn["TempFile"]; ok {
141 return nil, err
142 }
143
144 // Create an actual temporary file for testing purposes
145 tmpFile, err := os.CreateTemp(dir, pattern)
146 if err != nil {
147 return nil, err
148 }
149
150 // Record the temp file path
151 m.TempFiles = append(m.TempFiles, tmpFile.Name())
152
153 return tmpFile, nil
154}
155
156func (m *MockFileSystem) Rename(oldpath, newpath string) error {
157 if err, ok := m.FailOn["Rename"]; ok {
158 return err
159 }
160
161 // If the old path exists in our mock file system, move its contents
162 if data, exists := m.Files[oldpath]; exists {
163 m.Files[newpath] = data
164 delete(m.Files, oldpath)
165 }
166
167 return nil
168}
169
170func (m *MockFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
171 if err, ok := m.FailOn["SafeWriteFile"]; ok {
172 return err
173 }
174
175 // For the mock, we'll create a backup if the file exists
176 if existingData, exists := m.Files[name]; exists {
177 backupName := name + ".bak"
178 m.Files[backupName] = existingData
179 }
180
181 // Write the new data
182 m.Files[name] = data
183
184 return nil
185}
186
Sean McCullough2cba6952025-04-25 20:32:10 +0000187// MockKeyGenerator implements KeyGenerator interface for testing
188type MockKeyGenerator struct {
189 privateKey *rsa.PrivateKey
190 publicKey ssh.PublicKey
191 FailOn map[string]error
192}
193
194func NewMockKeyGenerator(privateKey *rsa.PrivateKey, publicKey ssh.PublicKey) *MockKeyGenerator {
195 return &MockKeyGenerator{
196 privateKey: privateKey,
197 publicKey: publicKey,
198 FailOn: make(map[string]error),
199 }
200}
201
202func (m *MockKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
203 if err, ok := m.FailOn["GeneratePrivateKey"]; ok {
204 return nil, err
205 }
206 return m.privateKey, nil
207}
208
209func (m *MockKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
210 if err, ok := m.FailOn["GeneratePublicKey"]; ok {
211 return nil, err
212 }
213 return m.publicKey, nil
214}
215
216// setupMocks sets up common mocks for testing
217func setupMocks(t *testing.T) (*MockFileSystem, *MockKeyGenerator, *rsa.PrivateKey) {
218 // Generate a real private key using real random
219 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
220 if err != nil {
221 t.Fatalf("Failed to generate test private key: %v", err)
222 }
223
224 // Generate a test public key
225 publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
226 if err != nil {
227 t.Fatalf("Failed to generate test public key: %v", err)
228 }
229
230 // Create mocks
231 mockFS := NewMockFileSystem()
232 mockKG := NewMockKeyGenerator(privateKey, publicKey)
233
234 return mockFS, mockKG, privateKey
235}
236
237// Helper function to setup a basic SSHTheater for testing
238func setupTestSSHTheater(t *testing.T) (*SSHTheater, *MockFileSystem, *MockKeyGenerator) {
239 mockFS, mockKG, _ := setupMocks(t)
240
241 // Setup home dir in mock filesystem
242 homePath := "/home/testuser"
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700243 sketchDir := filepath.Join(homePath, ".config/sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +0000244 mockFS.CreatedDirs[sketchDir] = true
245
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000246 // Create empty files so the tests don't fail
247 sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
248 mockFS.Files[sketchConfigPath] = []byte("")
249 knownHostsPath := filepath.Join(sketchDir, "known_hosts")
250 mockFS.Files[knownHostsPath] = []byte("")
251
Sean McCullough2cba6952025-04-25 20:32:10 +0000252 // Set HOME environment variable for the test
253 oldHome := os.Getenv("HOME")
254 os.Setenv("HOME", homePath)
255 t.Cleanup(func() { os.Setenv("HOME", oldHome) })
256
257 // Create SSH Theater with mocks
258 ssh, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
259 if err != nil {
260 t.Fatalf("Failed to create SSHTheater: %v", err)
261 }
262
263 return ssh, mockFS, mockKG
264}
265
266func TestNewSSHTheatherCreatesRequiredDirectories(t *testing.T) {
267 mockFS, mockKG, _ := setupMocks(t)
268
269 // Set HOME environment variable for the test
270 oldHome := os.Getenv("HOME")
271 os.Setenv("HOME", "/home/testuser")
272 defer func() { os.Setenv("HOME", oldHome) }()
273
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000274 // Create empty files so the test doesn't fail
275 sketchDir := "/home/testuser/.config/sketch"
276 sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
277 mockFS.Files[sketchConfigPath] = []byte("")
278 knownHostsPath := filepath.Join(sketchDir, "known_hosts")
279 mockFS.Files[knownHostsPath] = []byte("")
280
Sean McCullough2cba6952025-04-25 20:32:10 +0000281 // Create theater
282 _, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
283 if err != nil {
284 t.Fatalf("Failed to create SSHTheater: %v", err)
285 }
286
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700287 // Check if the .config/sketch directory was created
288 expectedDir := "/home/testuser/.config/sketch"
Sean McCullough2cba6952025-04-25 20:32:10 +0000289 if !mockFS.CreatedDirs[expectedDir] {
290 t.Errorf("Expected directory %s to be created", expectedDir)
291 }
292}
293
294func TestCreateKeyPairIfMissing(t *testing.T) {
295 ssh, mockFS, _ := setupTestSSHTheater(t)
296
297 // Test key pair creation
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700298 keyPath := "/home/testuser/.config/sketch/test_key"
Sean McCullough2cba6952025-04-25 20:32:10 +0000299 _, err := ssh.createKeyPairIfMissing(keyPath)
300 if err != nil {
301 t.Fatalf("Failed to create key pair: %v", err)
302 }
303
304 // Verify private key file was created
305 if _, exists := mockFS.Files[keyPath]; !exists {
306 t.Errorf("Private key file not created at %s", keyPath)
307 }
308
309 // Verify public key file was created
310 pubKeyPath := keyPath + ".pub"
311 if _, exists := mockFS.Files[pubKeyPath]; !exists {
312 t.Errorf("Public key file not created at %s", pubKeyPath)
313 }
314
315 // Verify public key content format
316 pubKeyContent, _ := mockFS.ReadFile(pubKeyPath)
317 if !bytes.HasPrefix(pubKeyContent, []byte("ssh-rsa ")) {
318 t.Errorf("Public key does not have expected format, got: %s", pubKeyContent)
319 }
320}
321
322// TestAddContainerToSSHConfig tests that the container gets added to the SSH config
323// This test uses a direct approach since the OpenFile mocking is complex
324func TestAddContainerToSSHConfig(t *testing.T) {
325 // Create a temporary directory for test files
326 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
327 if err != nil {
328 t.Fatalf("Failed to create temp dir: %v", err)
329 }
330 defer os.RemoveAll(tempDir)
331
332 // Create real files in temp directory
333 configPath := filepath.Join(tempDir, "ssh_config")
334 initialConfig := `# SSH Config
335Host existing-host
336 HostName example.com
337 User testuser
338`
Autoformatter33f71722025-04-25 23:23:22 +0000339 if err := os.WriteFile(configPath, []byte(initialConfig), 0o644); err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000340 t.Fatalf("Failed to write initial config: %v", err)
341 }
342
343 // Create a theater with the real filesystem but custom paths
344 ssh := &SSHTheater{
345 cntrName: "test-container",
346 sshHost: "localhost",
347 sshPort: "2222",
348 sshConfigPath: configPath,
349 userIdentityPath: filepath.Join(tempDir, "user_identity"),
350 fs: &RealFileSystem{},
351 kg: &RealKeyGenerator{},
352 }
353
354 // Add container to SSH config
355 err = ssh.addContainerToSSHConfig()
356 if err != nil {
357 t.Fatalf("Failed to add container to SSH config: %v", err)
358 }
359
360 // Read the updated file
361 configData, err := os.ReadFile(configPath)
362 if err != nil {
363 t.Fatalf("Failed to read updated config: %v", err)
364 }
365 configStr := string(configData)
366
367 // Check for expected values
368 if !strings.Contains(configStr, "Host test-container") {
369 t.Errorf("Container host entry not found in config")
370 }
371
372 if !strings.Contains(configStr, "HostName localhost") {
373 t.Errorf("HostName not correctly added to SSH config")
374 }
375
376 if !strings.Contains(configStr, "Port 2222") {
377 t.Errorf("Port not correctly added to SSH config")
378 }
379
380 if !strings.Contains(configStr, "User root") {
381 t.Errorf("User not correctly set to root in SSH config")
382 }
383
384 // Check if identity file path is correct
385 identityLine := "IdentityFile " + ssh.userIdentityPath
386 if !strings.Contains(configStr, identityLine) {
387 t.Errorf("Identity file path not correctly added to SSH config")
388 }
389}
390
391func TestAddContainerToKnownHosts(t *testing.T) {
392 // Skip this test as it requires more complex setup
393 // The TestSSHTheaterCleanup test covers the addContainerToKnownHosts
394 // functionality in a more integrated way
395 t.Skip("This test requires more complex setup, integrated test coverage exists in TestSSHTheaterCleanup")
396}
397
398func TestRemoveContainerFromSSHConfig(t *testing.T) {
399 // Create a temporary directory for test files
400 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
401 if err != nil {
402 t.Fatalf("Failed to create temp dir: %v", err)
403 }
404 defer os.RemoveAll(tempDir)
405
406 // Create paths for test files
407 sshConfigPath := filepath.Join(tempDir, "ssh_config")
408 userIdentityPath := filepath.Join(tempDir, "user_identity")
409 knownHostsPath := filepath.Join(tempDir, "known_hosts")
410
411 // Create initial SSH config with container entry
412 cntrName := "test-container"
413 sshHost := "localhost"
414 sshPort := "2222"
415
416 initialConfig := fmt.Sprintf(
417 `Host existing-host
418 HostName example.com
419 User testuser
420
421Host %s
422 HostName %s
423 User root
424 Port %s
425 IdentityFile %s
426 UserKnownHostsFile %s
427`,
428 cntrName, sshHost, sshPort, userIdentityPath, knownHostsPath,
429 )
430
Autoformatter33f71722025-04-25 23:23:22 +0000431 if err := os.WriteFile(sshConfigPath, []byte(initialConfig), 0o644); err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000432 t.Fatalf("Failed to write initial SSH config: %v", err)
433 }
434
435 // Create a theater with the real filesystem but custom paths
436 ssh := &SSHTheater{
437 cntrName: cntrName,
438 sshHost: sshHost,
439 sshPort: sshPort,
440 sshConfigPath: sshConfigPath,
441 userIdentityPath: userIdentityPath,
442 knownHostsPath: knownHostsPath,
443 fs: &RealFileSystem{},
444 }
445
446 // Remove container from SSH config
447 err = ssh.removeContainerFromSSHConfig()
448 if err != nil {
449 t.Fatalf("Failed to remove container from SSH config: %v", err)
450 }
451
452 // Read the updated file
453 configData, err := os.ReadFile(sshConfigPath)
454 if err != nil {
455 t.Fatalf("Failed to read updated config: %v", err)
456 }
457 configStr := string(configData)
458
459 // Check if the container host entry was removed
460 if strings.Contains(configStr, "Host "+cntrName) {
461 t.Errorf("Container host not removed from SSH config")
462 }
463
464 // Check if existing host remains
465 if !strings.Contains(configStr, "Host existing-host") {
466 t.Errorf("Existing host entry affected by container removal")
467 }
468}
469
470func TestRemoveContainerFromKnownHosts(t *testing.T) {
471 ssh, mockFS, _ := setupTestSSHTheater(t)
472
473 // Setup server public key
474 privateKey, _ := ssh.kg.GeneratePrivateKey(2048)
475 publicKey, _ := ssh.kg.GeneratePublicKey(&privateKey.PublicKey)
476 ssh.serverPublicKey = publicKey
477
478 // Create host line to be removed
479 hostLine := "[localhost]:2222 ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ..."
480 otherLine := "otherhost ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ..."
481
482 // Set initial content with the line to be removed
483 initialContent := otherLine + "\n" + hostLine
484 mockFS.Files[ssh.knownHostsPath] = []byte(initialContent)
485
486 // Add the host to test remove function
487 err := ssh.addContainerToKnownHosts()
488 if err != nil {
489 t.Fatalf("Failed to add container to known_hosts for removal test: %v", err)
490 }
491
492 // Now remove it
493 err = ssh.removeContainerFromKnownHosts()
494 if err != nil {
495 t.Fatalf("Failed to remove container from known_hosts: %v", err)
496 }
497
498 // Verify content
499 updatedContent, _ := mockFS.ReadFile(ssh.knownHostsPath)
500 content := string(updatedContent)
501
502 hostPattern := ssh.sshHost + ":" + ssh.sshPort
503 if strings.Contains(content, hostPattern) {
504 t.Errorf("Container entry not removed from known_hosts")
505 }
506
507 // Verify other content remains
508 if !strings.Contains(content, otherLine) {
509 t.Errorf("Other known_hosts entries improperly removed")
510 }
511}
512
513func TestSSHTheaterCleanup(t *testing.T) {
514 // Create a temporary directory for test files
515 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
516 if err != nil {
517 t.Fatalf("Failed to create temp dir: %v", err)
518 }
519 defer os.RemoveAll(tempDir)
520
521 // Create paths for test files
522 sshConfigPath := filepath.Join(tempDir, "ssh_config")
523 userIdentityPath := filepath.Join(tempDir, "user_identity")
524 knownHostsPath := filepath.Join(tempDir, "known_hosts")
525 serverIdentityPath := filepath.Join(tempDir, "server_identity")
526
527 // Create private key for server key
528 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
529 if err != nil {
530 t.Fatalf("Failed to generate private key: %v", err)
531 }
532 publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
533 if err != nil {
534 t.Fatalf("Failed to generate public key: %v", err)
535 }
536
537 // Initialize files
Autoformatter33f71722025-04-25 23:23:22 +0000538 os.WriteFile(sshConfigPath, []byte("initial ssh_config content"), 0o644)
539 os.WriteFile(knownHostsPath, []byte("initial known_hosts content"), 0o644)
Sean McCullough2cba6952025-04-25 20:32:10 +0000540
541 // Create a theater with the real filesystem but custom paths
542 cntrName := "test-container"
543 sshHost := "localhost"
544 sshPort := "2222"
545
546 ssh := &SSHTheater{
547 cntrName: cntrName,
548 sshHost: sshHost,
549 sshPort: sshPort,
550 sshConfigPath: sshConfigPath,
551 userIdentityPath: userIdentityPath,
552 knownHostsPath: knownHostsPath,
553 serverIdentityPath: serverIdentityPath,
554 serverPublicKey: publicKey,
555 fs: &RealFileSystem{},
556 kg: &RealKeyGenerator{},
557 }
558
559 // Add container to configs
560 err = ssh.addContainerToSSHConfig()
561 if err != nil {
562 t.Fatalf("Failed to set up SSH config for cleanup test: %v", err)
563 }
564
565 err = ssh.addContainerToKnownHosts()
566 if err != nil {
567 t.Fatalf("Failed to set up known_hosts for cleanup test: %v", err)
568 }
569
570 // Execute cleanup
571 err = ssh.Cleanup()
572 if err != nil {
573 t.Fatalf("Cleanup failed: %v", err)
574 }
575
576 // Read updated files
577 configData, err := os.ReadFile(sshConfigPath)
578 if err != nil {
579 t.Fatalf("Failed to read updated SSH config: %v", err)
580 }
581 configStr := string(configData)
582
583 // Check container was removed from SSH config
584 hostEntry := "Host " + ssh.cntrName
585 if strings.Contains(configStr, hostEntry) {
586 t.Errorf("Container not removed from SSH config during cleanup")
587 }
588
589 // Verify known hosts was updated
590 knownHostsContent, err := os.ReadFile(knownHostsPath)
591 if err != nil {
592 t.Fatalf("Failed to read updated known_hosts: %v", err)
593 }
594
595 expectedHostPattern := ssh.sshHost + ":" + ssh.sshPort
596 if strings.Contains(string(knownHostsContent), expectedHostPattern) {
597 t.Errorf("Container not removed from known_hosts during cleanup")
598 }
599}
600
Sean McCullough15c95282025-05-08 16:48:38 -0700601func TestCheckForInclude_userAccepts(t *testing.T) {
Sean McCullough2cba6952025-04-25 20:32:10 +0000602 mockFS := NewMockFileSystem()
603
604 // Set HOME environment variable for the test
605 oldHome := os.Getenv("HOME")
606 os.Setenv("HOME", "/home/testuser")
607 defer func() { os.Setenv("HOME", oldHome) }()
608
609 // Create a mock ssh config with the expected include
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700610 includeLine := "Include /home/testuser/.config/sketch/ssh_config"
Sean McCullough2cba6952025-04-25 20:32:10 +0000611 initialConfig := fmt.Sprintf("%s\nHost example\n HostName example.com\n", includeLine)
612
613 // Add the config to the mock filesystem
614 sshConfigPath := "/home/testuser/.ssh/config"
615 mockFS.Files[sshConfigPath] = []byte(initialConfig)
Sean McCullough15c95282025-05-08 16:48:38 -0700616 stdinReader := bufio.NewReader(strings.NewReader("y\n"))
Sean McCullough2cba6952025-04-25 20:32:10 +0000617 // Test the function with our mock
Sean McCullough15c95282025-05-08 16:48:38 -0700618 err := CheckForIncludeWithFS(mockFS, *stdinReader)
Sean McCullough2cba6952025-04-25 20:32:10 +0000619 if err != nil {
620 t.Fatalf("CheckForInclude failed with proper include: %v", err)
621 }
622
623 // Now test with config missing the include
624 mockFS.Files[sshConfigPath] = []byte("Host example\n HostName example.com\n")
625
Sean McCullough15c95282025-05-08 16:48:38 -0700626 stdinReader = bufio.NewReader(strings.NewReader("y\n"))
627 err = CheckForIncludeWithFS(mockFS, *stdinReader)
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700628 if err != nil {
629 t.Fatalf("CheckForInclude should have created the Include line without an error")
Sean McCullough2cba6952025-04-25 20:32:10 +0000630 }
631}
632
Sean McCullough15c95282025-05-08 16:48:38 -0700633func TestCheckForInclude_userDeclines(t *testing.T) {
634 mockFS := NewMockFileSystem()
635
636 // Set HOME environment variable for the test
637 oldHome := os.Getenv("HOME")
638 os.Setenv("HOME", "/home/testuser")
639 defer func() { os.Setenv("HOME", oldHome) }()
640
641 // Create a mock ssh config with the expected include
642 includeLine := "Include /home/testuser/.config/sketch/ssh_config"
643 initialConfig := fmt.Sprintf("%s\nHost example\n HostName example.com\n", includeLine)
644
645 // Add the config to the mock filesystem
646 sshConfigPath := "/home/testuser/.ssh/config"
647 mockFS.Files[sshConfigPath] = []byte(initialConfig)
648 stdinReader := bufio.NewReader(strings.NewReader("n\n"))
649 // Test the function with our mock
650 err := CheckForIncludeWithFS(mockFS, *stdinReader)
651 if err != nil {
652 t.Fatalf("CheckForInclude failed with proper include: %v", err)
653 }
654
655 // Now test with config missing the include
656 missingInclude := []byte("Host example\n HostName example.com\n")
657 mockFS.Files[sshConfigPath] = missingInclude
658
659 stdinReader = bufio.NewReader(strings.NewReader("n\n"))
660 err = CheckForIncludeWithFS(mockFS, *stdinReader)
661 if err == nil {
662 t.Errorf("CheckForInclude should have returned an error")
663 }
664 if !bytes.Equal(mockFS.Files[sshConfigPath], missingInclude) {
665 t.Errorf("ssh config should not have been edited")
666 }
667}
668
Sean McCullough2cba6952025-04-25 20:32:10 +0000669func TestSSHTheaterWithErrors(t *testing.T) {
670 // Test directory creation failure
671 mockFS := NewMockFileSystem()
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700672 mockFS.FailOn["MkdirAll"] = fmt.Errorf("mock mkdir error")
Sean McCullough2cba6952025-04-25 20:32:10 +0000673 mockKG := NewMockKeyGenerator(nil, nil)
674
675 // Set HOME environment variable for the test
676 oldHome := os.Getenv("HOME")
677 os.Setenv("HOME", "/home/testuser")
678 defer func() { os.Setenv("HOME", oldHome) }()
679
680 // Try to create theater with failing FS
681 _, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
682 if err == nil || !strings.Contains(err.Error(), "mock mkdir error") {
683 t.Errorf("Should have failed with mkdir error, got: %v", err)
684 }
685
686 // Test key generation failure
687 mockFS = NewMockFileSystem()
688 mockKG = NewMockKeyGenerator(nil, nil)
689 mockKG.FailOn["GeneratePrivateKey"] = fmt.Errorf("mock key generation error")
690
691 _, err = newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
692 if err == nil || !strings.Contains(err.Error(), "key generation error") {
693 t.Errorf("Should have failed with key generation error, got: %v", err)
694 }
695}
696
697func TestRealSSHTheatherInit(t *testing.T) {
698 // This is a basic smoke test for the real NewSSHTheather method
699 // We'll mock the os.Getenv("HOME") but use real dependencies otherwise
700
701 // Create a temp dir to use as HOME
702 tempDir, err := os.MkdirTemp("", "sshtheater-test-home-*")
703 if err != nil {
704 t.Fatalf("Failed to create temp dir: %v", err)
705 }
706 defer os.RemoveAll(tempDir)
707
708 // Set HOME environment for the test
709 oldHome := os.Getenv("HOME")
710 os.Setenv("HOME", tempDir)
711 defer os.Setenv("HOME", oldHome)
712
713 // Create the theater
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +0000714 theater, err := NewSSHTheater("test-container", "localhost", "2222")
Sean McCullough2cba6952025-04-25 20:32:10 +0000715 if err != nil {
716 t.Fatalf("Failed to create real SSHTheather: %v", err)
717 }
718
719 // Just some basic checks
720 if theater == nil {
721 t.Fatal("Theater is nil")
722 }
723
724 // Check if the sketch dir was created
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700725 sketchDir := filepath.Join(tempDir, ".config/sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +0000726 if _, err := os.Stat(sketchDir); os.IsNotExist(err) {
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700727 t.Errorf(".config/sketch directory not created")
Sean McCullough2cba6952025-04-25 20:32:10 +0000728 }
729
730 // Check if key files were created
731 if _, err := os.Stat(theater.serverIdentityPath); os.IsNotExist(err) {
732 t.Errorf("Server identity file not created")
733 }
734
735 if _, err := os.Stat(theater.userIdentityPath); os.IsNotExist(err) {
736 t.Errorf("User identity file not created")
737 }
738
739 // Check if the config files were created
740 if _, err := os.Stat(theater.sshConfigPath); os.IsNotExist(err) {
741 t.Errorf("SSH config file not created")
742 }
743
744 if _, err := os.Stat(theater.knownHostsPath); os.IsNotExist(err) {
745 t.Errorf("Known hosts file not created")
746 }
747
748 // Clean up
749 err = theater.Cleanup()
750 if err != nil {
751 t.Fatalf("Failed to clean up theater: %v", err)
752 }
753}