diff --git a/libs/dyn/convert/from_typed.go b/libs/dyn/convert/from_typed.go index 9ad0fad6aa..b286eba57d 100644 --- a/libs/dyn/convert/from_typed.go +++ b/libs/dyn/convert/from_typed.go @@ -55,6 +55,16 @@ func fromTyped(src any, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, } } + // Handle SDK native types using JSON marshaling. + // Check for Invalid kind first to avoid panic when calling Type() on invalid value. + if srcv.Kind() != reflect.Invalid && isSDKNativeType(srcv.Type()) { + v, err := fromTypedSDKNative(srcv, ref, options...) + if err != nil { + return dyn.InvalidValue, err + } + return v.WithLocations(ref.Locations()), nil + } + var v dyn.Value var err error switch srcv.Kind() { diff --git a/libs/dyn/convert/normalize.go b/libs/dyn/convert/normalize.go index b0d2ebfa89..787d1ff9e7 100644 --- a/libs/dyn/convert/normalize.go +++ b/libs/dyn/convert/normalize.go @@ -41,6 +41,11 @@ func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen [] typ = typ.Elem() } + // Handle SDK native types as strings since they use custom JSON marshaling. + if isSDKNativeType(typ) { + return n.normalizeString(reflect.TypeOf(""), src, path) + } + switch typ.Kind() { case reflect.Struct: return n.normalizeStruct(typ, src, append(seen, typ), path) diff --git a/libs/dyn/convert/sdk_native_types.go b/libs/dyn/convert/sdk_native_types.go new file mode 100644 index 0000000000..934b5a5f7c --- /dev/null +++ b/libs/dyn/convert/sdk_native_types.go @@ -0,0 +1,92 @@ +package convert + +import ( + "encoding/json" + "fmt" + "reflect" + "slices" + + "github.com/databricks/cli/libs/dyn" + sdkduration "github.com/databricks/databricks-sdk-go/common/types/duration" + sdkfieldmask "github.com/databricks/databricks-sdk-go/common/types/fieldmask" + sdktime "github.com/databricks/databricks-sdk-go/common/types/time" +) + +// sdkNativeTypes is a list of SDK native types that use custom JSON marshaling +// and should be treated as strings in dyn.Value. These types all implement +// json.Marshaler and json.Unmarshaler interfaces. +var sdkNativeTypes = []reflect.Type{ + reflect.TypeFor[sdkduration.Duration](), // Protobuf duration format (e.g., "300s") + reflect.TypeFor[sdktime.Time](), // RFC3339 timestamp format (e.g., "2023-12-25T10:30:00Z") + reflect.TypeFor[sdkfieldmask.FieldMask](), // Comma-separated paths (e.g., "name,age,email") +} + +// isSDKNativeType checks if the given type is one of the SDK's native types +// that use custom JSON marshaling and should be treated as strings. +func isSDKNativeType(typ reflect.Type) bool { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + for _, sdkType := range sdkNativeTypes { + if typ == sdkType { + return true + } + } + return false +} + +// fromTypedSDKNative converts SDK native types to dyn.Value. +// SDK native types (duration.Duration, time.Time, fieldmask.FieldMask) use +// custom JSON marshaling with string representations. +func fromTypedSDKNative(src reflect.Value, ref dyn.Value, options ...fromTypedOptions) (dyn.Value, error) { + // Check for zero value first. + if src.IsZero() && !slices.Contains(options, includeZeroValues) { + return dyn.NilValue, nil + } + + // Use JSON marshaling since SDK native types implement json.Marshaler. + jsonBytes, err := json.Marshal(src.Interface()) + if err != nil { + return dyn.InvalidValue, err + } + + // All SDK native types marshal to JSON strings. Unmarshal to get the raw string value. + // For example: duration.Duration(300s) -> JSON "300s" -> string "300s" + var str string + if err := json.Unmarshal(jsonBytes, &str); err != nil { + return dyn.InvalidValue, err + } + + // Handle empty string as zero value. + if str == "" && !slices.Contains(options, includeZeroValues) { + return dyn.NilValue, nil + } + + return dyn.V(str), nil +} + +// toTypedSDKNative converts a dyn.Value to an SDK native type. +// SDK native types (duration.Duration, time.Time, fieldmask.FieldMask) use +// custom JSON marshaling with string representations. +func toTypedSDKNative(dst reflect.Value, src dyn.Value) error { + switch src.Kind() { + case dyn.KindString: + // Use JSON unmarshaling since SDK native types implement json.Unmarshaler. + // Marshal the string to create a valid JSON string literal for unmarshaling. + jsonBytes, err := json.Marshal(src.MustString()) + if err != nil { + return err + } + return json.Unmarshal(jsonBytes, dst.Addr().Interface()) + case dyn.KindNil: + dst.SetZero() + return nil + default: + // Fall through to the error case. + } + + return TypeError{ + value: src, + msg: fmt.Sprintf("expected a string, found a %s", src.Kind()), + } +} diff --git a/libs/dyn/convert/sdk_native_types_test.go b/libs/dyn/convert/sdk_native_types_test.go new file mode 100644 index 0000000000..454b5fea18 --- /dev/null +++ b/libs/dyn/convert/sdk_native_types_test.go @@ -0,0 +1,366 @@ +package convert + +import ( + "testing" + "time" + + "github.com/databricks/cli/libs/diag" + "github.com/databricks/cli/libs/dyn" + sdkduration "github.com/databricks/databricks-sdk-go/common/types/duration" + sdkfieldmask "github.com/databricks/databricks-sdk-go/common/types/fieldmask" + sdktime "github.com/databricks/databricks-sdk-go/common/types/time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Roundtrip tests - verify SDK native types convert to dyn.Value and back for both value and pointer types + +func TestDurationRoundtrip(t *testing.T) { + tests := []struct { + name string + duration time.Duration + expectedString string + }{ + {"5min", 5 * time.Minute, "300s"}, + {"7days", 7 * 24 * time.Hour, "604800s"}, + {"1hour", 1 * time.Hour, "3600s"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test value type + t.Run("value", func(t *testing.T) { + src := *sdkduration.New(tt.duration) + dynValue, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out sdkduration.Duration + err = ToTyped(&out, dynValue) + require.NoError(t, err) + assert.Equal(t, src.AsDuration(), out.AsDuration()) + }) + + // Test pointer type + t.Run("pointer", func(t *testing.T) { + src := *sdkduration.New(tt.duration) + dynValue, err := FromTyped(&src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out *sdkduration.Duration + err = ToTyped(&out, dynValue) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, src.AsDuration(), out.AsDuration()) + }) + }) + } +} + +func TestTimeRoundtrip(t *testing.T) { + tests := []struct { + name string + time time.Time + expectedString string + }{ + { + "no_nanos", + time.Date(2023, 12, 25, 10, 30, 0, 0, time.UTC), + "2023-12-25T10:30:00Z", + }, + { + "with_nanos", + time.Date(2023, 12, 25, 10, 30, 0, 123456789, time.UTC), + "2023-12-25T10:30:00.123456789Z", + }, + { + "epoch", + time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), + "1970-01-01T00:00:00Z", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test value type + t.Run("value", func(t *testing.T) { + src := *sdktime.New(tt.time) + dynValue, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out sdktime.Time + err = ToTyped(&out, dynValue) + require.NoError(t, err) + assert.Equal(t, src.AsTime(), out.AsTime()) + }) + + // Test pointer type + t.Run("pointer", func(t *testing.T) { + src := *sdktime.New(tt.time) + dynValue, err := FromTyped(&src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out *sdktime.Time + err = ToTyped(&out, dynValue) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, src.AsTime(), out.AsTime()) + }) + }) + } +} + +func TestFieldMaskRoundtrip(t *testing.T) { + tests := []struct { + name string + paths []string + expectedString string + }{ + {"single", []string{"name"}, "name"}, + {"multiple", []string{"name", "age", "email"}, "name,age,email"}, + {"nested", []string{"user.name", "user.email"}, "user.name,user.email"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test value type + t.Run("value", func(t *testing.T) { + src := *sdkfieldmask.New(tt.paths) + dynValue, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out sdkfieldmask.FieldMask + err = ToTyped(&out, dynValue) + require.NoError(t, err) + assert.Equal(t, src.Paths, out.Paths) + }) + + // Test pointer type + t.Run("pointer", func(t *testing.T) { + src := *sdkfieldmask.New(tt.paths) + dynValue, err := FromTyped(&src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, tt.expectedString, dynValue.MustString()) + + var out *sdkfieldmask.FieldMask + err = ToTyped(&out, dynValue) + require.NoError(t, err) + require.NotNil(t, out) + assert.Equal(t, src.Paths, out.Paths) + }) + }) + } +} + +// Edge case tests + +func TestNilValuesFromTyped(t *testing.T) { + t.Run("duration", func(t *testing.T) { + var src *sdkduration.Duration + nv, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, dyn.NilValue, nv) + }) + + t.Run("time", func(t *testing.T) { + var src *sdktime.Time + nv, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, dyn.NilValue, nv) + }) + + t.Run("fieldmask", func(t *testing.T) { + var src *sdkfieldmask.FieldMask + nv, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, dyn.NilValue, nv) + }) +} + +func TestNilValuesNormalize(t *testing.T) { + t.Run("duration", func(t *testing.T) { + var typ *sdkduration.Duration + vin := dyn.NilValue + vout, diags := Normalize(typ, vin) + assert.Len(t, diags, 1) + assert.Equal(t, diag.Warning, diags[0].Severity) + assert.Equal(t, `expected a string value, found null`, diags[0].Summary) + assert.Equal(t, dyn.InvalidValue, vout) + }) + + t.Run("time", func(t *testing.T) { + var typ *sdktime.Time + vin := dyn.NilValue + vout, diags := Normalize(typ, vin) + assert.Len(t, diags, 1) + assert.Equal(t, diag.Warning, diags[0].Severity) + assert.Equal(t, `expected a string value, found null`, diags[0].Summary) + assert.Equal(t, dyn.InvalidValue, vout) + }) + + t.Run("fieldmask", func(t *testing.T) { + var typ *sdkfieldmask.FieldMask + vin := dyn.NilValue + vout, diags := Normalize(typ, vin) + assert.Len(t, diags, 1) + assert.Equal(t, diag.Warning, diags[0].Severity) + assert.Equal(t, `expected a string value, found null`, diags[0].Summary) + assert.Equal(t, dyn.InvalidValue, vout) + }) +} + +func TestToTypedErrors(t *testing.T) { + wrongTypeInput := dyn.V(map[string]dyn.Value{"foo": dyn.V("bar")}) + + t.Run("duration_wrong_type", func(t *testing.T) { + var out sdkduration.Duration + err := ToTyped(&out, wrongTypeInput) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected a string") + }) + + t.Run("duration_invalid_format", func(t *testing.T) { + var out sdkduration.Duration + err := ToTyped(&out, dyn.V("7d")) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid google.protobuf.Duration value") + }) + + t.Run("time_wrong_type", func(t *testing.T) { + var out sdktime.Time + err := ToTyped(&out, wrongTypeInput) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected a string") + }) + + t.Run("time_invalid_format", func(t *testing.T) { + var out sdktime.Time + err := ToTyped(&out, dyn.V("not-a-time")) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid google.protobuf.Timestamp value") + }) + + t.Run("fieldmask_wrong_type", func(t *testing.T) { + var out sdkfieldmask.FieldMask + err := ToTyped(&out, wrongTypeInput) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected a string") + }) +} + +func TestSpecialCases(t *testing.T) { + t.Run("duration_zero", func(t *testing.T) { + var src sdkduration.Duration + nv, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, dyn.NilValue, nv) + }) + + t.Run("fieldmask_empty_fromtyped", func(t *testing.T) { + src := sdkfieldmask.New([]string{}) + nv, err := FromTyped(src, dyn.NilValue) + require.NoError(t, err) + // Empty field mask marshals to empty string + assert.Equal(t, dyn.V(""), nv) + }) + + t.Run("fieldmask_empty_totyped", func(t *testing.T) { + var out sdkfieldmask.FieldMask + err := ToTyped(&out, dyn.V("")) + require.NoError(t, err) + assert.Empty(t, out.Paths) + }) +} + +// End-to-end tests with structs containing SDK native types + +func TestSDKTypesRoundTripWithPostgresBranchSpec(t *testing.T) { + type BranchSpec struct { + ExpireTime *sdktime.Time `json:"expire_time,omitempty"` + SourceBranchTime *sdktime.Time `json:"source_branch_time,omitempty"` + Ttl *sdkduration.Duration `json:"ttl,omitempty"` + IsProtected bool `json:"is_protected,omitempty"` + } + + original := BranchSpec{ + ExpireTime: sdktime.New(time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC)), + SourceBranchTime: sdktime.New(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + Ttl: sdkduration.New(7 * 24 * time.Hour), + IsProtected: true, + } + + dynValue, err := FromTyped(original, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, "2024-12-31T23:59:59Z", dynValue.Get("expire_time").MustString()) + assert.Equal(t, "2024-01-01T00:00:00Z", dynValue.Get("source_branch_time").MustString()) + assert.Equal(t, "604800s", dynValue.Get("ttl").MustString()) + assert.True(t, dynValue.Get("is_protected").MustBool()) + + var roundtrip BranchSpec + err = ToTyped(&roundtrip, dynValue) + require.NoError(t, err) + require.NotNil(t, roundtrip.ExpireTime) + require.NotNil(t, roundtrip.SourceBranchTime) + require.NotNil(t, roundtrip.Ttl) + assert.Equal(t, original.ExpireTime.AsTime(), roundtrip.ExpireTime.AsTime()) + assert.Equal(t, original.SourceBranchTime.AsTime(), roundtrip.SourceBranchTime.AsTime()) + assert.Equal(t, original.Ttl.AsDuration(), roundtrip.Ttl.AsDuration()) + assert.Equal(t, original.IsProtected, roundtrip.IsProtected) +} + +func TestSDKTypesRoundTripWithUpdateRequest(t *testing.T) { + type UpdateRequest struct { + Name string `json:"name"` + UpdateMask sdkfieldmask.FieldMask `json:"update_mask"` + } + + original := UpdateRequest{ + Name: "projects/123/branches/456", + UpdateMask: *sdkfieldmask.New([]string{"spec.ttl", "spec.is_protected"}), + } + + dynValue, err := FromTyped(original, dyn.NilValue) + require.NoError(t, err) + assert.Equal(t, "projects/123/branches/456", dynValue.Get("name").MustString()) + assert.Equal(t, "spec.ttl,spec.is_protected", dynValue.Get("update_mask").MustString()) + + var roundtrip UpdateRequest + err = ToTyped(&roundtrip, dynValue) + require.NoError(t, err) + assert.Equal(t, original.Name, roundtrip.Name) + assert.Equal(t, []string{"spec.ttl", "spec.is_protected"}, roundtrip.UpdateMask.Paths) +} + +func TestSDKTypesNormalizeWithStruct(t *testing.T) { + type BranchSpec struct { + ExpireTime *sdktime.Time `json:"expire_time,omitempty"` + Ttl *sdkduration.Duration `json:"ttl,omitempty"` + IsProtected bool `json:"is_protected,omitempty"` + } + + var typ BranchSpec + vin := dyn.V(map[string]dyn.Value{ + "expire_time": dyn.V("2024-12-31T23:59:59Z"), + "ttl": dyn.V("604800s"), + "is_protected": dyn.V(true), + }) + + vout, diags := Normalize(typ, vin) + assert.Empty(t, diags) + assert.Equal(t, "2024-12-31T23:59:59Z", vout.Get("expire_time").MustString()) + assert.Equal(t, "604800s", vout.Get("ttl").MustString()) + assert.True(t, vout.Get("is_protected").MustBool()) + + var out BranchSpec + err := ToTyped(&out, vout) + require.NoError(t, err) + require.NotNil(t, out.ExpireTime) + require.NotNil(t, out.Ttl) + assert.Equal(t, time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC), out.ExpireTime.AsTime()) + assert.Equal(t, 7*24*time.Hour, out.Ttl.AsDuration()) + assert.True(t, out.IsProtected) +} diff --git a/libs/dyn/convert/to_typed.go b/libs/dyn/convert/to_typed.go index 7fa386e5b5..9899e650b8 100644 --- a/libs/dyn/convert/to_typed.go +++ b/libs/dyn/convert/to_typed.go @@ -37,6 +37,11 @@ func ToTyped(dst any, src dyn.Value) error { panic("cannot set destination value") } + // Handle SDK native types using JSON unmarshaling. + if isSDKNativeType(dstv.Type()) { + return toTypedSDKNative(dstv, src) + } + switch dstv.Kind() { case reflect.Struct: return toTypedStruct(dstv, src)