diff --git a/cmd/csi-snapshotter/main.go b/cmd/csi-snapshotter/main.go index 529e8f7c..7aea51ad 100644 --- a/cmd/csi-snapshotter/main.go +++ b/cmd/csi-snapshotter/main.go @@ -24,14 +24,19 @@ import ( "os/signal" "time" + "google.golang.org/grpc" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" "k8s.io/klog" - "github.com/kubernetes-csi/external-snapshotter/pkg/connection" + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/kubernetes-csi/csi-lib-utils/connection" + csirpc "github.com/kubernetes-csi/csi-lib-utils/rpc" "github.com/kubernetes-csi/external-snapshotter/pkg/controller" + "github.com/kubernetes-csi/external-snapshotter/pkg/snapshotter" clientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned" snapshotscheme "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned/scheme" @@ -50,7 +55,7 @@ const ( // Command line flags var ( - snapshotter = flag.String("snapshotter", "", "This option is deprecated.") + snapshotterName = flag.String("snapshotter", "", "This option is deprecated.") kubeconfig = flag.String("kubeconfig", "", "Absolute path to the kubeconfig file. Required only when running out of cluster.") connectionTimeout = flag.Duration("connection-timeout", 0, "The --connection-timeout flag is deprecated") csiAddress = flag.String("csi-address", "/run/csi/socket", "Address of the CSI driver socket.") @@ -80,7 +85,8 @@ func main() { if *connectionTimeout != 0 { klog.Warning("--connection-timeout is deprecated and will have no effect") } - if *snapshotter != "" { + + if *snapshotterName != "" { klog.Warning("--snapshotter is deprecated and will have no effect") } @@ -124,9 +130,9 @@ func main() { snapshotscheme.AddToScheme(scheme.Scheme) // Connect to CSI. - csiConn, err := connection.New(*csiAddress) + csiConn, err := connection.Connect(*csiAddress) if err != nil { - klog.Error(err.Error()) + klog.Errorf("error connecting to CSI driver: %v", err) os.Exit(1) } @@ -135,27 +141,29 @@ func main() { defer cancel() // Find driver name - *snapshotter, err = csiConn.GetDriverName(ctx) + *snapshotterName, err = csirpc.GetDriverName(ctx, csiConn) if err != nil { - klog.Error(err.Error()) + klog.Errorf("error getting CSI driver name: %v", err) os.Exit(1) } - klog.V(2).Infof("CSI driver name: %q", *snapshotter) + + klog.V(2).Infof("CSI driver name: %q", *snapshotterName) // Check it's ready - if err = waitForDriverReady(csiConn, *connectionTimeout); err != nil { - klog.Error(err.Error()) + if err = csirpc.ProbeForever(csiConn, csiTimeout); err != nil { + klog.Errorf("error waiting for CSI driver to be ready: %v", err) os.Exit(1) + } // Find out if the driver supports create/delete snapshot. - supportsCreateSnapshot, err := csiConn.SupportsControllerCreateSnapshot(ctx) + supportsCreateSnapshot, err := supportsControllerCreateSnapshot(ctx, csiConn) if err != nil { - klog.Error(err.Error()) + klog.Errorf("error determining if driver supports create/delete snapshot operations: %v", err) os.Exit(1) } if !supportsCreateSnapshot { - klog.Errorf("CSI driver %s does not support ControllerCreateSnapshot", *snapshotter) + klog.Errorf("CSI driver %s does not support ControllerCreateSnapshot", *snapshotterName) os.Exit(1) } @@ -164,19 +172,20 @@ func main() { os.Exit(1) } - klog.V(2).Infof("Start NewCSISnapshotController with snapshotter [%s] kubeconfig [%s] connectionTimeout [%+v] csiAddress [%s] createSnapshotContentRetryCount [%d] createSnapshotContentInterval [%+v] resyncPeriod [%+v] snapshotNamePrefix [%s] snapshotNameUUIDLength [%d]", *snapshotter, *kubeconfig, *connectionTimeout, *csiAddress, createSnapshotContentRetryCount, *createSnapshotContentInterval, *resyncPeriod, *snapshotNamePrefix, snapshotNameUUIDLength) + klog.V(2).Infof("Start NewCSISnapshotController with snapshotter [%s] kubeconfig [%s] connectionTimeout [%+v] csiAddress [%s] createSnapshotContentRetryCount [%d] createSnapshotContentInterval [%+v] resyncPeriod [%+v] snapshotNamePrefix [%s] snapshotNameUUIDLength [%d]", *snapshotterName, *kubeconfig, *connectionTimeout, *csiAddress, createSnapshotContentRetryCount, *createSnapshotContentInterval, *resyncPeriod, *snapshotNamePrefix, snapshotNameUUIDLength) + snapShotter := snapshotter.NewSnapshotter(csiConn) ctrl := controller.NewCSISnapshotController( snapClient, kubeClient, - *snapshotter, + *snapshotterName, factory.Volumesnapshot().V1alpha1().VolumeSnapshots(), factory.Volumesnapshot().V1alpha1().VolumeSnapshotContents(), factory.Volumesnapshot().V1alpha1().VolumeSnapshotClasses(), coreFactory.Core().V1().PersistentVolumeClaims(), *createSnapshotContentRetryCount, *createSnapshotContentInterval, - csiConn, + snapShotter, *connectionTimeout, *resyncPeriod, *snapshotNamePrefix, @@ -203,24 +212,11 @@ func buildConfig(kubeconfig string) (*rest.Config, error) { return rest.InClusterConfig() } -func waitForDriverReady(csiConn connection.CSIConnection, timeout time.Duration) error { - now := time.Now() - finish := now.Add(timeout) - var err error - for { - ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) - defer cancel() - err = csiConn.Probe(ctx) - if err == nil { - klog.V(2).Infof("Probe succeeded") - return nil - } - klog.V(2).Infof("Probe failed with %s", err) - - now := time.Now() - if now.After(finish) { - return fmt.Errorf("failed to probe the controller: %s", err) - } - time.Sleep(time.Second) +func supportsControllerCreateSnapshot(ctx context.Context, conn *grpc.ClientConn) (bool, error) { + capabilities, err := csirpc.GetControllerCapabilities(ctx, conn) + if err != nil { + return false, err } + + return capabilities[csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT], nil } diff --git a/cmd/csi-snapshotter/main_test.go b/cmd/csi-snapshotter/main_test.go new file mode 100644 index 00000000..f13aba72 --- /dev/null +++ b/cmd/csi-snapshotter/main_test.go @@ -0,0 +1,161 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + "fmt" + "testing" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/mock/gomock" + "github.com/kubernetes-csi/csi-lib-utils/connection" + "github.com/kubernetes-csi/csi-test/driver" + + "google.golang.org/grpc" +) + +func Test_supportsControllerCreateSnapshot(t *testing.T) { + tests := []struct { + name string + output *csi.ControllerGetCapabilitiesResponse + injectError bool + expectError bool + expectResult bool + }{ + { + name: "success", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + }, + }, + }, + }, + }, + expectError: false, + expectResult: true, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + expectResult: false, + }, + { + name: "no create snapshot", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + }, + }, + expectError: false, + expectResult: false, + }, + { + name: "empty capability", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectError: false, + expectResult: false, + }, + { + name: "no capabilities", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{}, + }, + expectError: false, + expectResult: false, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + + in := &csi.ControllerGetCapabilitiesRequest{} + + out := test.output + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) + + ok, err := supportsControllerCreateSnapshot(context.Background(), csiConn) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + if err == nil && test.expectResult != ok { + t.Errorf("test fail expected result %t but got %t\n", test.expectResult, ok) + } + } +} + +func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, *driver.MockIdentityServer, *driver.MockControllerServer, *grpc.ClientConn, error) { + // Start the mock server + mockController := gomock.NewController(t) + identityServer := driver.NewMockIdentityServer(mockController) + controllerServer := driver.NewMockControllerServer(mockController) + drv := driver.NewMockCSIDriver(&driver.MockCSIDriverServers{ + Identity: identityServer, + Controller: controllerServer, + }) + drv.Start() + + // Create a client connection to it + addr := drv.Address() + csiConn, err := connection.Connect(addr) + if err != nil { + return nil, nil, nil, nil, nil, err + } + + return mockController, drv, identityServer, controllerServer, csiConn, nil +} diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go deleted file mode 100644 index 760ae2cd..00000000 --- a/pkg/connection/connection.go +++ /dev/null @@ -1,251 +0,0 @@ -/* -Copyright 2018 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package connection - -import ( - "context" - "fmt" - - "github.com/container-storage-interface/spec/lib/go/csi" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/timestamp" - "github.com/kubernetes-csi/csi-lib-utils/connection" - "github.com/kubernetes-csi/csi-lib-utils/protosanitizer" - "google.golang.org/grpc" - "k8s.io/api/core/v1" - "k8s.io/klog" -) - -// CSIConnection is gRPC connection to a remote CSI driver and abstracts all -// CSI calls. -type CSIConnection interface { - // GetDriverName returns driver name as discovered by GetPluginInfo() - // gRPC call. - GetDriverName(ctx context.Context) (string, error) - - // SupportsControllerCreateSnapshot returns true if the CSI driver reports - // CREATE_DELETE_SNAPSHOT in ControllerGetCapabilities() gRPC call. - SupportsControllerCreateSnapshot(ctx context.Context) (bool, error) - - // SupportsControllerListSnapshots returns true if the CSI driver reports - // LIST_SNAPSHOTS in ControllerGetCapabilities() gRPC call. - SupportsControllerListSnapshots(ctx context.Context) (bool, error) - - // CreateSnapshot creates a snapshot for a volume - CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (driverName string, snapshotId string, timestamp int64, size int64, readyToUse bool, err error) - - // DeleteSnapshot deletes a snapshot from a volume - DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error) - - // GetSnapshotStatus returns if a snapshot is ready to use, creation time, and restore size. - GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) - - // Probe checks that the CSI driver is ready to process requests - Probe(ctx context.Context) error - - // Close the connection - Close() error -} - -type csiConnection struct { - conn *grpc.ClientConn -} - -var ( - _ CSIConnection = &csiConnection{} -) - -// New returns a CSI connection object. -func New(address string) (CSIConnection, error) { - conn, err := connection.Connect(address) - if err != nil { - return nil, err - } - return &csiConnection{ - conn: conn, - }, nil -} - -func (c *csiConnection) GetDriverName(ctx context.Context) (string, error) { - client := csi.NewIdentityClient(c.conn) - - req := csi.GetPluginInfoRequest{} - - rsp, err := client.GetPluginInfo(ctx, &req) - if err != nil { - return "", err - } - name := rsp.GetName() - if name == "" { - return "", fmt.Errorf("name is empty") - } - return name, nil -} - -func (c *csiConnection) Probe(ctx context.Context) error { - client := csi.NewIdentityClient(c.conn) - - req := csi.ProbeRequest{} - - _, err := client.Probe(ctx, &req) - if err != nil { - return err - } - return nil -} - -func (c *csiConnection) SupportsControllerCreateSnapshot(ctx context.Context) (bool, error) { - client := csi.NewControllerClient(c.conn) - req := csi.ControllerGetCapabilitiesRequest{} - - rsp, err := client.ControllerGetCapabilities(ctx, &req) - if err != nil { - return false, err - } - caps := rsp.GetCapabilities() - for _, cap := range caps { - if cap == nil { - continue - } - rpc := cap.GetRpc() - if rpc == nil { - continue - } - if rpc.GetType() == csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT { - return true, nil - } - } - return false, nil -} - -func (c *csiConnection) SupportsControllerListSnapshots(ctx context.Context) (bool, error) { - client := csi.NewControllerClient(c.conn) - req := csi.ControllerGetCapabilitiesRequest{} - - rsp, err := client.ControllerGetCapabilities(ctx, &req) - if err != nil { - return false, err - } - caps := rsp.GetCapabilities() - for _, cap := range caps { - if cap == nil { - continue - } - rpc := cap.GetRpc() - if rpc == nil { - continue - } - if rpc.GetType() == csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS { - return true, nil - } - } - return false, nil -} - -func (c *csiConnection) CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, bool, error) { - klog.V(5).Infof("CSI CreateSnapshot: %s", snapshotName) - if volume.Spec.CSI == nil { - return "", "", 0, 0, false, fmt.Errorf("CSIPersistentVolumeSource not defined in spec") - } - - client := csi.NewControllerClient(c.conn) - - driverName, err := c.GetDriverName(ctx) - if err != nil { - return "", "", 0, 0, false, err - } - - req := csi.CreateSnapshotRequest{ - SourceVolumeId: volume.Spec.CSI.VolumeHandle, - Name: snapshotName, - Parameters: parameters, - Secrets: snapshotterCredentials, - } - - rsp, err := client.CreateSnapshot(ctx, &req) - if err != nil { - return "", "", 0, 0, false, err - } - - klog.V(5).Infof("CSI CreateSnapshot: %s driver name [%s] snapshot ID [%s] time stamp [%d] size [%d] readyToUse [%v]", snapshotName, driverName, rsp.Snapshot.SnapshotId, rsp.Snapshot.CreationTime, rsp.Snapshot.SizeBytes, rsp.Snapshot.ReadyToUse) - creationTime, err := timestampToUnixTime(rsp.Snapshot.CreationTime) - if err != nil { - return "", "", 0, 0, false, err - } - return driverName, rsp.Snapshot.SnapshotId, creationTime, rsp.Snapshot.SizeBytes, rsp.Snapshot.ReadyToUse, nil -} - -func (c *csiConnection) DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error) { - client := csi.NewControllerClient(c.conn) - - req := csi.DeleteSnapshotRequest{ - SnapshotId: snapshotID, - Secrets: snapshotterCredentials, - } - - if _, err := client.DeleteSnapshot(ctx, &req); err != nil { - return err - } - - return nil -} - -func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) { - client := csi.NewControllerClient(c.conn) - - req := csi.ListSnapshotsRequest{ - SnapshotId: snapshotID, - } - - rsp, err := client.ListSnapshots(ctx, &req) - if err != nil { - return false, 0, 0, err - } - - if rsp.Entries == nil || len(rsp.Entries) == 0 { - return false, 0, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID) - } - - creationTime, err := timestampToUnixTime(rsp.Entries[0].Snapshot.CreationTime) - if err != nil { - return false, 0, 0, err - } - return rsp.Entries[0].Snapshot.ReadyToUse, creationTime, rsp.Entries[0].Snapshot.SizeBytes, nil -} - -func (c *csiConnection) Close() error { - return c.conn.Close() -} - -func logGRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - klog.V(5).Infof("GRPC call: %s", method) - klog.V(5).Infof("GRPC request: %s", protosanitizer.StripSecrets(req)) - err := invoker(ctx, method, req, reply, cc, opts...) - klog.V(5).Infof("GRPC response: %s", protosanitizer.StripSecrets(reply)) - klog.V(5).Infof("GRPC error: %v", err) - return err -} - -func timestampToUnixTime(t *timestamp.Timestamp) (int64, error) { - time, err := ptypes.Timestamp(t) - if err != nil { - return -1, err - } - // TODO: clean this up, we probably don't need this translation layer - // and can just use time.Time - return time.UnixNano(), nil -} diff --git a/pkg/controller/csi_handler.go b/pkg/controller/csi_handler.go index e0c7e7e3..2abee097 100644 --- a/pkg/controller/csi_handler.go +++ b/pkg/controller/csi_handler.go @@ -23,7 +23,8 @@ import ( "time" crdv1 "github.com/kubernetes-csi/external-snapshotter/pkg/apis/volumesnapshot/v1alpha1" - "github.com/kubernetes-csi/external-snapshotter/pkg/connection" + "github.com/kubernetes-csi/external-snapshotter/pkg/snapshotter" + "k8s.io/api/core/v1" ) @@ -36,7 +37,7 @@ type Handler interface { // csiHandler is a handler that calls CSI to create/delete volume snapshot. type csiHandler struct { - csiConnection connection.CSIConnection + snapshotter snapshotter.Snapshotter timeout time.Duration snapshotNamePrefix string snapshotNameUUIDLength int @@ -44,13 +45,13 @@ type csiHandler struct { // NewCSIHandler returns a handler which includes the csi connection and Snapshot name details func NewCSIHandler( - csiConnection connection.CSIConnection, + snapshotter snapshotter.Snapshotter, timeout time.Duration, snapshotNamePrefix string, snapshotNameUUIDLength int, ) Handler { return &csiHandler{ - csiConnection: csiConnection, + snapshotter: snapshotter, timeout: timeout, snapshotNamePrefix: snapshotNamePrefix, snapshotNameUUIDLength: snapshotNameUUIDLength, @@ -70,7 +71,7 @@ func (handler *csiHandler) CreateSnapshot(snapshot *crdv1.VolumeSnapshot, volume if err != nil { return "", "", 0, 0, false, fmt.Errorf("failed to remove CSI Parameters of prefixed keys: %v", err) } - return handler.csiConnection.CreateSnapshot(ctx, snapshotName, volume, newParameters, snapshotterCredentials) + return handler.snapshotter.CreateSnapshot(ctx, snapshotName, volume, newParameters, snapshotterCredentials) } func (handler *csiHandler) DeleteSnapshot(content *crdv1.VolumeSnapshotContent, snapshotterCredentials map[string]string) error { @@ -80,7 +81,7 @@ func (handler *csiHandler) DeleteSnapshot(content *crdv1.VolumeSnapshotContent, ctx, cancel := context.WithTimeout(context.Background(), handler.timeout) defer cancel() - err := handler.csiConnection.DeleteSnapshot(ctx, content.Spec.CSI.SnapshotHandle, snapshotterCredentials) + err := handler.snapshotter.DeleteSnapshot(ctx, content.Spec.CSI.SnapshotHandle, snapshotterCredentials) if err != nil { return fmt.Errorf("failed to delete snapshot content %s: %q", content.Name, err) } @@ -95,12 +96,12 @@ func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotConten ctx, cancel := context.WithTimeout(context.Background(), handler.timeout) defer cancel() - csiSnapshotStatus, timestamp, size, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle) + csiSnapshotStatus, timestamp, size, err := handler.snapshotter.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle) if err != nil { return false, 0, 0, fmt.Errorf("failed to list snapshot content %s: %q", content.Name, err) } - return csiSnapshotStatus, timestamp, size, nil + return csiSnapshotStatus, timestamp, size, nil } func makeSnapshotName(prefix, snapshotUID string, snapshotNameUUIDLength int) (string, error) { diff --git a/pkg/controller/framework_test.go b/pkg/controller/framework_test.go index 751f530a..a9ad4832 100644 --- a/pkg/controller/framework_test.go +++ b/pkg/controller/framework_test.go @@ -717,7 +717,7 @@ func newTestController(kubeClient kubernetes.Interface, clientset clientset.Inte coreFactory := coreinformers.NewSharedInformerFactory(kubeClient, NoResyncPeriodFunc()) // Construct controller - csiConnection := &fakeCSIConnection{ + fakeSnapshot := &fakeSnapshotter{ t: t, listCalls: test.expectedListCalls, createCalls: test.expectedCreateCalls, @@ -734,7 +734,7 @@ func newTestController(kubeClient kubernetes.Interface, clientset clientset.Inte coreFactory.Core().V1().PersistentVolumeClaims(), 3, 5*time.Millisecond, - csiConnection, + fakeSnapshot, 5*time.Millisecond, 60*time.Second, "snapshot", @@ -1152,9 +1152,9 @@ type createCall struct { err error } -// Fake CSIConnection implementation that check that Attach/Detach is called +// Fake SnapShotter implementation that check that Attach/Detach is called // with the right parameters and it returns proper error code and metadata. -type fakeCSIConnection struct { +type fakeSnapshotter struct { createCalls []createCall createCallCounter int deleteCalls []deleteCall @@ -1164,19 +1164,7 @@ type fakeCSIConnection struct { t *testing.T } -func (f *fakeCSIConnection) GetDriverName(ctx context.Context) (string, error) { - return mockDriverName, nil -} - -func (f *fakeCSIConnection) SupportsControllerCreateSnapshot(ctx context.Context) (bool, error) { - return false, fmt.Errorf("Not implemented") -} - -func (f *fakeCSIConnection) SupportsControllerListSnapshots(ctx context.Context) (bool, error) { - return false, fmt.Errorf("Not implemented") -} - -func (f *fakeCSIConnection) CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, bool, error) { +func (f *fakeSnapshotter) CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, bool, error) { if f.createCallCounter >= len(f.createCalls) { f.t.Errorf("Unexpected CSI Create Snapshot call: snapshotName=%s, volume=%v, index: %d, calls: %+v", snapshotName, volume.Name, f.createCallCounter, f.createCalls) return "", "", 0, 0, false, fmt.Errorf("unexpected call") @@ -1212,7 +1200,7 @@ func (f *fakeCSIConnection) CreateSnapshot(ctx context.Context, snapshotName str return call.driverName, call.snapshotId, call.timestamp, call.size, call.readyToUse, call.err } -func (f *fakeCSIConnection) DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) error { +func (f *fakeSnapshotter) DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) error { if f.deleteCallCounter >= len(f.deleteCalls) { f.t.Errorf("Unexpected CSI Delete Snapshot call: snapshotID=%s, index: %d, calls: %+v", snapshotID, f.createCallCounter, f.createCalls) return fmt.Errorf("unexpected call") @@ -1238,7 +1226,7 @@ func (f *fakeCSIConnection) DeleteSnapshot(ctx context.Context, snapshotID strin return call.err } -func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) { +func (f *fakeSnapshotter) GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) { if f.listCallCounter >= len(f.listCalls) { f.t.Errorf("Unexpected CSI list Snapshot call: snapshotID=%s, index: %d, calls: %+v", snapshotID, f.createCallCounter, f.createCalls) return false, 0, 0, fmt.Errorf("unexpected call") @@ -1258,11 +1246,3 @@ func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID st return call.readyToUse, call.createTime, call.size, call.err } - -func (f *fakeCSIConnection) Close() error { - return fmt.Errorf("Not implemented") -} - -func (f *fakeCSIConnection) Probe(ctx context.Context) error { - return nil -} diff --git a/pkg/controller/snapshot_controller_base.go b/pkg/controller/snapshot_controller_base.go index 8c3f9a4e..5458b1bc 100644 --- a/pkg/controller/snapshot_controller_base.go +++ b/pkg/controller/snapshot_controller_base.go @@ -24,8 +24,9 @@ import ( clientset "github.com/kubernetes-csi/external-snapshotter/pkg/client/clientset/versioned" storageinformers "github.com/kubernetes-csi/external-snapshotter/pkg/client/informers/externalversions/volumesnapshot/v1alpha1" storagelisters "github.com/kubernetes-csi/external-snapshotter/pkg/client/listers/volumesnapshot/v1alpha1" - "github.com/kubernetes-csi/external-snapshotter/pkg/connection" - v1 "k8s.io/api/core/v1" + "github.com/kubernetes-csi/external-snapshotter/pkg/snapshotter" + + "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/util/wait" @@ -81,7 +82,7 @@ func NewCSISnapshotController( pvcInformer coreinformers.PersistentVolumeClaimInformer, createSnapshotContentRetryCount int, createSnapshotContentInterval time.Duration, - conn connection.CSIConnection, + snapshotter snapshotter.Snapshotter, timeout time.Duration, resyncPeriod time.Duration, snapshotNamePrefix string, @@ -98,7 +99,7 @@ func NewCSISnapshotController( client: client, snapshotterName: snapshotterName, eventRecorder: eventRecorder, - handler: NewCSIHandler(conn, timeout, snapshotNamePrefix, snapshotNameUUIDLength), + handler: NewCSIHandler(snapshotter, timeout, snapshotNamePrefix, snapshotNameUUIDLength), runningOperations: goroutinemap.NewGoRoutineMap(true), createSnapshotContentRetryCount: createSnapshotContentRetryCount, createSnapshotContentInterval: createSnapshotContentInterval, diff --git a/pkg/snapshotter/snapshotter.go b/pkg/snapshotter/snapshotter.go new file mode 100644 index 00000000..ee5031df --- /dev/null +++ b/pkg/snapshotter/snapshotter.go @@ -0,0 +1,135 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package snapshotter + +import ( + "context" + "fmt" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/timestamp" + csirpc "github.com/kubernetes-csi/csi-lib-utils/rpc" + + "google.golang.org/grpc" + + "k8s.io/api/core/v1" + "k8s.io/klog" +) + +// Snapshotter implements CreateSnapshot/DeleteSnapshot operations against a remote CSI driver. +type Snapshotter interface { + // CreateSnapshot creates a snapshot for a volume + CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (driverName string, snapshotId string, timestamp int64, size int64, readyToUse bool, err error) + + // DeleteSnapshot deletes a snapshot from a volume + DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error) + + // GetSnapshotStatus returns if a snapshot is ready to use, creation time, and restore size. + GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) +} + +type snapshot struct { + conn *grpc.ClientConn +} + +func NewSnapshotter(conn *grpc.ClientConn) Snapshotter { + return &snapshot{ + conn: conn, + } +} + +func (s *snapshot) CreateSnapshot(ctx context.Context, snapshotName string, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, bool, error) { + klog.V(5).Infof("CSI CreateSnapshot: %s", snapshotName) + if volume.Spec.CSI == nil { + return "", "", 0, 0, false, fmt.Errorf("CSIPersistentVolumeSource not defined in spec") + } + + client := csi.NewControllerClient(s.conn) + + driverName, err := csirpc.GetDriverName(ctx, s.conn) + if err != nil { + return "", "", 0, 0, false, err + } + + req := csi.CreateSnapshotRequest{ + SourceVolumeId: volume.Spec.CSI.VolumeHandle, + Name: snapshotName, + Parameters: parameters, + Secrets: snapshotterCredentials, + } + + rsp, err := client.CreateSnapshot(ctx, &req) + if err != nil { + return "", "", 0, 0, false, err + } + + klog.V(5).Infof("CSI CreateSnapshot: %s driver name [%s] snapshot ID [%s] time stamp [%d] size [%d] readyToUse [%v]", snapshotName, driverName, rsp.Snapshot.SnapshotId, rsp.Snapshot.CreationTime, rsp.Snapshot.SizeBytes, rsp.Snapshot.ReadyToUse) + creationTime, err := timestampToUnixTime(rsp.Snapshot.CreationTime) + if err != nil { + return "", "", 0, 0, false, err + } + return driverName, rsp.Snapshot.SnapshotId, creationTime, rsp.Snapshot.SizeBytes, rsp.Snapshot.ReadyToUse, nil +} + +func (s *snapshot) DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error) { + client := csi.NewControllerClient(s.conn) + + req := csi.DeleteSnapshotRequest{ + SnapshotId: snapshotID, + Secrets: snapshotterCredentials, + } + + if _, err := client.DeleteSnapshot(ctx, &req); err != nil { + return err + } + + return nil +} + +func (s *snapshot) GetSnapshotStatus(ctx context.Context, snapshotID string) (bool, int64, int64, error) { + client := csi.NewControllerClient(s.conn) + + req := csi.ListSnapshotsRequest{ + SnapshotId: snapshotID, + } + + rsp, err := client.ListSnapshots(ctx, &req) + if err != nil { + return false, 0, 0, err + } + + if rsp.Entries == nil || len(rsp.Entries) == 0 { + return false, 0, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID) + } + + creationTime, err := timestampToUnixTime(rsp.Entries[0].Snapshot.CreationTime) + if err != nil { + return false, 0, 0, err + } + return rsp.Entries[0].Snapshot.ReadyToUse, creationTime, rsp.Entries[0].Snapshot.SizeBytes, nil +} + +func timestampToUnixTime(t *timestamp.Timestamp) (int64, error) { + time, err := ptypes.Timestamp(t) + if err != nil { + return -1, err + } + // TODO: clean this up, we probably don't need this translation layer + // and can just use time.Time + return time.UnixNano(), nil +} diff --git a/pkg/connection/connection_test.go b/pkg/snapshotter/snapshotter_test.go similarity index 60% rename from pkg/connection/connection_test.go rename to pkg/snapshotter/snapshotter_test.go index ed70f211..dbe973d3 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/snapshotter/snapshotter_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package connection +package snapshotter import ( "context" @@ -25,9 +25,13 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" "github.com/golang/protobuf/ptypes" + "github.com/kubernetes-csi/csi-lib-utils/connection" "github.com/kubernetes-csi/csi-test/driver" + + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -37,7 +41,7 @@ const ( driverName = "foo/bar" ) -func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, *driver.MockIdentityServer, *driver.MockControllerServer, CSIConnection, error) { +func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, *driver.MockIdentityServer, *driver.MockControllerServer, *grpc.ClientConn, error) { // Start the mock server mockController := gomock.NewController(t) identityServer := driver.NewMockIdentityServer(mockController) @@ -50,7 +54,7 @@ func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, // Create a client connection to it addr := drv.Address() - csiConn, err := New(addr) + csiConn, err := connection.Connect(addr) if err != nil { return nil, nil, nil, nil, nil, err } @@ -58,323 +62,6 @@ func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, return mockController, drv, identityServer, controllerServer, csiConn, nil } -func TestGetPluginInfo(t *testing.T) { - tests := []struct { - name string - output *csi.GetPluginInfoResponse - injectError bool - expectError bool - }{ - { - name: "success", - output: &csi.GetPluginInfoResponse{ - Name: "csi/example", - VendorVersion: "0.3.0", - Manifest: map[string]string{ - "hello": "world", - }, - }, - expectError: false, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - }, - { - name: "empty name", - output: &csi.GetPluginInfoResponse{ - Name: "", - }, - expectError: true, - }, - } - - mockController, driver, identityServer, _, csiConn, err := createMockServer(t) - if err != nil { - t.Fatal(err) - } - defer mockController.Finish() - defer driver.Stop() - defer csiConn.Close() - - for _, test := range tests { - - in := &csi.GetPluginInfoRequest{} - - out := test.output - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - // Setup expectation - identityServer.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, injectedErr).Times(1) - - name, err := csiConn.GetDriverName(context.Background()) - if test.expectError && err == nil { - t.Errorf("test %q: Expected error, got none", test.name) - } - if !test.expectError && err != nil { - t.Errorf("test %q: got error: %v", test.name, err) - } - if err == nil && name != "csi/example" { - t.Errorf("got unexpected name: %q", name) - } - } -} - -func TestSupportsControllerCreateSnapshot(t *testing.T) { - tests := []struct { - name string - output *csi.ControllerGetCapabilitiesResponse - injectError bool - expectError bool - expectResult bool - }{ - { - name: "success", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, - }, - }, - }, - }, - }, - expectError: false, - expectResult: true, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - expectResult: false, - }, - { - name: "no create snapshot", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - }, - }, - expectError: false, - expectResult: false, - }, - { - name: "empty capability", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: nil, - }, - }, - }, - expectError: false, - expectResult: false, - }, - { - name: "no capabilities", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{}, - }, - expectError: false, - expectResult: false, - }, - } - - mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) - if err != nil { - t.Fatal(err) - } - defer mockController.Finish() - defer driver.Stop() - defer csiConn.Close() - - for _, test := range tests { - - in := &csi.ControllerGetCapabilitiesRequest{} - - out := test.output - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - // Setup expectation - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) - - ok, err := csiConn.SupportsControllerCreateSnapshot(context.Background()) - if test.expectError && err == nil { - t.Errorf("test %q: Expected error, got none", test.name) - } - if !test.expectError && err != nil { - t.Errorf("test %q: got error: %v", test.name, err) - } - if err == nil && test.expectResult != ok { - t.Errorf("test fail expected result %t but got %t\n", test.expectResult, ok) - } - } -} - -func TestSupportsControllerListSnapshots(t *testing.T) { - tests := []struct { - name string - output *csi.ControllerGetCapabilitiesResponse - injectError bool - expectError bool - expectResult bool - }{ - { - name: "success", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, - }, - }, - }, - }, - }, - expectError: false, - expectResult: true, - }, - { - name: "have create_delete_snapshot but no list snapshot ", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, - }, - }, - }, - }, - }, - expectError: false, - expectResult: false, - }, - { - name: "gRPC error", - output: nil, - injectError: true, - expectError: true, - expectResult: false, - }, - { - name: "no list snapshot", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: &csi.ControllerServiceCapability_Rpc{ - Rpc: &csi.ControllerServiceCapability_RPC{ - Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - }, - }, - }, - }, - }, - expectError: false, - expectResult: false, - }, - { - name: "empty capability", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{ - { - Type: nil, - }, - }, - }, - expectError: false, - expectResult: false, - }, - { - name: "no capabilities", - output: &csi.ControllerGetCapabilitiesResponse{ - Capabilities: []*csi.ControllerServiceCapability{}, - }, - expectError: false, - expectResult: false, - }, - } - - mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) - if err != nil { - t.Fatal(err) - } - defer mockController.Finish() - defer driver.Stop() - defer csiConn.Close() - - for _, test := range tests { - - in := &csi.ControllerGetCapabilitiesRequest{} - - out := test.output - var injectedErr error - if test.injectError { - injectedErr = fmt.Errorf("mock error") - } - - // Setup expectation - controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) - - ok, err := csiConn.SupportsControllerListSnapshots(context.Background()) - if test.expectError && err == nil { - t.Errorf("test %q: Expected error, got none", test.name) - } - if !test.expectError && err != nil { - t.Errorf("test %q: got error: %v", test.name, err) - } - if err == nil && test.expectResult != ok { - t.Errorf("test fail expected result %t but got %t\n", test.expectResult, ok) - } - } -} - func TestCreateSnapshot(t *testing.T) { defaultName := "snapshot-test" defaultID := "testid" @@ -535,7 +222,8 @@ func TestCreateSnapshot(t *testing.T) { controllerServer.EXPECT().CreateSnapshot(gomock.Any(), in).Return(out, injectedErr).Times(1) } - driverName, snapshotId, timestamp, size, readyToUse, err := csiConn.CreateSnapshot(context.Background(), test.snapshotName, test.volume, test.parameters, test.secrets) + s := NewSnapshotter(csiConn) + driverName, snapshotId, timestamp, size, readyToUse, err := s.CreateSnapshot(context.Background(), test.snapshotName, test.volume, test.parameters, test.secrets) if test.expectError && err == nil { t.Errorf("test %q: Expected error, got none", test.name) } @@ -642,7 +330,8 @@ func TestDeleteSnapshot(t *testing.T) { controllerServer.EXPECT().DeleteSnapshot(gomock.Any(), in).Return(out, injectedErr).Times(1) } - err := csiConn.DeleteSnapshot(context.Background(), test.snapshotID, test.secrets) + s := NewSnapshotter(csiConn) + err := s.DeleteSnapshot(context.Background(), test.snapshotID, test.secrets) if test.expectError && err == nil { t.Errorf("test %q: Expected error, got none", test.name) } @@ -650,7 +339,6 @@ func TestDeleteSnapshot(t *testing.T) { t.Errorf("test %q: got error: %v", test.name, err) } } - } func TestGetSnapshotStatus(t *testing.T) { @@ -740,7 +428,8 @@ func TestGetSnapshotStatus(t *testing.T) { controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1) } - ready, createTime, size, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID) + s := NewSnapshotter(csiConn) + ready, createTime, size, err := s.GetSnapshotStatus(context.Background(), test.snapshotID) if test.expectError && err == nil { t.Errorf("test %q: Expected error, got none", test.name) }