diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 3789f287..1d04f95e 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -51,8 +51,8 @@ type CSIConnection interface { // DeleteSnapshot deletes a snapshot from a volume DeleteSnapshot(ctx context.Context, snapshotID string, snapshotterCredentials map[string]string) (err error) - // GetSnapshotStatus lists snapshot from a volume - GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) + // GetSnapshotStatus returns a snapshot's status, creation time, and restore size. + GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) // Probe checks that the CSI driver is ready to process requests Probe(ctx context.Context) error @@ -232,7 +232,7 @@ func (c *csiConnection) DeleteSnapshot(ctx context.Context, snapshotID string, s return nil } -func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) { +func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) { client := csi.NewControllerClient(c.conn) req := csi.ListSnapshotsRequest{ @@ -241,14 +241,14 @@ func (c *csiConnection) GetSnapshotStatus(ctx context.Context, snapshotID string rsp, err := client.ListSnapshots(ctx, &req) if err != nil { - return nil, 0, err + return nil, 0, 0, err } if rsp.Entries == nil || len(rsp.Entries) == 0 { - return nil, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID) + return nil, 0, 0, fmt.Errorf("can not find snapshot for snapshotID %s", snapshotID) } - return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, nil + return rsp.Entries[0].Snapshot.Status, rsp.Entries[0].Snapshot.CreatedAt, rsp.Entries[0].Snapshot.SizeBytes, nil } func (c *csiConnection) Close() error { diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index 2f88f715..644c107d 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -658,6 +658,7 @@ func TestDeleteSnapshot(t *testing.T) { func TestGetSnapshotStatus(t *testing.T) { defaultID := "testid" createdAt := time.Now().UnixNano() + size := int64(1000) defaultRequest := &csi.ListSnapshotsRequest{ SnapshotId: defaultID, @@ -668,7 +669,7 @@ func TestGetSnapshotStatus(t *testing.T) { { Snapshot: &csi.Snapshot{ Id: defaultID, - SizeBytes: 1000, + SizeBytes: size, SourceVolumeId: "volumeid", CreatedAt: createdAt, Status: &csi.SnapshotStatus{ @@ -689,6 +690,7 @@ func TestGetSnapshotStatus(t *testing.T) { expectError bool expectStatus *csi.SnapshotStatus expectCreateAt int64 + expectSize int64 }{ { name: "success", @@ -701,6 +703,7 @@ func TestGetSnapshotStatus(t *testing.T) { Details: "success", }, expectCreateAt: createdAt, + expectSize: size, }, { name: "gRPC transient error", @@ -741,7 +744,7 @@ func TestGetSnapshotStatus(t *testing.T) { controllerServer.EXPECT().ListSnapshots(gomock.Any(), in).Return(out, injectedErr).Times(1) } - status, createTime, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID) + status, createTime, size, err := csiConn.GetSnapshotStatus(context.Background(), test.snapshotID) if test.expectError && err == nil { t.Errorf("test %q: Expected error, got none", test.name) } @@ -754,6 +757,9 @@ func TestGetSnapshotStatus(t *testing.T) { if test.expectCreateAt != createTime { t.Errorf("test %q: expected createTime: %v, got: %v", test.name, test.expectCreateAt, createTime) } + if test.expectSize != size { + t.Errorf("test %q: expected size: %v, got: %v", test.name, test.expectSize, size) + } } } diff --git a/pkg/controller/csi_handler.go b/pkg/controller/csi_handler.go index 437f4f8f..045433ab 100644 --- a/pkg/controller/csi_handler.go +++ b/pkg/controller/csi_handler.go @@ -32,7 +32,7 @@ import ( type Handler interface { CreateSnapshot(snapshot *crdv1.VolumeSnapshot, volume *v1.PersistentVolume, parameters map[string]string, snapshotterCredentials map[string]string) (string, string, int64, int64, *csi.SnapshotStatus, error) DeleteSnapshot(content *crdv1.VolumeSnapshotContent, snapshotterCredentials map[string]string) error - GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error) + GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error) } // csiHandler is a handler that calls CSI to create/delete volume snapshot. @@ -84,18 +84,19 @@ func (handler *csiHandler) DeleteSnapshot(content *crdv1.VolumeSnapshotContent, return nil } -func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, error) { +func (handler *csiHandler) GetSnapshotStatus(content *crdv1.VolumeSnapshotContent) (*csi.SnapshotStatus, int64, int64, error) { if content.Spec.CSI == nil { - return nil, 0, fmt.Errorf("CSISnapshot not defined in spec") + return nil, 0, 0, fmt.Errorf("CSISnapshot not defined in spec") } ctx, cancel := context.WithTimeout(context.Background(), handler.timeout) defer cancel() - csiSnapshotStatus, timestamp, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle) + csiSnapshotStatus, timestamp, size, err := handler.csiConnection.GetSnapshotStatus(ctx, content.Spec.CSI.SnapshotHandle) if err != nil { - return nil, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err) + return nil, 0, 0, fmt.Errorf("failed to list snapshot data %s: %q", content.Name, err) } - return csiSnapshotStatus, timestamp, nil + return csiSnapshotStatus, timestamp, size, nil + } func makeSnapshotName(prefix, snapshotUID string, snapshotNameUUIDLength int) (string, error) { diff --git a/pkg/controller/framework_test.go b/pkg/controller/framework_test.go index 57a5cf4e..983895d2 100644 --- a/pkg/controller/framework_test.go +++ b/pkg/controller/framework_test.go @@ -1092,6 +1092,7 @@ type listCall struct { // information to return status *csi.SnapshotStatus createTime int64 + size int64 err error } @@ -1202,10 +1203,10 @@ func (f *fakeCSIConnection) DeleteSnapshot(ctx context.Context, snapshotID strin return call.err } -func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, error) { +func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID string) (*csi.SnapshotStatus, int64, int64, error) { if f.listCallCounter >= len(f.listCalls) { f.t.Errorf("Unexpected CSI list Snapshot call: snapshotID=%s, index: %d, calls: %+v", snapshotID, f.createCallCounter, f.createCalls) - return nil, 0, fmt.Errorf("unexpected call") + return nil, 0, 0, fmt.Errorf("unexpected call") } call := f.listCalls[f.listCallCounter] f.listCallCounter++ @@ -1217,10 +1218,10 @@ func (f *fakeCSIConnection) GetSnapshotStatus(ctx context.Context, snapshotID st } if err != nil { - return nil, 0, fmt.Errorf("unexpected call") + return nil, 0, 0, fmt.Errorf("unexpected call") } - return call.status, call.createTime, call.err + return call.status, call.createTime, call.size, call.err } func (f *fakeCSIConnection) Close() error { diff --git a/pkg/controller/snapshot_controller.go b/pkg/controller/snapshot_controller.go index 4c3b8f8f..38da92c7 100644 --- a/pkg/controller/snapshot_controller.go +++ b/pkg/controller/snapshot_controller.go @@ -421,12 +421,12 @@ func (ctrl *csiSnapshotController) checkandBindSnapshotContent(snapshot *crdv1.V } func (ctrl *csiSnapshotController) checkandUpdateSnapshotStatusOperation(snapshot *crdv1.VolumeSnapshot, content *crdv1.VolumeSnapshotContent) (*crdv1.VolumeSnapshot, error) { - status, _, err := ctrl.handler.GetSnapshotStatus(content) + status, _, size, err := ctrl.handler.GetSnapshotStatus(content) if err != nil { return nil, fmt.Errorf("failed to check snapshot status %s with error %v", snapshot.Name, err) } - - newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, time.Now(), nil, IsSnapshotBound(snapshot, content)) + timestamp := time.Now().UnixNano() + newSnapshot, err := ctrl.updateSnapshotStatus(snapshot, status, timestamp, size, IsSnapshotBound(snapshot, content)) if err != nil { return nil, err } @@ -490,7 +490,7 @@ func (ctrl *csiSnapshotController) createSnapshotOperation(snapshot *crdv1.Volum // Update snapshot status with timestamp for i := 0; i < ctrl.createSnapshotContentRetryCount; i++ { glog.V(5).Infof("createSnapshot [%s]: trying to update snapshot creation timestamp", snapshotKey(snapshot)) - newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, time.Unix(0, timestamp), resource.NewQuantity(size, resource.BinarySI), false) + newSnapshot, err = ctrl.updateSnapshotStatus(snapshot, csiSnapshotStatus, timestamp, size, false) if err == nil { break } @@ -638,12 +638,12 @@ func (ctrl *csiSnapshotController) bindandUpdateVolumeSnapshot(snapshotContent * } // UpdateSnapshotStatus converts snapshot status to crdv1.VolumeSnapshotCondition -func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, timestamp time.Time, size *resource.Quantity, bound bool) (*crdv1.VolumeSnapshot, error) { - glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, timestamp) +func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSnapshot, csistatus *csi.SnapshotStatus, createdAt, size int64, bound bool) (*crdv1.VolumeSnapshot, error) { + glog.V(5).Infof("updating VolumeSnapshot[]%s, set status %v, timestamp %v", snapshotKey(snapshot), csistatus, createdAt) status := snapshot.Status change := false timeAt := &metav1.Time{ - Time: timestamp, + Time: time.Unix(0, createdAt), } snapshotClone := snapshot.DeepCopy() @@ -676,8 +676,8 @@ func (ctrl *csiSnapshotController) updateSnapshotStatus(snapshot *crdv1.VolumeSn } } if change { - if size != nil { - status.RestoreSize = size + if size > 0 { + status.RestoreSize = resource.NewQuantity(size, resource.BinarySI) } snapshotClone.Status = status newSnapshotObj, err := ctrl.clientset.VolumesnapshotV1alpha1().VolumeSnapshots(snapshotClone.Namespace).Update(snapshotClone)