diff --git a/context.go b/context.go index 5f8ee26..00dd7d8 100644 --- a/context.go +++ b/context.go @@ -4159,7 +4159,7 @@ func (c *Context) VerifyProvidersWereUnchanged() []error { for m := range toProcess { for i, provider := range m.providers { if provider != nil { - hash, err := proptools.HashProvider(provider) + hash, err := proptools.CalculateHash(provider) if err != nil { errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set, and no longer hashable afterwards: %s", providerRegistry[i].typ, m.Name(), err.Error())) continue diff --git a/proptools/hash_provider.go b/proptools/hash_provider.go index 6205bc7..c75bb7f 100644 --- a/proptools/hash_provider.go +++ b/proptools/hash_provider.go @@ -18,32 +18,31 @@ import ( "cmp" "encoding/binary" "fmt" - "hash/maphash" + "hash" + "hash/fnv" "math" "reflect" "sort" + "unsafe" ) -var seed maphash.Seed = maphash.MakeSeed() - // byte to insert between elements of lists, fields of structs/maps, etc in order // to try and make sure the hash is different when values are moved around between // elements. 36 is arbitrary, but it's the ascii code for a record separator var recordSeparator []byte = []byte{36} -func HashProvider(provider interface{}) (uint64, error) { - hasher := maphash.Hash{} - hasher.SetSeed(seed) +func CalculateHash(value interface{}) (uint64, error) { + hasher := fnv.New64() ptrs := make(map[uintptr]bool) - v := reflect.ValueOf(provider) + v := reflect.ValueOf(value) var err error if v.IsValid() { - err = hashProviderInternal(&hasher, v, ptrs) + err = calculateHashInternal(hasher, v, ptrs) } return hasher.Sum64(), err } -func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintptr]bool) error { +func calculateHashInternal(hasher hash.Hash64, v reflect.Value, ptrs map[uintptr]bool) error { var int64Array [8]byte int64Buf := int64Array[:] binary.LittleEndian.PutUint64(int64Buf, uint64(v.Kind())) @@ -55,7 +54,7 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt hasher.Write(int64Buf) for i := 0; i < v.NumField(); i++ { hasher.Write(recordSeparator) - err := hashProviderInternal(hasher, v.Field(i), ptrs) + err := calculateHashInternal(hasher, v.Field(i), ptrs) if err != nil { return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error()) } @@ -77,12 +76,12 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt }) for i := 0; i < v.Len(); i++ { hasher.Write(recordSeparator) - err := hashProviderInternal(hasher, keys[indexes[i]], ptrs) + err := calculateHashInternal(hasher, keys[indexes[i]], ptrs) if err != nil { return fmt.Errorf("in map: %s", err.Error()) } hasher.Write(recordSeparator) - err = hashProviderInternal(hasher, keys[indexes[i]], ptrs) + err = calculateHashInternal(hasher, keys[indexes[i]], ptrs) if err != nil { return fmt.Errorf("in map: %s", err.Error()) } @@ -92,7 +91,7 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt hasher.Write(int64Buf) for i := 0; i < v.Len(); i++ { hasher.Write(recordSeparator) - err := hashProviderInternal(hasher, v.Index(i), ptrs) + err := calculateHashInternal(hasher, v.Index(i), ptrs) if err != nil { return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error()) } @@ -103,15 +102,16 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt hasher.Write(int64Buf[:1]) return nil } - addr := v.Pointer() - binary.LittleEndian.PutUint64(int64Buf, uint64(addr)) + // Hardcoded value to indicate it is a pointer + binary.LittleEndian.PutUint64(int64Buf, uint64(0x55)) hasher.Write(int64Buf) + addr := v.Pointer() if _, ok := ptrs[addr]; ok { // We could make this an error if we want to disallow pointer cycles in the future return nil } ptrs[addr] = true - err := hashProviderInternal(hasher, v.Elem(), ptrs) + err := calculateHashInternal(hasher, v.Elem(), ptrs) if err != nil { return fmt.Errorf("in pointer: %s", err.Error()) } @@ -122,13 +122,20 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt } else { // The only way get the pointer out of an interface to hash it or check for cycles // would be InterfaceData(), but that's deprecated and seems like it has undefined behavior. - err := hashProviderInternal(hasher, v.Elem(), ptrs) + err := calculateHashInternal(hasher, v.Elem(), ptrs) if err != nil { return fmt.Errorf("in interface: %s", err.Error()) } } case reflect.String: - hasher.WriteString(v.String()) + strLen := len(v.String()) + if strLen == 0 { + // unsafe.StringData is unspecified in this case + int64Buf[0] = 0 + hasher.Write(int64Buf[:1]) + return nil + } + hasher.Write(unsafe.Slice(unsafe.StringData(v.String()), strLen)) case reflect.Bool: if v.Bool() { int64Buf[0] = 1 @@ -146,7 +153,7 @@ func hashProviderInternal(hasher *maphash.Hash, v reflect.Value, ptrs map[uintpt binary.LittleEndian.PutUint64(int64Buf, math.Float64bits(v.Float())) hasher.Write(int64Buf) default: - return fmt.Errorf("providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String()) + return fmt.Errorf("data may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String()) } return nil } diff --git a/proptools/hash_provider_test.go b/proptools/hash_provider_test.go index 1c97aec..338c6e4 100644 --- a/proptools/hash_provider_test.go +++ b/proptools/hash_provider_test.go @@ -5,9 +5,9 @@ import ( "testing" ) -func mustHash(t *testing.T, provider interface{}) uint64 { +func mustHash(t *testing.T, data interface{}) uint64 { t.Helper() - result, err := HashProvider(provider) + result, err := CalculateHash(data) if err != nil { t.Fatal(err) } @@ -15,11 +15,11 @@ func mustHash(t *testing.T, provider interface{}) uint64 { } func TestHashingMapGetsSameResults(t *testing.T) { - provider := map[string]string{"foo": "bar", "baz": "qux"} - first := mustHash(t, provider) - second := mustHash(t, provider) - third := mustHash(t, provider) - fourth := mustHash(t, provider) + data := map[string]string{"foo": "bar", "baz": "qux"} + first := mustHash(t, data) + second := mustHash(t, data) + third := mustHash(t, data) + fourth := mustHash(t, data) if first != second || second != third || third != fourth { t.Fatal("Did not get the same result every time for a map") } @@ -27,29 +27,29 @@ func TestHashingMapGetsSameResults(t *testing.T) { func TestHashingNonSerializableTypesFails(t *testing.T) { testCases := []struct { - name string - provider interface{} + name string + data interface{} }{ { - name: "function pointer", - provider: []func(){nil}, + name: "function pointer", + data: []func(){nil}, }, { - name: "channel", - provider: []chan int{make(chan int)}, + name: "channel", + data: []chan int{make(chan int)}, }, { - name: "list with non-serializable type", - provider: []interface{}{"foo", make(chan int)}, + name: "list with non-serializable type", + data: []interface{}{"foo", make(chan int)}, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - _, err := HashProvider(testCase) + _, err := CalculateHash(testCase) if err == nil { t.Fatal("Expected hashing error but didn't get one") } - expected := "providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers" + expected := "data may only contain primitives, strings, arrays, slices, structs, maps, and pointers" if !strings.Contains(err.Error(), expected) { t.Fatalf("Expected %q, got %q", expected, err.Error()) } @@ -59,32 +59,32 @@ func TestHashingNonSerializableTypesFails(t *testing.T) { func TestHashSuccessful(t *testing.T) { testCases := []struct { - name string - provider interface{} + name string + data interface{} }{ { - name: "int", - provider: 5, + name: "int", + data: 5, }, { - name: "string", - provider: "foo", + name: "string", + data: "foo", }, { - name: "*string", - provider: StringPtr("foo"), + name: "*string", + data: StringPtr("foo"), }, { - name: "array", - provider: [3]string{"foo", "bar", "baz"}, + name: "array", + data: [3]string{"foo", "bar", "baz"}, }, { - name: "slice", - provider: []string{"foo", "bar", "baz"}, + name: "slice", + data: []string{"foo", "bar", "baz"}, }, { name: "struct", - provider: struct { + data: struct { foo string bar int }{ @@ -94,19 +94,35 @@ func TestHashSuccessful(t *testing.T) { }, { name: "map", - provider: map[string]int{ + data: map[string]int{ "foo": 3, "bar": 4, }, }, { - name: "list of interfaces with different types", - provider: []interface{}{"foo", 3, []string{"bar", "baz"}}, + name: "list of interfaces with different types", + data: []interface{}{"foo", 3, []string{"bar", "baz"}}, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - mustHash(t, testCase.provider) + mustHash(t, testCase.data) }) } } + +func TestHashingDereferencePointers(t *testing.T) { + str1 := "this is a hash test for pointers" + str2 := "this is a hash test for pointers" + data := []struct { + content *string + }{ + {content: &str1}, + {content: &str2}, + } + first := mustHash(t, data[0]) + second := mustHash(t, data[1]) + if first != second { + t.Fatal("Got different results for the same string") + } +} diff --git a/provider.go b/provider.go index 297861e..b2e0876 100644 --- a/provider.go +++ b/provider.go @@ -158,7 +158,7 @@ func (c *Context) setProvider(m *moduleInfo, provider *providerKey, value any) { if m.providerInitialValueHashes == nil { m.providerInitialValueHashes = make([]uint64, len(providerRegistry)) } - hash, err := proptools.HashProvider(value) + hash, err := proptools.CalculateHash(value) if err != nil { panic(fmt.Sprintf("Can't set value of provider %s: %s", provider.typ, err.Error())) }