package webhook import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "os" "testing" "time" ) func TestWebhookCertReload(t *testing.T) { // Initialize test space tmpDir := os.TempDir() + "/webhook-cert-tests" certFile = tmpDir + "/tls.crt" keyFile = tmpDir + "/tls.key" port = 30443 err := os.Mkdir(tmpDir, 0777) if err != nil && err != os.ErrExist { t.Errorf("unexpected error occurred while creating tmp dir: %v", err) } defer func() { err := os.RemoveAll(tmpDir) if err != nil { t.Errorf("unexpected error occurred while deleting certs: %v", err) } }() generateTestCertKeyPair(t, certFile, keyFile) // Start test server ctx, cancel := context.WithCancel(context.Background()) defer cancel() cw, err := NewCertWatcher(certFile, keyFile) if err != nil { t.Errorf("failed to initialize new cert watcher: %v", err) } tlsConfig := &tls.Config{ GetCertificate: cw.GetCertificate, } go func() { if err := startServer(ctx, tlsConfig, cw, &fakeSnapshotLister{}); err != nil { panic(err) } }() time.Sleep(250 * time.Millisecond) // Give some time for watcher to start // TC: Original cert should not change with no file changes originalCert, err := tlsConfig.GetCertificate(nil) if err != nil { t.Errorf("unexpected error occurred while getting cert: %v", err) } originalCertStr := string(originalCert.Certificate[0]) originalKey := originalCert.PrivateKey.(*rsa.PrivateKey) newCert, err := tlsConfig.GetCertificate(nil) // get certificate again if err != nil { t.Errorf("unexpected error occurred while getting newcert: %v", err) } if string(newCert.Certificate[0]) != originalCertStr { t.Error("new cert was updated when it should not have been") } newKey := newCert.PrivateKey.(*rsa.PrivateKey) if !newKey.Equal(originalKey) { t.Error("new key was updated when it should not have been") } // TC: Certificate should consistently change with a file change for i := 0; i < 5; i++ { // Generate new key/cert generateTestCertKeyPair(t, certFile, keyFile) // Wait for certwatcher to update time.Sleep(250 * time.Millisecond) newCert, err = tlsConfig.GetCertificate(nil) if err != nil { t.Errorf("unexpected error occurred while getting newcert: %v", err) } if string(newCert.Certificate[0]) == originalCertStr { t.Errorf("new cert was not updated") } newKey = newCert.PrivateKey.(*rsa.PrivateKey) if newKey.Equal(originalKey) { t.Error("new key was not updated") } originalCertStr = string(newCert.Certificate[0]) originalKey = newKey } } // generateTestCertKeyPair generates a new random test key/crt and writes it to tmpDir // based on https://golang.org/src/crypto/tls/generate_cert.go func generateTestCertKeyPair(t *testing.T, certPath, keyPath string) error { notBefore := time.Now() serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return fmt.Errorf("Failed to generate serial number: %v", err) } var priv interface{} priv, err = rsa.GenerateKey(rand.Reader, 4096) if err != nil { return fmt.Errorf("Failed to generate key: %v", err) } keyUsage := x509.KeyUsageDigitalSignature if _, isRSA := priv.(*rsa.PrivateKey); isRSA { keyUsage |= x509.KeyUsageKeyEncipherment } randomOrganizationStr := time.Now().String() template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{randomOrganizationStr}, }, NotBefore: notBefore, NotAfter: time.Now().Add(1 * time.Hour), KeyUsage: keyUsage, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, DNSNames: []string{"127.0.0.1"}, } rk := priv.(*rsa.PrivateKey) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &rk.PublicKey, priv) if err != nil { return fmt.Errorf("Failed to create certificate: %v", err) } certOut, err := os.Create(certPath) if err != nil { return fmt.Errorf("Failed to open tls.crt for writing: %v", err) } if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return fmt.Errorf("Failed to write data to tls.crt: %v", err) } if err := certOut.Close(); err != nil { return fmt.Errorf("Error closing tls.crt: %v", err) } fmt.Printf("wrote new cert: %s\n", certPath) keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("Failed to open tls.key for writing: %v", err) } privBytes, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { return fmt.Errorf("Unable to marshal private key: %v", err) } if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return fmt.Errorf("Failed to write data to tls.key: %v", err) } if err := keyOut.Close(); err != nil { return fmt.Errorf("Error closing tls.key: %v", err) } fmt.Printf("wrote new key: %s\n", keyPath) return nil }