diff --git a/bootstrap/bpdoc/bpdoc.go b/bootstrap/bpdoc/bpdoc.go index dcb6f65..608cfac 100644 --- a/bootstrap/bpdoc/bpdoc.go +++ b/bootstrap/bpdoc/bpdoc.go @@ -152,15 +152,18 @@ func setDefaults(properties []Property, defaults reflect.Value) { continue } - if f.Type().Kind() == reflect.Interface { + if f.Kind() == reflect.Interface { f = f.Elem() } - if f.Type().Kind() == reflect.Ptr { + if f.Kind() == reflect.Ptr { + if f.IsNil() { + continue + } f = f.Elem() } - if f.Type().Kind() == reflect.Struct { + if f.Kind() == reflect.Struct { setDefaults(prop.Properties, f) } else { prop.Default = fmt.Sprintf("%v", f.Interface()) diff --git a/proptools/clone.go b/proptools/clone.go index ad78e38..4cc1103 100644 --- a/proptools/clone.go +++ b/proptools/clone.go @@ -43,6 +43,7 @@ func CopyProperties(dstValue, srcValue reflect.Value) { srcFieldValue := srcValue.Field(i) dstFieldValue := dstValue.Field(i) dstFieldInterfaceValue := reflect.Value{} + origDstFieldValue := dstFieldValue switch srcFieldValue.Kind() { case reflect.Bool, reflect.String, reflect.Int, reflect.Uint: @@ -93,7 +94,7 @@ func CopyProperties(dstValue, srcValue reflect.Value) { fallthrough case reflect.Ptr: if srcFieldValue.IsNil() { - dstFieldValue.Set(srcFieldValue) + origDstFieldValue.Set(srcFieldValue) break } @@ -110,13 +111,13 @@ func CopyProperties(dstValue, srcValue reflect.Value) { if dstFieldInterfaceValue.IsValid() { dstFieldInterfaceValue.Set(newValue) } else { - dstFieldValue.Set(newValue) + origDstFieldValue.Set(newValue) } } case reflect.Bool, reflect.String: newValue := reflect.New(srcFieldValue.Type()) newValue.Elem().Set(srcFieldValue) - dstFieldValue.Set(newValue) + origDstFieldValue.Set(newValue) default: panic(fmt.Errorf("can't clone field %q: points to a %s", field.Name, srcFieldValue.Kind())) diff --git a/proptools/extend.go b/proptools/extend.go index 832ee0b..a4c46e8 100644 --- a/proptools/extend.go +++ b/proptools/extend.go @@ -191,11 +191,15 @@ func extendPropertyErrorf(property string, format string, a ...interface{}) *Ext func extendProperties(dst interface{}, src interface{}, filter ExtendPropertyFilterFunc, order ExtendPropertyOrderFunc) error { - dstValue, err := getStruct(dst) + srcValue, err := getStruct(src) if err != nil { + if _, ok := err.(getStructEmptyError); ok { + return nil + } return err } - srcValue, err := getStruct(src) + + dstValue, err := getOrCreateStruct(dst) if err != nil { return err } @@ -212,20 +216,23 @@ func extendProperties(dst interface{}, src interface{}, filter ExtendPropertyFil func extendMatchingProperties(dst []interface{}, src interface{}, filter ExtendPropertyFilterFunc, order ExtendPropertyOrderFunc) error { + srcValue, err := getStruct(src) + if err != nil { + if _, ok := err.(getStructEmptyError); ok { + return nil + } + return err + } + dstValues := make([]reflect.Value, len(dst)) for i := range dst { var err error - dstValues[i], err = getStruct(dst[i]) + dstValues[i], err = getOrCreateStruct(dst[i]) if err != nil { return err } } - srcValue, err := getStruct(src) - if err != nil { - return err - } - return extendPropertiesRecursive(dstValues, srcValue, "", filter, false, order) } @@ -270,6 +277,7 @@ func extendPropertiesRecursive(dstValues []reflect.Value, srcValue reflect.Value found = true dstFieldValue := dstValue.FieldByIndex(dstField.Index) + origDstFieldValue := dstFieldValue if srcFieldValue.Kind() != dstFieldValue.Kind() { return extendPropertyErrorf(propertyName, "mismatched types %s and %s", @@ -306,11 +314,12 @@ func extendPropertiesRecursive(dstValues []reflect.Value, srcValue reflect.Value } // Pointer to a struct - if dstFieldValue.IsNil() != srcFieldValue.IsNil() { - return extendPropertyErrorf(propertyName, "nilitude mismatch") + if srcFieldValue.IsNil() { + continue } if dstFieldValue.IsNil() { - continue + dstFieldValue = reflect.New(srcFieldValue.Type().Elem()) + origDstFieldValue.Set(dstFieldValue) } dstFieldValue = dstFieldValue.Elem() @@ -435,14 +444,32 @@ func extendPropertiesRecursive(dstValues []reflect.Value, srcValue reflect.Value return nil } +type getStructEmptyError struct{} + +func (getStructEmptyError) Error() string { return "interface containing nil pointer" } + +func getOrCreateStruct(in interface{}) (reflect.Value, error) { + value, err := getStruct(in) + if _, ok := err.(getStructEmptyError); ok { + value := reflect.ValueOf(in) + newValue := reflect.New(value.Type().Elem()) + value.Set(newValue) + } + + return value, err +} + func getStruct(in interface{}) (reflect.Value, error) { value := reflect.ValueOf(in) if value.Kind() != reflect.Ptr { return reflect.Value{}, fmt.Errorf("expected pointer to struct, got %T", in) } - value = value.Elem() - if value.Kind() != reflect.Struct { + if value.Type().Elem().Kind() != reflect.Struct { return reflect.Value{}, fmt.Errorf("expected pointer to struct, got %T", in) } + if value.IsNil() { + return reflect.Value{}, getStructEmptyError{} + } + value = value.Elem() return value, nil } diff --git a/proptools/extend_test.go b/proptools/extend_test.go index 0acd139..3fd61f2 100644 --- a/proptools/extend_test.go +++ b/proptools/extend_test.go @@ -496,12 +496,69 @@ func appendPropertiesTestCases() []appendPropertyTestCase { }, }, }, + { + // Nil pointer to a struct + in1: &struct { + Nested *struct { + S string + } + }{}, + in2: &struct { + Nested *struct { + S string + } + }{ + Nested: &struct { + S string + }{ + S: "string", + }, + }, + out: &struct { + Nested *struct { + S string + } + }{ + Nested: &struct { + S string + }{ + S: "string", + }, + }, + }, + { + // Nil pointer to a struct in an interface + in1: &struct { + Nested interface{} + }{ + Nested: (*struct{ S string })(nil), + }, + in2: &struct { + Nested interface{} + }{ + Nested: &struct { + S string + }{ + S: "string", + }, + }, + out: &struct { + Nested interface{} + }{ + Nested: &struct { + S string + }{ + S: "string", + }, + }, + }, // Errors { // Non-pointer in1 in1: struct{}{}, + in2: &struct{}{}, err: errors.New("expected pointer to struct, got struct {}"), out: struct{}{}, }, @@ -515,6 +572,7 @@ func appendPropertiesTestCases() []appendPropertyTestCase { { // Non-struct in1 in1: &[]string{"bad"}, + in2: &struct{}{}, err: errors.New("expected pointer to struct, got *[]string"), out: &[]string{"bad"}, }, @@ -606,23 +664,6 @@ func appendPropertiesTestCases() []appendPropertyTestCase { }, err: extendPropertyErrorf("s", "interface not a pointer"), }, - { - // Pointer nilitude mismatch - in1: &struct{ S *struct{ S string } }{ - S: &struct{ S string }{ - S: "string1", - }, - }, - in2: &struct{ S *struct{ S string } }{ - S: nil, - }, - out: &struct{ S *struct{ S string } }{ - S: &struct{ S string }{ - S: "string1", - }, - }, - err: extendPropertyErrorf("s", "nilitude mismatch"), - }, { // Pointer not a struct in1: &struct{ S *[]string }{ @@ -912,6 +953,7 @@ func appendMatchingPropertiesTestCases() []appendMatchingPropertiesTestCase { { // Non-pointer in1 in1: []interface{}{struct{}{}}, + in2: &struct{}{}, err: errors.New("expected pointer to struct, got struct {}"), out: []interface{}{struct{}{}}, }, @@ -925,6 +967,7 @@ func appendMatchingPropertiesTestCases() []appendMatchingPropertiesTestCase { { // Non-struct in1 in1: []interface{}{&[]string{"bad"}}, + in2: &struct{}{}, err: errors.New("expected pointer to struct, got *[]string"), out: []interface{}{&[]string{"bad"}}, }, diff --git a/proptools/typeequal.go b/proptools/typeequal.go index e68f91a..d9b3c18 100644 --- a/proptools/typeequal.go +++ b/proptools/typeequal.go @@ -46,12 +46,14 @@ func typeEqual(v1, v2 reflect.Value) bool { if v1.Type().Elem().Kind() != reflect.Struct { return true } - if v1.IsNil() != v2.IsNil() { - return false - } - if v1.IsNil() { + if v1.IsNil() && !v2.IsNil() { + return concreteType(v2) + } else if v2.IsNil() && !v1.IsNil() { + return concreteType(v1) + } else if v1.IsNil() && v2.IsNil() { return true } + v1 = v1.Elem() v2 = v2.Elem() } @@ -74,3 +76,34 @@ func typeEqual(v1, v2 reflect.Value) bool { return true } + +// Returns true if v recursively contains no interfaces +func concreteType(v reflect.Value) bool { + if v.Kind() == reflect.Interface { + return false + } + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return true + } + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return true + } + + for i := 0; i < v.NumField(); i++ { + v := v.Field(i) + + switch kind := v.Kind(); kind { + case reflect.Interface, reflect.Ptr, reflect.Struct: + if !concreteType(v) { + return false + } + } + } + + return true +} diff --git a/proptools/typeequal_test.go b/proptools/typeequal_test.go index d374609..cd86cd3 100644 --- a/proptools/typeequal_test.go +++ b/proptools/typeequal_test.go @@ -88,7 +88,7 @@ var typeEqualTestCases = []struct { // Mismatching nilitude embedded pointer to struct in1: &struct{ S *struct{ S1 string } }{S: &struct{ S1 string }{}}, in2: &struct{ S *struct{ S1 string } }{}, - out: false, + out: true, }, { // Matching embedded interface to pointer to struct @@ -116,13 +116,13 @@ var typeEqualTestCases = []struct { }, { // Matching pointer to non-struct - in1: struct{ S1 *string }{ S1: StringPtr("test1") }, - in2: struct{ S1 *string }{ S1: StringPtr("test2") }, + in1: struct{ S1 *string }{S1: StringPtr("test1")}, + in2: struct{ S1 *string }{S1: StringPtr("test2")}, out: true, }, { // Matching nilitude pointer to non-struct - in1: struct{ S1 *string }{ S1: StringPtr("test1") }, + in1: struct{ S1 *string }{S1: StringPtr("test1")}, in2: struct{ S1 *string }{}, out: true, }, diff --git a/unpack.go b/unpack.go index 3e21203..3580eff 100644 --- a/unpack.go +++ b/unpack.go @@ -139,6 +139,8 @@ func unpackStructValue(namePrefix string, structValue reflect.Value, panic(fmt.Errorf("field %s is not settable", propertyName)) } + origFieldValue := fieldValue + // To make testing easier we validate the struct field's type regardless // of whether or not the property was specified in the parsed string. switch kind := fieldValue.Kind(); kind { @@ -163,7 +165,8 @@ func unpackStructValue(namePrefix string, structValue reflect.Value, switch ptrKind := fieldValue.Type().Elem().Kind(); ptrKind { case reflect.Struct: if fieldValue.IsNil() { - panic(fmt.Errorf("field %s contains a nil pointer", propertyName)) + fieldValue = reflect.New(fieldValue.Type().Elem()) + origFieldValue.Set(fieldValue) } fieldValue = fieldValue.Elem() case reflect.Bool, reflect.String: diff --git a/unpack_test.go b/unpack_test.go index 7b314dd..ea4dbc7 100644 --- a/unpack_test.go +++ b/unpack_test.go @@ -28,15 +28,17 @@ import ( var validUnpackTestCases = []struct { input string output []interface{} + empty []interface{} errs []error }{ - {` - m { - name: "abc", - blank: "", - } + { + input: ` + m { + name: "abc", + blank: "", + } `, - []interface{}{ + output: []interface{}{ struct { Name *string Blank *string @@ -47,46 +49,46 @@ var validUnpackTestCases = []struct { Unset: nil, }, }, - nil, }, - {` - m { - name: "abc", - } + { + input: ` + m { + name: "abc", + } `, - []interface{}{ + output: []interface{}{ struct { Name string }{ Name: "abc", }, }, - nil, }, - {` - m { - isGood: true, - } + { + input: ` + m { + isGood: true, + } `, - []interface{}{ + output: []interface{}{ struct { IsGood bool }{ IsGood: true, }, }, - nil, }, - {` - m { - isGood: true, - isBad: false, - } + { + input: ` + m { + isGood: true, + isBad: false, + } `, - []interface{}{ + output: []interface{}{ struct { IsGood *bool IsBad *bool @@ -97,17 +99,17 @@ var validUnpackTestCases = []struct { IsUgly: nil, }, }, - nil, }, - {` - m { - stuff: ["asdf", "jkl;", "qwert", - "uiop", "bnm,"], - empty: [] - } + { + input: ` + m { + stuff: ["asdf", "jkl;", "qwert", + "uiop", "bnm,"], + empty: [] + } `, - []interface{}{ + output: []interface{}{ struct { Stuff []string Empty []string @@ -118,17 +120,17 @@ var validUnpackTestCases = []struct { Nil: nil, }, }, - nil, }, - {` - m { - nested: { - name: "abc", + { + input: ` + m { + nested: { + name: "abc", + } } - } `, - []interface{}{ + output: []interface{}{ struct { Nested struct { Name string @@ -139,17 +141,17 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, - {` - m { - nested: { - name: "def", + { + input: ` + m { + nested: { + name: "def", + } } - } `, - []interface{}{ + output: []interface{}{ struct { Nested interface{} }{ @@ -158,19 +160,19 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, - {` - m { - nested: { - foo: "abc", - }, - bar: false, - baz: ["def", "ghi"], - } + { + input: ` + m { + nested: { + foo: "abc", + }, + bar: false, + baz: ["def", "ghi"], + } `, - []interface{}{ + output: []interface{}{ struct { Nested struct { Foo string @@ -185,19 +187,19 @@ var validUnpackTestCases = []struct { Baz: []string{"def", "ghi"}, }, }, - nil, }, - {` - m { - nested: { - foo: "abc", - }, - bar: false, - baz: ["def", "ghi"], - } + { + input: ` + m { + nested: { + foo: "abc", + }, + bar: false, + baz: ["def", "ghi"], + } `, - []interface{}{ + output: []interface{}{ struct { Nested struct { Foo string `allowNested:"true"` @@ -214,19 +216,19 @@ var validUnpackTestCases = []struct { Baz: []string{"def", "ghi"}, }, }, - nil, }, - {` - m { - nested: { - foo: "abc", - }, - bar: false, - baz: ["def", "ghi"], - } + { + input: ` + m { + nested: { + foo: "abc", + }, + bar: false, + baz: ["def", "ghi"], + } `, - []interface{}{ + output: []interface{}{ struct { Nested struct { Foo string @@ -241,24 +243,25 @@ var validUnpackTestCases = []struct { Baz: []string{"def", "ghi"}, }, }, - []error{ + errs: []error{ &Error{ Err: fmt.Errorf("filtered field nested.foo cannot be set in a Blueprint file"), - Pos: mkpos(27, 4, 8), + Pos: mkpos(30, 4, 9), }, }, }, // Anonymous struct - {` - m { - name: "abc", - nested: { - name: "def", - }, - } + { + input: ` + m { + name: "abc", + nested: { + name: "def", + }, + } `, - []interface{}{ + output: []interface{}{ struct { EmbeddedStruct Nested struct { @@ -277,19 +280,19 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, // Anonymous interface - {` - m { - name: "abc", - nested: { - name: "def", - }, - } + { + input: ` + m { + name: "abc", + nested: { + name: "def", + }, + } `, - []interface{}{ + output: []interface{}{ struct { EmbeddedInterface Nested struct { @@ -308,19 +311,19 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, // Anonymous struct with name collision - {` - m { - name: "abc", - nested: { - name: "def", - }, - } + { + input: ` + m { + name: "abc", + nested: { + name: "def", + }, + } `, - []interface{}{ + output: []interface{}{ struct { Name string EmbeddedStruct @@ -344,19 +347,19 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, // Anonymous interface with name collision - {` - m { - name: "abc", - nested: { - name: "def", - }, - } + { + input: ` + m { + name: "abc", + nested: { + name: "def", + }, + } `, - []interface{}{ + output: []interface{}{ struct { Name string EmbeddedInterface @@ -380,21 +383,21 @@ var validUnpackTestCases = []struct { }, }, }, - nil, }, // Variables - {` - list = ["abc"] - string = "def" - list_with_variable = [string] - m { - name: string, - list: list, - list2: list_with_variable, - } + { + input: ` + list = ["abc"] + string = "def" + list_with_variable = [string] + m { + name: string, + list: list, + list2: list_with_variable, + } `, - []interface{}{ + output: []interface{}{ struct { Name string List []string @@ -405,18 +408,18 @@ var validUnpackTestCases = []struct { List2: []string{"def"}, }, }, - nil, }, // Multiple property structs - {` - m { - nested: { - name: "abc", + { + input: ` + m { + nested: { + name: "abc", + } } - } `, - []interface{}{ + output: []interface{}{ struct { Nested struct { Name string @@ -438,7 +441,62 @@ var validUnpackTestCases = []struct { struct { }{}, }, - nil, + }, + + // Nil pointer to struct + { + input: ` + m { + nested: { + name: "abc", + } + } + `, + output: []interface{}{ + struct { + Nested *struct { + Name string + } + }{ + Nested: &struct{ Name string }{ + Name: "abc", + }, + }, + }, + empty: []interface{}{ + &struct { + Nested *struct { + Name string + } + }{}, + }, + }, + + // Interface containing nil pointer to struct + { + input: ` + m { + nested: { + name: "abc", + } + } + `, + output: []interface{}{ + struct { + Nested interface{} + }{ + Nested: &EmbeddedStruct{ + Name: "abc", + }, + }, + }, + empty: []interface{}{ + &struct { + Nested interface{} + }{ + Nested: (*EmbeddedStruct)(nil), + }, + }, }, } @@ -464,9 +522,13 @@ func TestUnpackProperties(t *testing.T) { continue } - output := []interface{}{} - for _, p := range testCase.output { - output = append(output, proptools.CloneEmptyProperties(reflect.ValueOf(p)).Interface()) + var output []interface{} + if len(testCase.empty) > 0 { + output = testCase.empty + } else { + for _, p := range testCase.output { + output = append(output, proptools.CloneEmptyProperties(reflect.ValueOf(p)).Interface()) + } } _, errs = unpackProperties(module.Properties, output...) if len(errs) != 0 && len(testCase.errs) == 0 {