Add certwatcher to webhook server
Signed-off-by: Grant Griffiths <grant@portworx.com>
This commit is contained in:
164
pkg/validation-webhook/certwatcher.go
Normal file
164
pkg/validation-webhook/certwatcher.go
Normal file
@@ -0,0 +1,164 @@
|
||||
/*
|
||||
Copyright 2020 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"k8s.io/klog"
|
||||
)
|
||||
|
||||
// This file originated from github.com/kubernetes-sigs/controller-runtime/pkg/webhook/internal/certwatcher.
|
||||
// We cannot import this package as it's an internal one. In addition, we cannot yet easily integrate
|
||||
// with controller-runtime/pkg/webhook directly, as it would require extensive rework:
|
||||
// https://github.com/kubernetes-csi/external-snapshotter/issues/422
|
||||
|
||||
// CertWatcher watches certificate and key files for changes. When either file
|
||||
// changes, it reads and parses both and calls an optional callback with the new
|
||||
// certificate.
|
||||
type CertWatcher struct {
|
||||
sync.Mutex
|
||||
|
||||
currentCert *tls.Certificate
|
||||
watcher *fsnotify.Watcher
|
||||
|
||||
certPath string
|
||||
keyPath string
|
||||
}
|
||||
|
||||
// NewCertWatcher returns a new CertWatcher watching the given certificate and key.
|
||||
func NewCertWatcher(certPath, keyPath string) (*CertWatcher, error) {
|
||||
var err error
|
||||
|
||||
cw := &CertWatcher{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
}
|
||||
|
||||
// Initial read of certificate and key.
|
||||
if err := cw.ReadCertificate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cw.watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cw, nil
|
||||
}
|
||||
|
||||
// GetCertificate fetches the currently loaded certificate, which may be nil.
|
||||
func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cw.Lock()
|
||||
defer cw.Unlock()
|
||||
return cw.currentCert, nil
|
||||
}
|
||||
|
||||
// Start starts the watch on the certificate and key files.
|
||||
func (cw *CertWatcher) Start(ctx context.Context) error {
|
||||
files := []string{cw.certPath, cw.keyPath}
|
||||
|
||||
for _, f := range files {
|
||||
if err := cw.watcher.Add(f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
go cw.Watch()
|
||||
|
||||
// Block until the context is done.
|
||||
<-ctx.Done()
|
||||
|
||||
return cw.watcher.Close()
|
||||
}
|
||||
|
||||
// Watch reads events from the watcher's channel and reacts to changes.
|
||||
func (cw *CertWatcher) Watch() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-cw.watcher.Events:
|
||||
// Channel is closed.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
cw.handleEvent(event)
|
||||
|
||||
case err, ok := <-cw.watcher.Errors:
|
||||
// Channel is closed.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
klog.Error(err, "certificate watch error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadCertificate reads the certificate and key files from disk, parses them,
|
||||
// and updates the current certificate on the watcher. If a callback is set, it
|
||||
// is invoked with the new certificate.
|
||||
func (cw *CertWatcher) ReadCertificate() error {
|
||||
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cw.Lock()
|
||||
cw.currentCert = &cert
|
||||
cw.Unlock()
|
||||
|
||||
klog.Info("Updated current TLS certificate")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
|
||||
// Only care about events which may modify the contents of the file.
|
||||
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
|
||||
return
|
||||
}
|
||||
|
||||
klog.V(1).Info("certificate event", "event", event)
|
||||
|
||||
// If the file was removed, re-add the watch.
|
||||
if isRemove(event) {
|
||||
if err := cw.watcher.Add(event.Name); err != nil {
|
||||
klog.Error(err, "error re-watching file")
|
||||
}
|
||||
}
|
||||
|
||||
if err := cw.ReadCertificate(); err != nil {
|
||||
klog.Error(err, "error re-reading certificate")
|
||||
}
|
||||
}
|
||||
|
||||
func isWrite(event fsnotify.Event) bool {
|
||||
return event.Op&fsnotify.Write == fsnotify.Write
|
||||
}
|
||||
|
||||
func isCreate(event fsnotify.Event) bool {
|
||||
return event.Op&fsnotify.Create == fsnotify.Create
|
||||
}
|
||||
|
||||
func isRemove(event fsnotify.Event) bool {
|
||||
return event.Op&fsnotify.Remove == fsnotify.Remove
|
||||
}
|
@@ -17,6 +17,8 @@ limitations under the License.
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@@ -175,21 +177,45 @@ func serveSnapshotRequest(w http.ResponseWriter, r *http.Request) {
|
||||
serve(w, r, newDelegateToV1AdmitHandler(admitSnapshot))
|
||||
}
|
||||
|
||||
func main(cmd *cobra.Command, args []string) {
|
||||
func startServer(ctx context.Context, tlsConfig *tls.Config, cw *CertWatcher) error {
|
||||
go func() {
|
||||
klog.Info("Starting certificate watcher")
|
||||
if err := cw.Start(ctx); err != nil {
|
||||
klog.Errorf("certificate watcher error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
fmt.Println("Starting webhook server")
|
||||
config := Config{
|
||||
CertFile: certFile,
|
||||
KeyFile: keyFile,
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/volumesnapshot", serveSnapshotRequest)
|
||||
mux.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("ok")) })
|
||||
srv := &http.Server{
|
||||
Handler: mux,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
http.HandleFunc("/volumesnapshot", serveSnapshotRequest)
|
||||
http.HandleFunc("/readyz", func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("ok")) })
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
TLSConfig: configTLS(config),
|
||||
}
|
||||
err := server.ListenAndServeTLS("", "")
|
||||
// listener is always closed by srv.Serve
|
||||
listener, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), tlsConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return srv.Serve(listener)
|
||||
}
|
||||
|
||||
func main(cmd *cobra.Command, args []string) {
|
||||
// Create new cert watcher
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel() // stops certwatcher
|
||||
cw, err := NewCertWatcher(certFile, keyFile)
|
||||
if err != nil {
|
||||
klog.Fatalf("failed to initialize new cert watcher: %v", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: cw.GetCertificate,
|
||||
}
|
||||
|
||||
if err := startServer(ctx, tlsConfig, cw); err != nil {
|
||||
klog.Fatalf("server stopped: %v", err)
|
||||
}
|
||||
}
|
||||
|
169
pkg/validation-webhook/webhook_test.go
Normal file
169
pkg/validation-webhook/webhook_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebhookCertReload(t *testing.T) {
|
||||
// Initialize test space
|
||||
tmpDir := os.TempDir() + "/webhook-cert-tests"
|
||||
certFile = tmpDir + "/tls.crt"
|
||||
keyFile = tmpDir + "/tls.key"
|
||||
port = 30443
|
||||
err := os.Mkdir(tmpDir, 0777)
|
||||
if err != nil && err != os.ErrExist {
|
||||
t.Errorf("unexpected error occurred while creating tmp dir: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error occurred while deleting certs: %v", err)
|
||||
}
|
||||
}()
|
||||
generateTestCertKeyPair(t, certFile, keyFile)
|
||||
|
||||
// Start test server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cw, err := NewCertWatcher(certFile, keyFile)
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize new cert watcher: %v", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: cw.GetCertificate,
|
||||
}
|
||||
go func() {
|
||||
if err := startServer(ctx, tlsConfig, cw); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(250 * time.Millisecond) // Give some time for watcher to start
|
||||
|
||||
// TC: Original cert should not change with no file changes
|
||||
originalCert, err := tlsConfig.GetCertificate(nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error occurred while getting cert: %v", err)
|
||||
}
|
||||
originalCertStr := string(originalCert.Certificate[0])
|
||||
originalKey := originalCert.PrivateKey.(*rsa.PrivateKey)
|
||||
|
||||
newCert, err := tlsConfig.GetCertificate(nil) // get certificate again
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error occurred while getting newcert: %v", err)
|
||||
}
|
||||
if string(newCert.Certificate[0]) != originalCertStr {
|
||||
t.Error("new cert was updated when it should not have been")
|
||||
}
|
||||
newKey := newCert.PrivateKey.(*rsa.PrivateKey)
|
||||
if !newKey.Equal(originalKey) {
|
||||
t.Error("new key was updated when it should not have been")
|
||||
}
|
||||
|
||||
// TC: Certificate should consistently change with a file change
|
||||
for i := 0; i < 5; i++ {
|
||||
// Generate new key/cert
|
||||
generateTestCertKeyPair(t, certFile, keyFile)
|
||||
|
||||
// Wait for certwatcher to update
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
newCert, err = tlsConfig.GetCertificate(nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error occurred while getting newcert: %v", err)
|
||||
}
|
||||
if string(newCert.Certificate[0]) == originalCertStr {
|
||||
t.Errorf("new cert was not updated")
|
||||
}
|
||||
|
||||
newKey = newCert.PrivateKey.(*rsa.PrivateKey)
|
||||
if newKey.Equal(originalKey) {
|
||||
t.Error("new key was not updated")
|
||||
}
|
||||
|
||||
originalCertStr = string(newCert.Certificate[0])
|
||||
originalKey = newKey
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestCertKeyPair generates a new random test key/crt and writes it to tmpDir
|
||||
// based on https://golang.org/src/crypto/tls/generate_cert.go
|
||||
func generateTestCertKeyPair(t *testing.T, certPath, keyPath string) error {
|
||||
notBefore := time.Now()
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to generate serial number: %v", err)
|
||||
}
|
||||
|
||||
var priv interface{}
|
||||
priv, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to generate key: %v", err)
|
||||
}
|
||||
keyUsage := x509.KeyUsageDigitalSignature
|
||||
if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
|
||||
keyUsage |= x509.KeyUsageKeyEncipherment
|
||||
}
|
||||
randomOrganizationStr := time.Now().String()
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{randomOrganizationStr},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"127.0.0.1"},
|
||||
}
|
||||
|
||||
rk := priv.(*rsa.PrivateKey)
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &rk.PublicKey, priv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to open tls.crt for writing: %v", err)
|
||||
}
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return fmt.Errorf("Failed to write data to tls.crt: %v", err)
|
||||
}
|
||||
if err := certOut.Close(); err != nil {
|
||||
return fmt.Errorf("Error closing tls.crt: %v", err)
|
||||
}
|
||||
fmt.Printf("wrote new cert: %s\n", certPath)
|
||||
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to open tls.key for writing: %v", err)
|
||||
}
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to marshal private key: %v", err)
|
||||
}
|
||||
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
|
||||
return fmt.Errorf("Failed to write data to tls.key: %v", err)
|
||||
}
|
||||
if err := keyOut.Close(); err != nil {
|
||||
return fmt.Errorf("Error closing tls.key: %v", err)
|
||||
}
|
||||
fmt.Printf("wrote new key: %s\n", keyPath)
|
||||
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user