Files
2023-05-08 18:25:23 +05:30

181 lines
5.3 KiB
Go

package webhook
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"testing"
"time"
)
func TestWebhookCertReload(t *testing.T) {
// Initialize test space
tmpDir := os.TempDir() + "/webhook-cert-tests"
certFile = tmpDir + "/tls.crt"
keyFile = tmpDir + "/tls.key"
port = 30443
err := os.Mkdir(tmpDir, 0o777)
if err != nil && err != os.ErrExist {
t.Errorf("unexpected error occurred while creating tmp dir: %v", err)
}
defer func() {
err := os.RemoveAll(tmpDir)
if err != nil {
t.Errorf("unexpected error occurred while deleting certs: %v", err)
}
}()
err = generateTestCertKeyPair(t, certFile, keyFile)
if err != nil {
t.Errorf("unexpected error occurred while generating test certs: %v", err)
}
// Start test server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw, err := NewCertWatcher(certFile, keyFile)
if err != nil {
t.Errorf("failed to initialize new cert watcher: %v", err)
}
tlsConfig := &tls.Config{
GetCertificate: cw.GetCertificate,
}
go func() {
err := startServer(ctx,
tlsConfig,
cw,
&fakeSnapshotLister{},
&fakeGroupSnapshotLister{})
if err != nil {
panic(err)
}
}()
time.Sleep(250 * time.Millisecond) // Give some time for watcher to start
// TC: Original cert should not change with no file changes
originalCert, err := tlsConfig.GetCertificate(nil)
if err != nil {
t.Errorf("unexpected error occurred while getting cert: %v", err)
}
originalCertStr := string(originalCert.Certificate[0])
originalKey := originalCert.PrivateKey.(*rsa.PrivateKey)
newCert, err := tlsConfig.GetCertificate(nil) // get certificate again
if err != nil {
t.Errorf("unexpected error occurred while getting newcert: %v", err)
}
if string(newCert.Certificate[0]) != originalCertStr {
t.Error("new cert was updated when it should not have been")
}
newKey := newCert.PrivateKey.(*rsa.PrivateKey)
if !newKey.Equal(originalKey) {
t.Error("new key was updated when it should not have been")
}
// TC: Certificate should consistently change with a file change
for i := 0; i < 5; i++ {
// Generate new key/cert
err = generateTestCertKeyPair(t, certFile, keyFile)
if err != nil {
t.Errorf("unexpected error occurred while generating test certs: %v", err)
}
// Wait for certwatcher to update
time.Sleep(250 * time.Millisecond)
newCert, err = tlsConfig.GetCertificate(nil)
if err != nil {
t.Errorf("unexpected error occurred while getting newcert: %v", err)
}
if string(newCert.Certificate[0]) == originalCertStr {
t.Errorf("new cert was not updated")
}
newKey = newCert.PrivateKey.(*rsa.PrivateKey)
if newKey.Equal(originalKey) {
t.Error("new key was not updated")
}
originalCertStr = string(newCert.Certificate[0])
originalKey = newKey
}
}
// generateTestCertKeyPair generates a new random test key/crt and writes it to tmpDir
// based on https://golang.org/src/crypto/tls/generate_cert.go
func generateTestCertKeyPair(t *testing.T, certPath, keyPath string) error {
notBefore := time.Now()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return fmt.Errorf("Failed to generate serial number: %v", err)
}
var priv interface{}
priv, err = rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return fmt.Errorf("Failed to generate key: %v", err)
}
keyUsage := x509.KeyUsageDigitalSignature
if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
keyUsage |= x509.KeyUsageKeyEncipherment
}
randomOrganizationStr := time.Now().String()
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{randomOrganizationStr},
},
NotBefore: notBefore,
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"127.0.0.1"},
}
rk := priv.(*rsa.PrivateKey)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &rk.PublicKey, priv)
if err != nil {
return fmt.Errorf("Failed to create certificate: %v", err)
}
certOut, err := os.Create(certPath)
if err != nil {
return fmt.Errorf("Failed to open tls.crt for writing: %v", err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return fmt.Errorf("Failed to write data to tls.crt: %v", err)
}
if err := certOut.Close(); err != nil {
return fmt.Errorf("Error closing tls.crt: %v", err)
}
fmt.Printf("wrote new cert: %s\n", certPath)
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("Failed to open tls.key for writing: %v", err)
}
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return fmt.Errorf("Unable to marshal private key: %v", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return fmt.Errorf("Failed to write data to tls.key: %v", err)
}
if err := keyOut.Close(); err != nil {
return fmt.Errorf("Error closing tls.key: %v", err)
}
fmt.Printf("wrote new key: %s\n", keyPath)
return nil
}