Skip to content

validate cert #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions pkg/admission/cert/writer/certwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package writer
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
"net/url"
"time"

admissionregistrationv1beta1 "k8s.io/api/admissionregistration/v1beta1"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -91,8 +94,13 @@ func handleCommon(webhook *admissionregistrationv1beta1.Webhook, ch certReadWrit
return err
}

dnsName, err := dnsNameForWebhook(&webhook.ClientConfig)
if err != nil {
return err
}
// Recreate the cert if it's invalid.
if !validCert(certs) {
valid := validCert(certs, dnsName)
if !valid {
log.Printf("cert is invalid or expiring, regenerating a new one")
certs, err = ch.overwrite(webhook.Name)
if err != nil {
Expand Down Expand Up @@ -141,18 +149,36 @@ type certReadWriter interface {
overwrite(webhookName string) (*generator.Artifacts, error)
}

func validCert(certs *generator.Artifacts) bool {
// TODO:
// 1) validate the key and the cert are valid pair e.g. call crypto/tls.X509KeyPair()
// 2) validate the cert with the CA cert
// 3) validate the cert is for a certain DNSName
// e.g.
// c, err := tls.X509KeyPair(cert, key)
// err := c.Verify(options)
func validCert(certs *generator.Artifacts, dnsName string) bool {
if certs == nil {
return false
}

// Verify key and cert are valid pair
_, err := tls.X509KeyPair(certs.Cert, certs.Key)
if err != nil {
return false
}

// Verify cert is good for desired DNS name and signed by CA and will be valid for desired period of time.
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certs.CACert) {
return false
}
block, _ := pem.Decode([]byte(certs.Cert))
if block == nil {
return false
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return false
}
ops := x509.VerifyOptions{
DNSName: dnsName,
Roots: pool,
CurrentTime: time.Now().AddDate(0, 6, 0),
}
_, err = cert.Verify(ops)
return err == nil
}

Expand Down
102 changes: 70 additions & 32 deletions pkg/admission/cert/writer/certwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

var certs1, certs2 *generator.Artifacts

func init() {
cn1 := "example.com"
cn2 := "test-service.test-svc-namespace.svc"
cp := generator.SelfSignedCertGenerator{}
certs1, _ = cp.Generate(cn1)
certs2, _ = cp.Generate(cn2)
}

var _ = Describe("NewProvider", func() {
var cl client.Client
var ops Options
Expand Down Expand Up @@ -73,33 +83,6 @@ var _ = Describe("NewProvider", func() {

})

var certPEM = `-----BEGIN CERTIFICATE-----
MIICRzCCAfGgAwIBAgIJALMb7ecMIk3MMA0GCSqGSIb3DQEBCwUAMH4xCzAJBgNV
BAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNVBAcMBkxvbmRvbjEYMBYGA1UE
CgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1JVCBEZXBhcnRtZW50MRswGQYD
VQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTAwIBcNMTcwNDI2MjMyNjUyWhgPMjExNzA0
MDIyMzI2NTJaMH4xCzAJBgNVBAYTAkdCMQ8wDQYDVQQIDAZMb25kb24xDzANBgNV
BAcMBkxvbmRvbjEYMBYGA1UECgwPR2xvYmFsIFNlY3VyaXR5MRYwFAYDVQQLDA1J
VCBEZXBhcnRtZW50MRswGQYDVQQDDBJ0ZXN0LWNlcnRpZmljYXRlLTAwXDANBgkq
hkiG9w0BAQEFAANLADBIAkEAtBMa7NWpv3BVlKTCPGO/LEsguKqWHBtKzweMY2CV
tAL1rQm913huhxF9w+ai76KQ3MHK5IVnLJjYYA5MzP2H5QIDAQABo1AwTjAdBgNV
HQ4EFgQU22iy8aWkNSxv0nBxFxerfsvnZVMwHwYDVR0jBBgwFoAU22iy8aWkNSxv
0nBxFxerfsvnZVMwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAANBAEOefGbV
NcHxklaW06w6OBYJPwpIhCVozC1qdxGX1dg8VkEKzjOzjgqVD30m59OFmSlBmHsl
nkVA6wyOSDYBf3o=
-----END CERTIFICATE-----`

var keyPEM = `-----BEGIN RSA PRIVATE KEY-----
MIIBUwIBADANBgkqhkiG9w0BAQEFAASCAT0wggE5AgEAAkEAtBMa7NWpv3BVlKTC
PGO/LEsguKqWHBtKzweMY2CVtAL1rQm913huhxF9w+ai76KQ3MHK5IVnLJjYYA5M
zP2H5QIDAQABAkAS9BfXab3OKpK3bIgNNyp+DQJKrZnTJ4Q+OjsqkpXvNltPJosf
G8GsiKu/vAt4HGqI3eU77NvRI+mL4MnHRmXBAiEA3qM4FAtKSRBbcJzPxxLEUSwg
XSCcosCktbkXvpYrS30CIQDPDxgqlwDEJQ0uKuHkZI38/SPWWqfUmkecwlbpXABK
iQIgZX08DA8VfvcA5/Xj1Zjdey9FVY6POLXen6RPiabE97UCICp6eUW7ht+2jjar
e35EltCRCjoejRHTuN9TC0uCoVipAiAXaJIx/Q47vGwiw6Y8KXsNU6y54gTbOSxX
54LzHNk/+Q==
-----END RSA PRIVATE KEY-----`

type fakeCertReadWriter struct {
numReadCalled int
readCertAndErr []certAndErr
Expand Down Expand Up @@ -154,11 +137,16 @@ var _ = Describe("handleCommon", func() {
var invalidCert *generator.Artifacts

BeforeEach(func(done Done) {
webhook = &admissionregistration.Webhook{}
url := "https://example.com/admission"
webhook = &admissionregistration.Webhook{
ClientConfig: admissionregistration.WebhookClientConfig{
URL: &url,
},
}
cert = &generator.Artifacts{
CACert: []byte(`CACertBytes`),
Cert: []byte(certPEM),
Key: []byte(keyPEM),
CACert: []byte(certs1.CACert),
Cert: []byte(certs1.Cert),
Key: []byte(certs1.Key),
}
invalidCert = &generator.Artifacts{
CACert: []byte(`CACertBytes`),
Expand Down Expand Up @@ -188,7 +176,11 @@ var _ = Describe("handleCommon", func() {
certrw := &fakeCertReadWriter{
readCertAndErr: []certAndErr{
{
err: notFoundError{errors.NewNotFound(schema.GroupResource{}, "foo")},
err: notFoundError{errors.NewNotFound(schema.GroupResource{}, "foo")},
},
},
writeCertAndErr: []certAndErr{
{
cert: cert,
},
},
Expand All @@ -198,6 +190,7 @@ var _ = Describe("handleCommon", func() {
Expect(err).NotTo(HaveOccurred())
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(1))
Expect(certrw.numOverwriteCalled).To(Equal(0))
})

It("should return the error on failed write", func() {
Expand All @@ -218,6 +211,7 @@ var _ = Describe("handleCommon", func() {
Expect(err).To(MatchError(goerrors.New("failed to write")))
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(1))
Expect(certrw.numOverwriteCalled).To(Equal(0))
})
})

Expand All @@ -234,6 +228,8 @@ var _ = Describe("handleCommon", func() {
err := handleCommon(webhook, certrw)
Expect(err).NotTo(HaveOccurred())
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(0))
Expect(certrw.numOverwriteCalled).To(Equal(0))
})

It("should return the error on failed read", func() {
Expand All @@ -248,6 +244,8 @@ var _ = Describe("handleCommon", func() {
err := handleCommon(webhook, certrw)
Expect(err).To(MatchError(goerrors.New("failed to read")))
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(0))
Expect(certrw.numOverwriteCalled).To(Equal(0))
})
})

Expand All @@ -269,6 +267,7 @@ var _ = Describe("handleCommon", func() {
err := handleCommon(webhook, certrw)
Expect(err).NotTo(HaveOccurred())
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(0))
Expect(certrw.numOverwriteCalled).To(Equal(1))
})

Expand All @@ -289,6 +288,7 @@ var _ = Describe("handleCommon", func() {
err := handleCommon(webhook, certrw)
Expect(err).NotTo(HaveOccurred())
Expect(certrw.numReadCalled).To(Equal(1))
Expect(certrw.numWriteCalled).To(Equal(0))
Expect(certrw.numOverwriteCalled).To(Equal(1))
})

Expand Down Expand Up @@ -413,3 +413,41 @@ var _ = Describe("dnsNameForWebhook", func() {
})
})
})

var _ = Describe("validate cert", func() {
Context("invalid pair", func() {
It("should detect it", func() {
certs := generator.Artifacts{
CACert: certs1.CACert,
Cert: certs1.Cert,
Key: certs2.Key,
}
valid := validCert(&certs, "example.com")
Expect(valid).To(BeFalse())
})
})

Context("CA not matching", func() {
It("should detect it", func() {
certs := generator.Artifacts{
CACert: certs2.CACert,
Cert: certs1.Cert,
Key: certs1.Key,
}
valid := validCert(&certs, "example.com")
Expect(valid).To(BeFalse())
})
})

Context("DNS name not matching", func() {
It("should detect it", func() {
certs := generator.Artifacts{
CACert: certs1.CACert,
Cert: certs1.Cert,
Key: certs1.Key,
}
valid := validCert(&certs, "foo.com")
Expect(valid).To(BeFalse())
})
})
})
48 changes: 24 additions & 24 deletions pkg/admission/cert/writer/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ var _ = Describe("FSCertWriter", func() {
CertGenerator: &fakegenerator.CertGenerator{
DNSNameToCertArtifacts: map[string]*generator.Artifacts{
"test-service.test-svc-namespace.svc": {
CACert: []byte(`CACertBytes`),
Cert: []byte(certPEM),
Key: []byte(keyPEM),
CACert: []byte(certs2.CACert),
Cert: []byte(certs2.Cert),
Key: []byte(certs2.Key),
},
},
},
Expand Down Expand Up @@ -187,13 +187,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`CACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})

Expand All @@ -212,13 +212,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`CACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})

Expand All @@ -241,13 +241,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`CACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})
})
Expand All @@ -266,13 +266,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`CACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})

Expand Down Expand Up @@ -301,13 +301,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`CACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})
})
Expand All @@ -317,11 +317,11 @@ var _ = Describe("FSCertWriter", func() {
Context("cert is valid", func() {
Context("when not expiring", func() {
BeforeEach(func(done Done) {
err := ioutil.WriteFile(path.Join(testingDir, CACertName), []byte(`oldCACertBytes`), 0600)
err := ioutil.WriteFile(path.Join(testingDir, CACertName), []byte(certs2.CACert), 0600)
Expect(err).NotTo(HaveOccurred())
err = ioutil.WriteFile(path.Join(testingDir, ServerCertName), []byte(certPEM), 0600)
err = ioutil.WriteFile(path.Join(testingDir, ServerCertName), []byte(certs2.Cert), 0600)
Expect(err).NotTo(HaveOccurred())
err = ioutil.WriteFile(path.Join(testingDir, ServerKeyName), []byte(keyPEM), 0600)
err = ioutil.WriteFile(path.Join(testingDir, ServerKeyName), []byte(certs2.Key), 0600)
Expect(err).NotTo(HaveOccurred())
close(done)
})
Expand All @@ -330,13 +330,13 @@ var _ = Describe("FSCertWriter", func() {
Expect(err).NotTo(HaveOccurred())
caBytes, err := ioutil.ReadFile(path.Join(testingDir, CACertName))
Expect(err).NotTo(HaveOccurred())
Expect(caBytes).To(Equal([]byte(`oldCACertBytes`)))
Expect(caBytes).To(Equal([]byte(certs2.CACert)))
certBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerCertName))
Expect(err).NotTo(HaveOccurred())
Expect(certBytes).To(Equal([]byte(certPEM)))
Expect(certBytes).To(Equal([]byte(certs2.Cert)))
keyBytes, err := ioutil.ReadFile(path.Join(testingDir, ServerKeyName))
Expect(err).NotTo(HaveOccurred())
Expect(keyBytes).To(Equal([]byte(keyPEM)))
Expect(keyBytes).To(Equal([]byte(certs2.Key)))
})
})

Expand Down
Loading