|
| 1 | +// Copyright 2023 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package test |
| 6 | + |
| 7 | +import ( |
| 8 | + "bytes" |
| 9 | + "fmt" |
| 10 | + "os" |
| 11 | + "os/exec" |
| 12 | + "path/filepath" |
| 13 | + "testing" |
| 14 | + |
| 15 | + "golang.org/x/crypto/internal/testenv" |
| 16 | + "golang.org/x/crypto/ssh" |
| 17 | + "golang.org/x/crypto/ssh/testdata" |
| 18 | +) |
| 19 | + |
| 20 | +func sshClient(t *testing.T) string { |
| 21 | + if testing.Short() { |
| 22 | + t.Skip("Skipping test that executes OpenSSH in -short mode") |
| 23 | + } |
| 24 | + sshCLI := os.Getenv("SSH_CLI_PATH") |
| 25 | + if sshCLI == "" { |
| 26 | + sshCLI = "ssh" |
| 27 | + } |
| 28 | + var err error |
| 29 | + sshCLI, err = exec.LookPath(sshCLI) |
| 30 | + if err != nil { |
| 31 | + t.Skipf("Can't find an ssh(1) client to test against: %v", err) |
| 32 | + } |
| 33 | + return sshCLI |
| 34 | +} |
| 35 | + |
| 36 | +func TestSSHCLIAuth(t *testing.T) { |
| 37 | + sshCLI := sshClient(t) |
| 38 | + dir := t.TempDir() |
| 39 | + keyPrivPath := filepath.Join(dir, "rsa") |
| 40 | + |
| 41 | + for fn, content := range map[string][]byte{ |
| 42 | + keyPrivPath: testdata.PEMBytes["rsa"], |
| 43 | + keyPrivPath + ".pub": ssh.MarshalAuthorizedKey(testPublicKeys["rsa"]), |
| 44 | + filepath.Join(dir, "rsa-cert.pub"): testdata.SSHCertificates["rsa-user-testcertificate"], |
| 45 | + } { |
| 46 | + if err := os.WriteFile(fn, content, 0600); err != nil { |
| 47 | + t.Fatalf("WriteFile(%q): %v", fn, err) |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + certChecker := ssh.CertChecker{ |
| 52 | + IsUserAuthority: func(k ssh.PublicKey) bool { |
| 53 | + return bytes.Equal(k.Marshal(), testPublicKeys["ca"].Marshal()) |
| 54 | + }, |
| 55 | + UserKeyFallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
| 56 | + if conn.User() == "testpubkey" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
| 57 | + return nil, nil |
| 58 | + } |
| 59 | + |
| 60 | + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) |
| 61 | + }, |
| 62 | + } |
| 63 | + |
| 64 | + config := &ssh.ServerConfig{ |
| 65 | + PublicKeyCallback: certChecker.Authenticate, |
| 66 | + } |
| 67 | + config.AddHostKey(testSigners["rsa"]) |
| 68 | + |
| 69 | + server, err := newTestServer(config) |
| 70 | + if err != nil { |
| 71 | + t.Fatalf("unable to start test server: %v", err) |
| 72 | + } |
| 73 | + defer server.Close() |
| 74 | + |
| 75 | + port, err := server.port() |
| 76 | + if err != nil { |
| 77 | + t.Fatalf("unable to get server port: %v", err) |
| 78 | + } |
| 79 | + |
| 80 | + // test public key authentication. |
| 81 | + cmd := testenv.Command(t, sshCLI, "-vvv", "-i", keyPrivPath, "-o", "StrictHostKeyChecking=no", |
| 82 | + "-p", port, "[email protected]", "true") |
| 83 | + out, err := cmd.CombinedOutput() |
| 84 | + if err != nil { |
| 85 | + t.Fatalf("public key authentication failed, error: %v, command output %q", err, string(out)) |
| 86 | + } |
| 87 | + // Test SSH user certificate authentication. |
| 88 | + // The username must match one of the principals included in the certificate. |
| 89 | + // The certificate "rsa-user-testcertificate" has "testcertificate" as principal. |
| 90 | + cmd = testenv.Command(t, sshCLI, "-vvv", "-i", keyPrivPath, "-o", "StrictHostKeyChecking=no", |
| 91 | + "-p", port, "[email protected]", "true") |
| 92 | + out, err = cmd.CombinedOutput() |
| 93 | + if err != nil { |
| 94 | + t.Fatalf("user certificate authentication failed, error: %v, command output %q", err, string(out)) |
| 95 | + } |
| 96 | +} |
0 commit comments