Skip to content

Commit 96b67f2

Browse files
authored
Merge pull request #421 from bison/tls-reload
✨ webhook: Handle TLS certificate rotation
2 parents 38483b2 + 2d952ac commit 96b67f2

File tree

3 files changed

+177
-4
lines changed

3 files changed

+177
-4
lines changed

Gopkg.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
Copyright 2019 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package certwatcher
18+
19+
import (
20+
"crypto/tls"
21+
"sync"
22+
23+
"gopkg.in/fsnotify.v1"
24+
logf "sigs.k8s.io/controller-runtime/pkg/internal/log"
25+
)
26+
27+
var log = logf.RuntimeLog.WithName("certwatcher")
28+
29+
// CertWatcher watches certificate and key files for changes. When either file
30+
// changes, it reads and parses both and calls an optional callback with the new
31+
// certificate.
32+
type CertWatcher struct {
33+
sync.Mutex
34+
35+
currentCert *tls.Certificate
36+
watcher *fsnotify.Watcher
37+
38+
certPath string
39+
keyPath string
40+
}
41+
42+
// New returns a new CertWatcher watching the given certificate and key.
43+
func New(certPath, keyPath string) (*CertWatcher, error) {
44+
var err error
45+
46+
cw := &CertWatcher{
47+
certPath: certPath,
48+
keyPath: keyPath,
49+
}
50+
51+
// Initial read of certificate and key.
52+
if err := cw.ReadCertificate(); err != nil {
53+
return nil, err
54+
}
55+
56+
cw.watcher, err = fsnotify.NewWatcher()
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
return cw, nil
62+
}
63+
64+
// GetCertificate fetches the currently loaded certificate, which may be nil.
65+
func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
66+
cw.Lock()
67+
defer cw.Unlock()
68+
return cw.currentCert, nil
69+
}
70+
71+
// Start starts the watch on the certificate and key files.
72+
func (cw *CertWatcher) Start(stopCh <-chan struct{}) error {
73+
files := []string{cw.certPath, cw.keyPath}
74+
75+
for _, f := range files {
76+
if err := cw.watcher.Add(f); err != nil {
77+
return err
78+
}
79+
}
80+
81+
go cw.Watch()
82+
83+
log.Info("Starting certificate watcher")
84+
85+
// Block until the stop channel is closed.
86+
<-stopCh
87+
88+
return cw.watcher.Close()
89+
}
90+
91+
// Watch reads events from the watcher's channel and reacts to changes.
92+
func (cw *CertWatcher) Watch() {
93+
for {
94+
select {
95+
case event, ok := <-cw.watcher.Events:
96+
// Channel is closed.
97+
if !ok {
98+
return
99+
}
100+
101+
cw.handleEvent(event)
102+
103+
case err, ok := <-cw.watcher.Errors:
104+
// Channel is closed.
105+
if !ok {
106+
return
107+
}
108+
109+
log.Error(err, "certificate watch error")
110+
}
111+
}
112+
}
113+
114+
// ReadCertificate reads the certificate and key files from disk, parses them,
115+
// and updates the current certificate on the watcher. If a callback is set, it
116+
// is invoked with the new certificate.
117+
func (cw *CertWatcher) ReadCertificate() error {
118+
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
119+
if err != nil {
120+
return err
121+
}
122+
123+
cw.Lock()
124+
cw.currentCert = &cert
125+
cw.Unlock()
126+
127+
log.Info("Updated current TLS certiface")
128+
129+
return nil
130+
}
131+
132+
func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
133+
// Only care about events which may modify the contents of the file.
134+
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
135+
return
136+
}
137+
138+
log.V(1).Info("certificate event", "event", event)
139+
140+
// If the file was removed, re-add the watch.
141+
if isRemove(event) {
142+
if err := cw.watcher.Add(event.Name); err != nil {
143+
log.Error(err, "error re-watching file")
144+
}
145+
}
146+
147+
if err := cw.ReadCertificate(); err != nil {
148+
log.Error(err, "error re-reading certificate")
149+
}
150+
}
151+
152+
func isWrite(event fsnotify.Event) bool {
153+
return event.Op&fsnotify.Write == fsnotify.Write
154+
}
155+
156+
func isCreate(event fsnotify.Event) bool {
157+
return event.Op&fsnotify.Create == fsnotify.Create
158+
}
159+
160+
func isRemove(event fsnotify.Event) bool {
161+
return event.Op&fsnotify.Remove == fsnotify.Remove
162+
}

pkg/webhook/server.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ import (
2323
"net"
2424
"net/http"
2525
"path"
26+
"path/filepath"
2627
"strconv"
2728
"sync"
2829
"time"
2930

3031
"sigs.k8s.io/controller-runtime/pkg/runtime/inject"
32+
"sigs.k8s.io/controller-runtime/pkg/webhook/internal/certwatcher"
3133
"sigs.k8s.io/controller-runtime/pkg/webhook/internal/metrics"
3234
)
3335

@@ -132,15 +134,23 @@ func (s *Server) Start(stop <-chan struct{}) error {
132134
}
133135
}
134136

135-
// TODO: watch the cert dir. Reload the cert if it changes
136-
cert, err := tls.LoadX509KeyPair(path.Join(s.CertDir, certName), path.Join(s.CertDir, keyName))
137+
certPath := filepath.Join(s.CertDir, certName)
138+
keyPath := filepath.Join(s.CertDir, keyName)
139+
140+
certWatcher, err := certwatcher.New(certPath, keyPath)
137141
if err != nil {
138142
return err
139143
}
140144

145+
go func() {
146+
if err := certWatcher.Start(stop); err != nil {
147+
log.Error(err, "certificate watcher error")
148+
}
149+
}()
150+
141151
cfg := &tls.Config{
142-
Certificates: []tls.Certificate{cert},
143-
NextProtos: []string{"h2"},
152+
NextProtos: []string{"h2"},
153+
GetCertificate: certWatcher.GetCertificate,
144154
}
145155

146156
listener, err := tls.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(int(s.Port))), cfg)

0 commit comments

Comments
 (0)