From 0fa91a7df24dd282ddd23b6670fb1faa259591a2 Mon Sep 17 00:00:00 2001 From: wackxu Date: Mon, 27 Aug 2018 20:43:34 +0800 Subject: [PATCH] add UT for connection.go --- pkg/connection/connection_test.go | 771 ++++++++++++++++++++++++++++++ 1 file changed, 771 insertions(+) diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index c2b03924..a6aafd6c 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -1 +1,772 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package connection + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "github.com/container-storage-interface/spec/lib/go/csi/v0" + "github.com/golang/mock/gomock" + "github.com/kubernetes-csi/csi-test/driver" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +const ( + driverName = "foo/bar" +) + +func createMockServer(t *testing.T) (*gomock.Controller, *driver.MockCSIDriver, *driver.MockIdentityServer, *driver.MockControllerServer, CSIConnection, error) { + // Start the mock server + mockController := gomock.NewController(t) + identityServer := driver.NewMockIdentityServer(mockController) + controllerServer := driver.NewMockControllerServer(mockController) + drv := driver.NewMockCSIDriver(&driver.MockCSIDriverServers{ + Identity: identityServer, + Controller: controllerServer, + }) + drv.Start() + + // Create a client connection to it + addr := drv.Address() + csiConn, err := New(addr, 10) + if err != nil { + return nil, nil, nil, nil, nil, err + } + + return mockController, drv, identityServer, controllerServer, csiConn, nil +} + +func TestGetPluginInfo(t *testing.T) { + tests := []struct { + name string + output *csi.GetPluginInfoResponse + injectError bool + expectError bool + }{ + { + name: "success", + output: &csi.GetPluginInfoResponse{ + Name: "csi/example", + VendorVersion: "0.3.0", + Manifest: map[string]string{ + "hello": "world", + }, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty name", + output: &csi.GetPluginInfoResponse{ + Name: "", + }, + expectError: true, + }, + } + + mockController, driver, identityServer, _, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + + in := &csi.GetPluginInfoRequest{} + + out := test.output + var injectedErr error = nil + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + identityServer.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, injectedErr).Times(1) + + name, err := csiConn.GetDriverName(context.Background()) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + if err == nil && name != "csi/example" { + t.Errorf("got unexpected name: %q", name) + } + } +} + +func TestSupportsControllerCreateSnapshot(t *testing.T) { + tests := []struct { + name string + output *csi.ControllerGetCapabilitiesResponse + injectError bool + expectError bool + }{ + { + name: "success", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "no create snapshot", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "empty capability", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{}, + }, + expectError: false, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + + in := &csi.ControllerGetCapabilitiesRequest{} + + out := test.output + var injectedErr error = nil + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) + + _, err = csiConn.SupportsControllerCreateSnapshot(context.Background()) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + } +} + +func TestSupportsControllerListSnapshots(t *testing.T) { + tests := []struct { + name string + output *csi.ControllerGetCapabilitiesResponse + injectError bool + expectError bool + }{ + { + name: "success", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "no list snapshot", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + }, + }, + expectError: false, + }, + { + name: "empty capability", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.ControllerGetCapabilitiesResponse{ + Capabilities: []*csi.ControllerServiceCapability{}, + }, + expectError: false, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + + in := &csi.ControllerGetCapabilitiesRequest{} + + out := test.output + var injectedErr error = nil + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + // Setup expectation + controllerServer.EXPECT().ControllerGetCapabilities(gomock.Any(), in).Return(out, injectedErr).Times(1) + + _, err = csiConn.SupportsControllerListSnapshots(context.Background()) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + } +} + +func TestCreateSnapshot(t *testing.T) { + defaultName := "snapshot-test" + defaultID := "testid" + createTime := time.Now().UnixNano() + + createSecrets := map[string]string{"foo": "bar"} + defaultParameter := map[string]string{ + "param1": "value1", + "param2": "value2", + } + + csiVolume := FakeCSIVolume() + volumeWithoutCSI := FakeVolume() + + defaultRequest := &csi.CreateSnapshotRequest{ + Name: defaultName, + SourceVolumeId: csiVolume.Spec.CSI.VolumeHandle, + } + + attributesRequest := &csi.CreateSnapshotRequest{ + Name: defaultName, + Parameters: defaultParameter, + SourceVolumeId: csiVolume.Spec.CSI.VolumeHandle, + } + + secretsRequest := &csi.CreateSnapshotRequest{ + Name: defaultName, + SourceVolumeId: csiVolume.Spec.CSI.VolumeHandle, + CreateSnapshotSecrets: createSecrets, + } + + defaultResponse := &csi.CreateSnapshotResponse{ + Snapshot: &csi.Snapshot{ + Id: defaultID, + SizeBytes: 1000, + SourceVolumeId: csiVolume.Spec.CSI.VolumeHandle, + CreatedAt: createTime, + Status: &csi.SnapshotStatus{ + Type: csi.SnapshotStatus_READY, + Details: "success", + }, + }, + } + + pluginInfoOutput := &csi.GetPluginInfoResponse{ + Name: driverName, + VendorVersion: "0.3.0", + Manifest: map[string]string{ + "hello": "world", + }, + } + + type snapshotResult struct { + driverName string + snapshotId string + timestamp int64 + size int64 + status *csi.SnapshotStatus + } + + result := &snapshotResult{ + size: 1000, + driverName: driverName, + snapshotId: defaultID, + timestamp: createTime, + status: &csi.SnapshotStatus{ + Type: csi.SnapshotStatus_READY, + Details: "success", + }, + } + + tests := []struct { + name string + snapshotName string + volume *v1.PersistentVolume + parameters map[string]string + secrets map[string]string + input *csi.CreateSnapshotRequest + output *csi.CreateSnapshotResponse + injectError codes.Code + expectError bool + expectResult *snapshotResult + }{ + { + name: "success", + snapshotName: defaultName, + volume: csiVolume, + input: defaultRequest, + output: defaultResponse, + expectError: false, + expectResult: result, + }, + { + name: "attributes", + snapshotName: defaultName, + volume: csiVolume, + parameters: defaultParameter, + input: attributesRequest, + output: defaultResponse, + expectError: false, + expectResult: result, + }, + { + name: "secrets", + snapshotName: defaultName, + volume: csiVolume, + secrets: createSecrets, + input: secretsRequest, + output: defaultResponse, + expectError: false, + expectResult: result, + }, + { + name: "fail for volume without csi source", + snapshotName: defaultName, + volume: volumeWithoutCSI, + input: nil, + output: nil, + expectError: true, + }, + { + name: "gRPC transient error", + snapshotName: defaultName, + volume: csiVolume, + input: defaultRequest, + output: nil, + injectError: codes.DeadlineExceeded, + expectError: true, + }, + { + name: "gRPC final error", + snapshotName: defaultName, + volume: csiVolume, + input: defaultRequest, + output: nil, + injectError: codes.NotFound, + expectError: true, + }, + } + + mockController, driver, identityServer, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + in := test.input + out := test.output + var injectedErr error = nil + if test.injectError != codes.OK { + injectedErr = status.Error(test.injectError, fmt.Sprintf("Injecting error %d", test.injectError)) + } + + // Setup expectation + if in != nil { + identityServer.EXPECT().GetPluginInfo(gomock.Any(), gomock.Any()).Return(pluginInfoOutput, nil).Times(1) + controllerServer.EXPECT().CreateSnapshot(gomock.Any(), in).Return(out, injectedErr).Times(1) + } + + driverName, snapshotId, timestamp, size, status, err := csiConn.CreateSnapshot(context.Background(), test.snapshotName, test.volume, test.parameters, test.secrets) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + if test.expectResult != nil { + if driverName != test.expectResult.driverName { + t.Errorf("test %q: expected driverName: %q, got: %q", test.name, test.expectResult.driverName, driverName) + } + + if snapshotId != test.expectResult.snapshotId { + t.Errorf("test %q: expected snapshotId: %v, got: %v", test.name, test.expectResult.snapshotId, snapshotId) + } + + if timestamp != test.expectResult.timestamp { + t.Errorf("test %q: expected create time: %v, got: %v", test.name, test.expectResult.timestamp, timestamp) + } + + if size != test.expectResult.size { + t.Errorf("test %q: expected size: %v, got: %v", test.name, test.expectResult.size, size) + } + + if !reflect.DeepEqual(status, test.expectResult.status) { + t.Errorf("test %q: expected status: %v, got: %v", test.name, test.expectResult.status, status) + } + } + } +} + +func TestDeleteSnapshot(t *testing.T) { + defaultID := "testid" + secrets := map[string]string{"foo": "bar"} + + defaultRequest := &csi.DeleteSnapshotRequest{ + SnapshotId: defaultID, + } + + secretsRequest := &csi.DeleteSnapshotRequest{ + SnapshotId: defaultID, + DeleteSnapshotSecrets: secrets, + } + + tests := []struct { + name string + snapshotID string + secrets map[string]string + input *csi.DeleteSnapshotRequest + output *csi.DeleteSnapshotResponse + injectError codes.Code + expectError bool + }{ + { + name: "success", + snapshotID: defaultID, + input: defaultRequest, + output: &csi.DeleteSnapshotResponse{}, + expectError: false, + }, + { + name: "secrets", + snapshotID: defaultID, + secrets: secrets, + input: secretsRequest, + output: &csi.DeleteSnapshotResponse{}, + expectError: false, + }, + { + name: "gRPC transient error", + snapshotID: defaultID, + input: defaultRequest, + output: nil, + injectError: codes.DeadlineExceeded, + expectError: true, + }, + { + name: "gRPC final error", + snapshotID: defaultID, + input: defaultRequest, + output: nil, + injectError: codes.NotFound, + expectError: true, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + in := test.input + out := test.output + var injectedErr error = nil + if test.injectError != codes.OK { + injectedErr = status.Error(test.injectError, fmt.Sprintf("Injecting error %d", test.injectError)) + } + + // Setup expectation + if in != nil { + controllerServer.EXPECT().DeleteSnapshot(gomock.Any(), in).Return(out, injectedErr).Times(1) + } + + err := csiConn.DeleteSnapshot(context.Background(), test.snapshotID, test.secrets) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + } + +} + +func TestGetSnapshotStatus(t *testing.T) { + defaultID := "testid" + createdAt := time.Now().UnixNano() + + defaultRequest := &csi.ListSnapshotsRequest{ + SnapshotId: defaultID, + } + + defaultResponse := &csi.ListSnapshotsResponse{ + Entries: []*csi.ListSnapshotsResponse_Entry{ + { + Snapshot: &csi.Snapshot{ + Id: defaultID, + SizeBytes: 1000, + SourceVolumeId: "volumeid", + CreatedAt: createdAt, + Status: &csi.SnapshotStatus{ + Type: csi.SnapshotStatus_READY, + Details: "success", + }, + }, + }, + }, + } + + tests := []struct { + name string + snapshotID string + input *csi.ListSnapshotsRequest + output *csi.ListSnapshotsResponse + injectError codes.Code + expectError bool + expectStatus *csi.SnapshotStatus + expectCreateAt int64 + }{ + { + name: "success", + snapshotID: defaultID, + input: defaultRequest, + output: defaultResponse, + expectError: false, + expectStatus: &csi.SnapshotStatus{ + Type: csi.SnapshotStatus_READY, + Details: "success", + }, + expectCreateAt: createdAt, + }, + { + name: "gRPC transient error", + snapshotID: defaultID, + input: defaultRequest, + output: nil, + injectError: codes.DeadlineExceeded, + expectError: true, + }, + { + name: "gRPC final error", + snapshotID: defaultID, + input: defaultRequest, + output: nil, + injectError: codes.NotFound, + expectError: true, + }, + } + + mockController, driver, _, controllerServer, csiConn, err := createMockServer(t) + if err != nil { + t.Fatal(err) + } + defer mockController.Finish() + defer driver.Stop() + defer csiConn.Close() + + for _, test := range tests { + in := test.input + out := test.output + var injectedErr error = nil + if test.injectError != codes.OK { + injectedErr = status.Error(test.injectError, fmt.Sprintf("Injecting error %d", test.injectError)) + } + + // Setup expectation + if in != nil { + controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1) + } + + status, createTime, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID) + if test.expectError && err == nil { + t.Errorf("test %q: Expected error, got none", test.name) + } + if !test.expectError && err != nil { + t.Errorf("test %q: got error: %v", test.name, err) + } + if test.expectStatus != nil && !reflect.DeepEqual(test.expectStatus, status) { + t.Errorf("test %q: expected status: %v, got: %v", test.name, test.expectStatus, status) + } + if test.expectCreateAt != createTime { + t.Errorf("test %q: expected createTime: %v, got: %v", test.name, test.expectCreateAt, createTime) + } + } +} + +func FakeCSIVolume() *v1.PersistentVolume { + volume := v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fake-csi-volume", + }, + Spec: v1.PersistentVolumeSpec{ + ClaimRef: &v1.ObjectReference{ + Kind: "PersistentVolumeClaim", + APIVersion: "v1", + UID: types.UID("uid123"), + Namespace: "default", + Name: "test-claim", + }, + PersistentVolumeSource: v1.PersistentVolumeSource{ + CSI: &v1.CSIPersistentVolumeSource{ + Driver: driverName, + VolumeHandle: "foo", + }, + }, + StorageClassName: "default", + }, + Status: v1.PersistentVolumeStatus{ + Phase: v1.VolumeBound, + }, + } + + return &volume +} + +func FakeVolume() *v1.PersistentVolume { + volume := v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{ + Name: "fake-csi-volume", + }, + Spec: v1.PersistentVolumeSpec{ + ClaimRef: &v1.ObjectReference{ + Kind: "PersistentVolumeClaim", + APIVersion: "v1", + UID: types.UID("uid123"), + Namespace: "default", + Name: "test-claim", + }, + PersistentVolumeSource: v1.PersistentVolumeSource{ + GCEPersistentDisk: &v1.GCEPersistentDiskVolumeSource{}, + }, + StorageClassName: "default", + }, + Status: v1.PersistentVolumeStatus{ + Phase: v1.VolumeBound, + }, + } + + return &volume +}