Add certwatcher to webhook server

Signed-off-by: Grant Griffiths <grant@portworx.com>
This commit is contained in:
Grant Griffiths
2020-11-16 23:06:50 -08:00
parent 1e5a7ef7a7
commit b457b08ffc
24 changed files with 2775 additions and 12 deletions

View File

@@ -0,0 +1,164 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package webhook
import (
"context"
"crypto/tls"
"sync"
"github.com/fsnotify/fsnotify"
"k8s.io/klog"
)
// This file originated from github.com/kubernetes-sigs/controller-runtime/pkg/webhook/internal/certwatcher.
// We cannot import this package as it's an internal one. In addition, we cannot yet easily integrate
// with controller-runtime/pkg/webhook directly, as it would require extensive rework:
// https://github.com/kubernetes-csi/external-snapshotter/issues/422
// CertWatcher watches certificate and key files for changes. When either file
// changes, it reads and parses both and calls an optional callback with the new
// certificate.
type CertWatcher struct {
sync.Mutex
currentCert *tls.Certificate
watcher *fsnotify.Watcher
certPath string
keyPath string
}
// NewCertWatcher returns a new CertWatcher watching the given certificate and key.
func NewCertWatcher(certPath, keyPath string) (*CertWatcher, error) {
var err error
cw := &CertWatcher{
certPath: certPath,
keyPath: keyPath,
}
// Initial read of certificate and key.
if err := cw.ReadCertificate(); err != nil {
return nil, err
}
cw.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return cw, nil
}
// GetCertificate fetches the currently loaded certificate, which may be nil.
func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cw.Lock()
defer cw.Unlock()
return cw.currentCert, nil
}
// Start starts the watch on the certificate and key files.
func (cw *CertWatcher) Start(ctx context.Context) error {
files := []string{cw.certPath, cw.keyPath}
for _, f := range files {
if err := cw.watcher.Add(f); err != nil {
return err
}
}
go cw.Watch()
// Block until the context is done.
<-ctx.Done()
return cw.watcher.Close()
}
// Watch reads events from the watcher's channel and reacts to changes.
func (cw *CertWatcher) Watch() {
for {
select {
case event, ok := <-cw.watcher.Events:
// Channel is closed.
if !ok {
return
}
cw.handleEvent(event)
case err, ok := <-cw.watcher.Errors:
// Channel is closed.
if !ok {
return
}
klog.Error(err, "certificate watch error")
}
}
}
// ReadCertificate reads the certificate and key files from disk, parses them,
// and updates the current certificate on the watcher. If a callback is set, it
// is invoked with the new certificate.
func (cw *CertWatcher) ReadCertificate() error {
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
if err != nil {
return err
}
cw.Lock()
cw.currentCert = &cert
cw.Unlock()
klog.Info("Updated current TLS certificate")
return nil
}
func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
return
}
klog.V(1).Info("certificate event", "event", event)
// If the file was removed, re-add the watch.
if isRemove(event) {
if err := cw.watcher.Add(event.Name); err != nil {
klog.Error(err, "error re-watching file")
}
}
if err := cw.ReadCertificate(); err != nil {
klog.Error(err, "error re-reading certificate")
}
}
func isWrite(event fsnotify.Event) bool {
return event.Op&fsnotify.Write == fsnotify.Write
}
func isCreate(event fsnotify.Event) bool {
return event.Op&fsnotify.Create == fsnotify.Create
}
func isRemove(event fsnotify.Event) bool {
return event.Op&fsnotify.Remove == fsnotify.Remove
}

View File

@@ -17,6 +17,8 @@ limitations under the License.
package webhook
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
@@ -175,21 +177,45 @@ func serveSnapshotRequest(w http.ResponseWriter, r *http.Request) {
serve(w, r, newDelegateToV1AdmitHandler(admitSnapshot))
}
func main(cmd *cobra.Command, args []string) {
func startServer(ctx context.Context, tlsConfig *tls.Config, cw *CertWatcher) error {
go func() {
klog.Info("Starting certificate watcher")
if err := cw.Start(ctx); err != nil {
klog.Errorf("certificate watcher error: %v", err)
}
}()
fmt.Println("Starting webhook server")
config := Config{
CertFile: certFile,
KeyFile: keyFile,
mux := http.NewServeMux()
mux.HandleFunc("/volumesnapshot", serveSnapshotRequest)
mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("ok")) })
srv := &http.Server{
Handler: mux,
TLSConfig: tlsConfig,
}
http.HandleFunc("/volumesnapshot", serveSnapshotRequest)
http.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("ok")) })
server := &http.Server{
Addr: fmt.Sprintf(":%d", port),
TLSConfig: configTLS(config),
}
err := server.ListenAndServeTLS("", "")
// listener is always closed by srv.Serve
listener, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), tlsConfig)
if err != nil {
panic(err)
return err
}
return srv.Serve(listener)
}
func main(cmd *cobra.Command, args []string) {
// Create new cert watcher
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel() // stops certwatcher
cw, err := NewCertWatcher(certFile, keyFile)
if err != nil {
klog.Fatalf("failed to initialize new cert watcher: %v", err)
}
tlsConfig := &tls.Config{
GetCertificate: cw.GetCertificate,
}
if err := startServer(ctx, tlsConfig, cw); err != nil {
klog.Fatalf("server stopped: %v", err)
}
}

View File

@@ -0,0 +1,169 @@
package webhook
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"testing"
"time"
)
func TestWebhookCertReload(t *testing.T) {
// Initialize test space
tmpDir := os.TempDir() + "/webhook-cert-tests"
certFile = tmpDir + "/tls.crt"
keyFile = tmpDir + "/tls.key"
port = 30443
err := os.Mkdir(tmpDir, 0777)
if err != nil && err != os.ErrExist {
t.Errorf("unexpected error occurred while creating tmp dir: %v", err)
}
defer func() {
err := os.RemoveAll(tmpDir)
if err != nil {
t.Errorf("unexpected error occurred while deleting certs: %v", err)
}
}()
generateTestCertKeyPair(t, certFile, keyFile)
// Start test server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw, err := NewCertWatcher(certFile, keyFile)
if err != nil {
t.Errorf("failed to initialize new cert watcher: %v", err)
}
tlsConfig := &tls.Config{
GetCertificate: cw.GetCertificate,
}
go func() {
if err := startServer(ctx, tlsConfig, cw); err != nil {
panic(err)
}
}()
time.Sleep(250 * time.Millisecond) // Give some time for watcher to start
// TC: Original cert should not change with no file changes
originalCert, err := tlsConfig.GetCertificate(nil)
if err != nil {
t.Errorf("unexpected error occurred while getting cert: %v", err)
}
originalCertStr := string(originalCert.Certificate[0])
originalKey := originalCert.PrivateKey.(*rsa.PrivateKey)
newCert, err := tlsConfig.GetCertificate(nil) // get certificate again
if err != nil {
t.Errorf("unexpected error occurred while getting newcert: %v", err)
}
if string(newCert.Certificate[0]) != originalCertStr {
t.Error("new cert was updated when it should not have been")
}
newKey := newCert.PrivateKey.(*rsa.PrivateKey)
if !newKey.Equal(originalKey) {
t.Error("new key was updated when it should not have been")
}
// TC: Certificate should consistently change with a file change
for i := 0; i < 5; i++ {
// Generate new key/cert
generateTestCertKeyPair(t, certFile, keyFile)
// Wait for certwatcher to update
time.Sleep(250 * time.Millisecond)
newCert, err = tlsConfig.GetCertificate(nil)
if err != nil {
t.Errorf("unexpected error occurred while getting newcert: %v", err)
}
if string(newCert.Certificate[0]) == originalCertStr {
t.Errorf("new cert was not updated")
}
newKey = newCert.PrivateKey.(*rsa.PrivateKey)
if newKey.Equal(originalKey) {
t.Error("new key was not updated")
}
originalCertStr = string(newCert.Certificate[0])
originalKey = newKey
}
}
// generateTestCertKeyPair generates a new random test key/crt and writes it to tmpDir
// based on https://golang.org/src/crypto/tls/generate_cert.go
func generateTestCertKeyPair(t *testing.T, certPath, keyPath string) error {
notBefore := time.Now()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return fmt.Errorf("Failed to generate serial number: %v", err)
}
var priv interface{}
priv, err = rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return fmt.Errorf("Failed to generate key: %v", err)
}
keyUsage := x509.KeyUsageDigitalSignature
if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
keyUsage |= x509.KeyUsageKeyEncipherment
}
randomOrganizationStr := time.Now().String()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{randomOrganizationStr},
},
NotBefore: notBefore,
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"127.0.0.1"},
}
rk := priv.(*rsa.PrivateKey)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &rk.PublicKey, priv)
if err != nil {
return fmt.Errorf("Failed to create certificate: %v", err)
}
certOut, err := os.Create(certPath)
if err != nil {
return fmt.Errorf("Failed to open tls.crt for writing: %v", err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return fmt.Errorf("Failed to write data to tls.crt: %v", err)
}
if err := certOut.Close(); err != nil {
return fmt.Errorf("Error closing tls.crt: %v", err)
}
fmt.Printf("wrote new cert: %s\n", certPath)
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("Failed to open tls.key for writing: %v", err)
}
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return fmt.Errorf("Unable to marshal private key: %v", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return fmt.Errorf("Failed to write data to tls.key: %v", err)
}
if err := keyOut.Close(); err != nil {
return fmt.Errorf("Error closing tls.key: %v", err)
}
fmt.Printf("wrote new key: %s\n", keyPath)
return nil
}