blob: 66a99a3769e73b5f9e5ba01b9f128f99b3f6f0b3 [file] [log] [blame]
Sean McCullough2cba6952025-04-25 20:32:10 +00001package dockerimg
2
3import (
4 "bytes"
5 "crypto/rand"
6 "crypto/rsa"
7 "fmt"
8 "io/fs"
9 "os"
10 "path/filepath"
11 "strings"
12 "testing"
13
14 "golang.org/x/crypto/ssh"
15)
16
17// MockFileSystem implements the FileSystem interface for testing
18type MockFileSystem struct {
19 Files map[string][]byte
20 CreatedDirs map[string]bool
21 OpenedFiles map[string]*MockFile
22 StatCalledWith []string
23 FailOn map[string]error // Map of function name to error to simulate failures
24}
25
26func NewMockFileSystem() *MockFileSystem {
27 return &MockFileSystem{
28 Files: make(map[string][]byte),
29 CreatedDirs: make(map[string]bool),
30 OpenedFiles: make(map[string]*MockFile),
31 FailOn: make(map[string]error),
32 }
33}
34
35func (m *MockFileSystem) Stat(name string) (fs.FileInfo, error) {
36 m.StatCalledWith = append(m.StatCalledWith, name)
37 if err, ok := m.FailOn["Stat"]; ok {
38 return nil, err
39 }
40
41 _, exists := m.Files[name]
42 if exists {
43 return nil, nil // File exists
44 }
45 _, exists = m.CreatedDirs[name]
46 if exists {
47 return nil, nil // Directory exists
48 }
49 return nil, os.ErrNotExist
50}
51
52func (m *MockFileSystem) Mkdir(name string, perm fs.FileMode) error {
53 if err, ok := m.FailOn["Mkdir"]; ok {
54 return err
55 }
56 m.CreatedDirs[name] = true
57 return nil
58}
59
60func (m *MockFileSystem) ReadFile(name string) ([]byte, error) {
61 if err, ok := m.FailOn["ReadFile"]; ok {
62 return nil, err
63 }
64
65 data, exists := m.Files[name]
66 if !exists {
67 return nil, fmt.Errorf("file not found: %s", name)
68 }
69 return data, nil
70}
71
72func (m *MockFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
73 if err, ok := m.FailOn["WriteFile"]; ok {
74 return err
75 }
76 m.Files[name] = data
77 return nil
78}
79
80// MockFile implements a simple in-memory file for testing
81type MockFile struct {
82 name string
83 buffer *bytes.Buffer
84 fs *MockFileSystem
85 position int64
86}
87
88// MockFileContents represents in-memory file contents for testing
89type MockFileContents struct {
90 name string
91 contents string
92}
93
94func (m *MockFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
95 if err, ok := m.FailOn["OpenFile"]; ok {
96 return nil, err
97 }
98
99 // Initialize the file content if it doesn't exist and we're not in read-only mode
100 if _, exists := m.Files[name]; !exists && (flag&os.O_CREATE != 0) {
101 m.Files[name] = []byte{}
102 }
103
104 data, exists := m.Files[name]
105 if !exists {
106 return nil, fmt.Errorf("file not found: %s", name)
107 }
108
109 // For OpenFile, we'll just use WriteFile to simulate file operations
110 // The actual file handle isn't used for much in the sshtheater code
111 // but we still need to return a valid file handle
112 tmpFile, err := os.CreateTemp("", "mockfile-*")
113 if err != nil {
114 return nil, err
115 }
116 if _, err := tmpFile.Write(data); err != nil {
117 tmpFile.Close()
118 return nil, err
119 }
120 if _, err := tmpFile.Seek(0, 0); err != nil {
121 tmpFile.Close()
122 return nil, err
123 }
124
125 return tmpFile, nil
126}
127
128// MockKeyGenerator implements KeyGenerator interface for testing
129type MockKeyGenerator struct {
130 privateKey *rsa.PrivateKey
131 publicKey ssh.PublicKey
132 FailOn map[string]error
133}
134
135func NewMockKeyGenerator(privateKey *rsa.PrivateKey, publicKey ssh.PublicKey) *MockKeyGenerator {
136 return &MockKeyGenerator{
137 privateKey: privateKey,
138 publicKey: publicKey,
139 FailOn: make(map[string]error),
140 }
141}
142
143func (m *MockKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
144 if err, ok := m.FailOn["GeneratePrivateKey"]; ok {
145 return nil, err
146 }
147 return m.privateKey, nil
148}
149
150func (m *MockKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
151 if err, ok := m.FailOn["GeneratePublicKey"]; ok {
152 return nil, err
153 }
154 return m.publicKey, nil
155}
156
157// setupMocks sets up common mocks for testing
158func setupMocks(t *testing.T) (*MockFileSystem, *MockKeyGenerator, *rsa.PrivateKey) {
159 // Generate a real private key using real random
160 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
161 if err != nil {
162 t.Fatalf("Failed to generate test private key: %v", err)
163 }
164
165 // Generate a test public key
166 publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
167 if err != nil {
168 t.Fatalf("Failed to generate test public key: %v", err)
169 }
170
171 // Create mocks
172 mockFS := NewMockFileSystem()
173 mockKG := NewMockKeyGenerator(privateKey, publicKey)
174
175 return mockFS, mockKG, privateKey
176}
177
178// Helper function to setup a basic SSHTheater for testing
179func setupTestSSHTheater(t *testing.T) (*SSHTheater, *MockFileSystem, *MockKeyGenerator) {
180 mockFS, mockKG, _ := setupMocks(t)
181
182 // Setup home dir in mock filesystem
183 homePath := "/home/testuser"
184 sketchDir := filepath.Join(homePath, ".sketch")
185 mockFS.CreatedDirs[sketchDir] = true
186
187 // Set HOME environment variable for the test
188 oldHome := os.Getenv("HOME")
189 os.Setenv("HOME", homePath)
190 t.Cleanup(func() { os.Setenv("HOME", oldHome) })
191
192 // Create SSH Theater with mocks
193 ssh, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
194 if err != nil {
195 t.Fatalf("Failed to create SSHTheater: %v", err)
196 }
197
198 return ssh, mockFS, mockKG
199}
200
201func TestNewSSHTheatherCreatesRequiredDirectories(t *testing.T) {
202 mockFS, mockKG, _ := setupMocks(t)
203
204 // Set HOME environment variable for the test
205 oldHome := os.Getenv("HOME")
206 os.Setenv("HOME", "/home/testuser")
207 defer func() { os.Setenv("HOME", oldHome) }()
208
209 // Create theater
210 _, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
211 if err != nil {
212 t.Fatalf("Failed to create SSHTheater: %v", err)
213 }
214
215 // Check if the .sketch directory was created
216 expectedDir := "/home/testuser/.sketch"
217 if !mockFS.CreatedDirs[expectedDir] {
218 t.Errorf("Expected directory %s to be created", expectedDir)
219 }
220}
221
222func TestCreateKeyPairIfMissing(t *testing.T) {
223 ssh, mockFS, _ := setupTestSSHTheater(t)
224
225 // Test key pair creation
226 keyPath := "/home/testuser/.sketch/test_key"
227 _, err := ssh.createKeyPairIfMissing(keyPath)
228 if err != nil {
229 t.Fatalf("Failed to create key pair: %v", err)
230 }
231
232 // Verify private key file was created
233 if _, exists := mockFS.Files[keyPath]; !exists {
234 t.Errorf("Private key file not created at %s", keyPath)
235 }
236
237 // Verify public key file was created
238 pubKeyPath := keyPath + ".pub"
239 if _, exists := mockFS.Files[pubKeyPath]; !exists {
240 t.Errorf("Public key file not created at %s", pubKeyPath)
241 }
242
243 // Verify public key content format
244 pubKeyContent, _ := mockFS.ReadFile(pubKeyPath)
245 if !bytes.HasPrefix(pubKeyContent, []byte("ssh-rsa ")) {
246 t.Errorf("Public key does not have expected format, got: %s", pubKeyContent)
247 }
248}
249
250// TestAddContainerToSSHConfig tests that the container gets added to the SSH config
251// This test uses a direct approach since the OpenFile mocking is complex
252func TestAddContainerToSSHConfig(t *testing.T) {
253 // Create a temporary directory for test files
254 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
255 if err != nil {
256 t.Fatalf("Failed to create temp dir: %v", err)
257 }
258 defer os.RemoveAll(tempDir)
259
260 // Create real files in temp directory
261 configPath := filepath.Join(tempDir, "ssh_config")
262 initialConfig := `# SSH Config
263Host existing-host
264 HostName example.com
265 User testuser
266`
Autoformatter33f71722025-04-25 23:23:22 +0000267 if err := os.WriteFile(configPath, []byte(initialConfig), 0o644); err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000268 t.Fatalf("Failed to write initial config: %v", err)
269 }
270
271 // Create a theater with the real filesystem but custom paths
272 ssh := &SSHTheater{
273 cntrName: "test-container",
274 sshHost: "localhost",
275 sshPort: "2222",
276 sshConfigPath: configPath,
277 userIdentityPath: filepath.Join(tempDir, "user_identity"),
278 fs: &RealFileSystem{},
279 kg: &RealKeyGenerator{},
280 }
281
282 // Add container to SSH config
283 err = ssh.addContainerToSSHConfig()
284 if err != nil {
285 t.Fatalf("Failed to add container to SSH config: %v", err)
286 }
287
288 // Read the updated file
289 configData, err := os.ReadFile(configPath)
290 if err != nil {
291 t.Fatalf("Failed to read updated config: %v", err)
292 }
293 configStr := string(configData)
294
295 // Check for expected values
296 if !strings.Contains(configStr, "Host test-container") {
297 t.Errorf("Container host entry not found in config")
298 }
299
300 if !strings.Contains(configStr, "HostName localhost") {
301 t.Errorf("HostName not correctly added to SSH config")
302 }
303
304 if !strings.Contains(configStr, "Port 2222") {
305 t.Errorf("Port not correctly added to SSH config")
306 }
307
308 if !strings.Contains(configStr, "User root") {
309 t.Errorf("User not correctly set to root in SSH config")
310 }
311
312 // Check if identity file path is correct
313 identityLine := "IdentityFile " + ssh.userIdentityPath
314 if !strings.Contains(configStr, identityLine) {
315 t.Errorf("Identity file path not correctly added to SSH config")
316 }
317}
318
319func TestAddContainerToKnownHosts(t *testing.T) {
320 // Skip this test as it requires more complex setup
321 // The TestSSHTheaterCleanup test covers the addContainerToKnownHosts
322 // functionality in a more integrated way
323 t.Skip("This test requires more complex setup, integrated test coverage exists in TestSSHTheaterCleanup")
324}
325
326func TestRemoveContainerFromSSHConfig(t *testing.T) {
327 // Create a temporary directory for test files
328 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
329 if err != nil {
330 t.Fatalf("Failed to create temp dir: %v", err)
331 }
332 defer os.RemoveAll(tempDir)
333
334 // Create paths for test files
335 sshConfigPath := filepath.Join(tempDir, "ssh_config")
336 userIdentityPath := filepath.Join(tempDir, "user_identity")
337 knownHostsPath := filepath.Join(tempDir, "known_hosts")
338
339 // Create initial SSH config with container entry
340 cntrName := "test-container"
341 sshHost := "localhost"
342 sshPort := "2222"
343
344 initialConfig := fmt.Sprintf(
345 `Host existing-host
346 HostName example.com
347 User testuser
348
349Host %s
350 HostName %s
351 User root
352 Port %s
353 IdentityFile %s
354 UserKnownHostsFile %s
355`,
356 cntrName, sshHost, sshPort, userIdentityPath, knownHostsPath,
357 )
358
Autoformatter33f71722025-04-25 23:23:22 +0000359 if err := os.WriteFile(sshConfigPath, []byte(initialConfig), 0o644); err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000360 t.Fatalf("Failed to write initial SSH config: %v", err)
361 }
362
363 // Create a theater with the real filesystem but custom paths
364 ssh := &SSHTheater{
365 cntrName: cntrName,
366 sshHost: sshHost,
367 sshPort: sshPort,
368 sshConfigPath: sshConfigPath,
369 userIdentityPath: userIdentityPath,
370 knownHostsPath: knownHostsPath,
371 fs: &RealFileSystem{},
372 }
373
374 // Remove container from SSH config
375 err = ssh.removeContainerFromSSHConfig()
376 if err != nil {
377 t.Fatalf("Failed to remove container from SSH config: %v", err)
378 }
379
380 // Read the updated file
381 configData, err := os.ReadFile(sshConfigPath)
382 if err != nil {
383 t.Fatalf("Failed to read updated config: %v", err)
384 }
385 configStr := string(configData)
386
387 // Check if the container host entry was removed
388 if strings.Contains(configStr, "Host "+cntrName) {
389 t.Errorf("Container host not removed from SSH config")
390 }
391
392 // Check if existing host remains
393 if !strings.Contains(configStr, "Host existing-host") {
394 t.Errorf("Existing host entry affected by container removal")
395 }
396}
397
398func TestRemoveContainerFromKnownHosts(t *testing.T) {
399 ssh, mockFS, _ := setupTestSSHTheater(t)
400
401 // Setup server public key
402 privateKey, _ := ssh.kg.GeneratePrivateKey(2048)
403 publicKey, _ := ssh.kg.GeneratePublicKey(&privateKey.PublicKey)
404 ssh.serverPublicKey = publicKey
405
406 // Create host line to be removed
407 hostLine := "[localhost]:2222 ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ..."
408 otherLine := "otherhost ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ..."
409
410 // Set initial content with the line to be removed
411 initialContent := otherLine + "\n" + hostLine
412 mockFS.Files[ssh.knownHostsPath] = []byte(initialContent)
413
414 // Add the host to test remove function
415 err := ssh.addContainerToKnownHosts()
416 if err != nil {
417 t.Fatalf("Failed to add container to known_hosts for removal test: %v", err)
418 }
419
420 // Now remove it
421 err = ssh.removeContainerFromKnownHosts()
422 if err != nil {
423 t.Fatalf("Failed to remove container from known_hosts: %v", err)
424 }
425
426 // Verify content
427 updatedContent, _ := mockFS.ReadFile(ssh.knownHostsPath)
428 content := string(updatedContent)
429
430 hostPattern := ssh.sshHost + ":" + ssh.sshPort
431 if strings.Contains(content, hostPattern) {
432 t.Errorf("Container entry not removed from known_hosts")
433 }
434
435 // Verify other content remains
436 if !strings.Contains(content, otherLine) {
437 t.Errorf("Other known_hosts entries improperly removed")
438 }
439}
440
441func TestSSHTheaterCleanup(t *testing.T) {
442 // Create a temporary directory for test files
443 tempDir, err := os.MkdirTemp("", "sshtheater-test-*")
444 if err != nil {
445 t.Fatalf("Failed to create temp dir: %v", err)
446 }
447 defer os.RemoveAll(tempDir)
448
449 // Create paths for test files
450 sshConfigPath := filepath.Join(tempDir, "ssh_config")
451 userIdentityPath := filepath.Join(tempDir, "user_identity")
452 knownHostsPath := filepath.Join(tempDir, "known_hosts")
453 serverIdentityPath := filepath.Join(tempDir, "server_identity")
454
455 // Create private key for server key
456 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
457 if err != nil {
458 t.Fatalf("Failed to generate private key: %v", err)
459 }
460 publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
461 if err != nil {
462 t.Fatalf("Failed to generate public key: %v", err)
463 }
464
465 // Initialize files
Autoformatter33f71722025-04-25 23:23:22 +0000466 os.WriteFile(sshConfigPath, []byte("initial ssh_config content"), 0o644)
467 os.WriteFile(knownHostsPath, []byte("initial known_hosts content"), 0o644)
Sean McCullough2cba6952025-04-25 20:32:10 +0000468
469 // Create a theater with the real filesystem but custom paths
470 cntrName := "test-container"
471 sshHost := "localhost"
472 sshPort := "2222"
473
474 ssh := &SSHTheater{
475 cntrName: cntrName,
476 sshHost: sshHost,
477 sshPort: sshPort,
478 sshConfigPath: sshConfigPath,
479 userIdentityPath: userIdentityPath,
480 knownHostsPath: knownHostsPath,
481 serverIdentityPath: serverIdentityPath,
482 serverPublicKey: publicKey,
483 fs: &RealFileSystem{},
484 kg: &RealKeyGenerator{},
485 }
486
487 // Add container to configs
488 err = ssh.addContainerToSSHConfig()
489 if err != nil {
490 t.Fatalf("Failed to set up SSH config for cleanup test: %v", err)
491 }
492
493 err = ssh.addContainerToKnownHosts()
494 if err != nil {
495 t.Fatalf("Failed to set up known_hosts for cleanup test: %v", err)
496 }
497
498 // Execute cleanup
499 err = ssh.Cleanup()
500 if err != nil {
501 t.Fatalf("Cleanup failed: %v", err)
502 }
503
504 // Read updated files
505 configData, err := os.ReadFile(sshConfigPath)
506 if err != nil {
507 t.Fatalf("Failed to read updated SSH config: %v", err)
508 }
509 configStr := string(configData)
510
511 // Check container was removed from SSH config
512 hostEntry := "Host " + ssh.cntrName
513 if strings.Contains(configStr, hostEntry) {
514 t.Errorf("Container not removed from SSH config during cleanup")
515 }
516
517 // Verify known hosts was updated
518 knownHostsContent, err := os.ReadFile(knownHostsPath)
519 if err != nil {
520 t.Fatalf("Failed to read updated known_hosts: %v", err)
521 }
522
523 expectedHostPattern := ssh.sshHost + ":" + ssh.sshPort
524 if strings.Contains(string(knownHostsContent), expectedHostPattern) {
525 t.Errorf("Container not removed from known_hosts during cleanup")
526 }
527}
528
529func TestCheckForInclude(t *testing.T) {
530 mockFS := NewMockFileSystem()
531
532 // Set HOME environment variable for the test
533 oldHome := os.Getenv("HOME")
534 os.Setenv("HOME", "/home/testuser")
535 defer func() { os.Setenv("HOME", oldHome) }()
536
537 // Create a mock ssh config with the expected include
538 includeLine := "Include /home/testuser/.sketch/ssh_config"
539 initialConfig := fmt.Sprintf("%s\nHost example\n HostName example.com\n", includeLine)
540
541 // Add the config to the mock filesystem
542 sshConfigPath := "/home/testuser/.ssh/config"
543 mockFS.Files[sshConfigPath] = []byte(initialConfig)
544
545 // Test the function with our mock
546 err := CheckForIncludeWithFS(mockFS)
547 if err != nil {
548 t.Fatalf("CheckForInclude failed with proper include: %v", err)
549 }
550
551 // Now test with config missing the include
552 mockFS.Files[sshConfigPath] = []byte("Host example\n HostName example.com\n")
553
554 err = CheckForIncludeWithFS(mockFS)
555 if err == nil {
556 t.Fatalf("CheckForInclude should have returned an error for missing include")
557 }
558}
559
560func TestSSHTheaterWithErrors(t *testing.T) {
561 // Test directory creation failure
562 mockFS := NewMockFileSystem()
563 mockFS.FailOn["Mkdir"] = fmt.Errorf("mock mkdir error")
564 mockKG := NewMockKeyGenerator(nil, nil)
565
566 // Set HOME environment variable for the test
567 oldHome := os.Getenv("HOME")
568 os.Setenv("HOME", "/home/testuser")
569 defer func() { os.Setenv("HOME", oldHome) }()
570
571 // Try to create theater with failing FS
572 _, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
573 if err == nil || !strings.Contains(err.Error(), "mock mkdir error") {
574 t.Errorf("Should have failed with mkdir error, got: %v", err)
575 }
576
577 // Test key generation failure
578 mockFS = NewMockFileSystem()
579 mockKG = NewMockKeyGenerator(nil, nil)
580 mockKG.FailOn["GeneratePrivateKey"] = fmt.Errorf("mock key generation error")
581
582 _, err = newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
583 if err == nil || !strings.Contains(err.Error(), "key generation error") {
584 t.Errorf("Should have failed with key generation error, got: %v", err)
585 }
586}
587
588func TestRealSSHTheatherInit(t *testing.T) {
589 // This is a basic smoke test for the real NewSSHTheather method
590 // We'll mock the os.Getenv("HOME") but use real dependencies otherwise
591
592 // Create a temp dir to use as HOME
593 tempDir, err := os.MkdirTemp("", "sshtheater-test-home-*")
594 if err != nil {
595 t.Fatalf("Failed to create temp dir: %v", err)
596 }
597 defer os.RemoveAll(tempDir)
598
599 // Set HOME environment for the test
600 oldHome := os.Getenv("HOME")
601 os.Setenv("HOME", tempDir)
602 defer os.Setenv("HOME", oldHome)
603
604 // Create the theater
605 theater, err := NewSSHTheather("test-container", "localhost", "2222")
606 if err != nil {
607 t.Fatalf("Failed to create real SSHTheather: %v", err)
608 }
609
610 // Just some basic checks
611 if theater == nil {
612 t.Fatal("Theater is nil")
613 }
614
615 // Check if the sketch dir was created
616 sketchDir := filepath.Join(tempDir, ".sketch")
617 if _, err := os.Stat(sketchDir); os.IsNotExist(err) {
618 t.Errorf(".sketch directory not created")
619 }
620
621 // Check if key files were created
622 if _, err := os.Stat(theater.serverIdentityPath); os.IsNotExist(err) {
623 t.Errorf("Server identity file not created")
624 }
625
626 if _, err := os.Stat(theater.userIdentityPath); os.IsNotExist(err) {
627 t.Errorf("User identity file not created")
628 }
629
630 // Check if the config files were created
631 if _, err := os.Stat(theater.sshConfigPath); os.IsNotExist(err) {
632 t.Errorf("SSH config file not created")
633 }
634
635 if _, err := os.Stat(theater.knownHostsPath); os.IsNotExist(err) {
636 t.Errorf("Known hosts file not created")
637 }
638
639 // Clean up
640 err = theater.Cleanup()
641 if err != nil {
642 t.Fatalf("Failed to clean up theater: %v", err)
643 }
644}