diff --git a/Android.bp b/Android.bp index 246207a..d0a16ad 100644 --- a/Android.bp +++ b/Android.bp @@ -52,7 +52,7 @@ bootstrap_go_package { "provider.go", "scope.go", "singleton_ctx.go", - "source_file_provider.go" + "source_file_provider.go", ], testSrcs: [ "context_test.go", @@ -120,6 +120,7 @@ bootstrap_go_package { "proptools/escape.go", "proptools/extend.go", "proptools/filter.go", + "proptools/hash_provider.go", "proptools/proptools.go", "proptools/tag.go", "proptools/typeequal.go", @@ -130,6 +131,7 @@ bootstrap_go_package { "proptools/escape_test.go", "proptools/extend_test.go", "proptools/filter_test.go", + "proptools/hash_provider_test.go", "proptools/tag_test.go", "proptools/typeequal_test.go", "proptools/unpack_test.go", diff --git a/bootstrap/command.go b/bootstrap/command.go index bc1d32d..d7dcc27 100644 --- a/bootstrap/command.go +++ b/bootstrap/command.go @@ -25,6 +25,7 @@ import ( "runtime/debug" "runtime/pprof" "runtime/trace" + "strings" "github.com/google/blueprint" ) @@ -134,6 +135,15 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf return ninjaDeps, nil } + providersValidationChan := make(chan []error, 1) + if ctx.GetVerifyProvidersAreUnchanged() { + go func() { + providersValidationChan <- ctx.VerifyProvidersWereUnchanged() + }() + } else { + providersValidationChan <- nil + } + const outFilePermissions = 0666 var out blueprint.StringWriterWriter var f *os.File @@ -172,6 +182,18 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf } } + providerValidationErrors := <-providersValidationChan + if providerValidationErrors != nil { + var sb strings.Builder + for i, err := range providerValidationErrors { + if i != 0 { + sb.WriteString("\n") + } + sb.WriteString(err.Error()) + } + return nil, errors.New(sb.String()) + } + if args.Memprofile != "" { f, err := os.Create(joinPath(ctx.SrcDir(), args.Memprofile)) if err != nil { diff --git a/context.go b/context.go index 28f0cc5..4130700 100644 --- a/context.go +++ b/context.go @@ -101,6 +101,8 @@ type Context struct { // set by SetAllowMissingDependencies allowMissingDependencies bool + verifyProvidersAreUnchanged bool + // set during PrepareBuildActions nameTracker *nameTracker liveGlobals *liveTracker @@ -351,7 +353,8 @@ type moduleInfo struct { // set during PrepareBuildActions actionDefs localBuildActions - providers []interface{} + providers []interface{} + providerInitialValueHashes []uint64 startedMutator *mutatorInfo finishedMutator *mutatorInfo @@ -463,20 +466,21 @@ type mutatorInfo struct { func newContext() *Context { eventHandler := metrics.EventHandler{} return &Context{ - Context: context.Background(), - EventHandler: &eventHandler, - moduleFactories: make(map[string]ModuleFactory), - nameInterface: NewSimpleNameInterface(), - moduleInfo: make(map[Module]*moduleInfo), - globs: make(map[globKey]pathtools.GlobResult), - fs: pathtools.OsFs, - finishedMutators: make(map[*mutatorInfo]bool), - includeTags: &IncludeTags{}, - sourceRootDirs: &SourceRootDirs{}, - outDir: nil, - requiredNinjaMajor: 1, - requiredNinjaMinor: 7, - requiredNinjaMicro: 0, + Context: context.Background(), + EventHandler: &eventHandler, + moduleFactories: make(map[string]ModuleFactory), + nameInterface: NewSimpleNameInterface(), + moduleInfo: make(map[Module]*moduleInfo), + globs: make(map[globKey]pathtools.GlobResult), + fs: pathtools.OsFs, + finishedMutators: make(map[*mutatorInfo]bool), + includeTags: &IncludeTags{}, + sourceRootDirs: &SourceRootDirs{}, + outDir: nil, + requiredNinjaMajor: 1, + requiredNinjaMinor: 7, + requiredNinjaMicro: 0, + verifyProvidersAreUnchanged: true, } } @@ -952,6 +956,18 @@ func (c *Context) SetAllowMissingDependencies(allowMissingDependencies bool) { c.allowMissingDependencies = allowMissingDependencies } +// SetVerifyProvidersAreUnchanged makes blueprint hash all providers immediately +// after SetProvider() is called, and then hash them again after the build finished. +// If the hashes change, it's an error. Providers are supposed to be immutable, but +// we don't have any more direct way to enforce that in go. +func (c *Context) SetVerifyProvidersAreUnchanged(verifyProvidersAreUnchanged bool) { + c.verifyProvidersAreUnchanged = verifyProvidersAreUnchanged +} + +func (c *Context) GetVerifyProvidersAreUnchanged() bool { + return c.verifyProvidersAreUnchanged +} + func (c *Context) SetModuleListFile(listFile string) { c.moduleListFile = listFile } @@ -1730,6 +1746,7 @@ func (c *Context) createVariations(origModule *moduleInfo, mutatorName string, newModule.variant = newVariant(origModule, mutatorName, variationName, local) newModule.properties = newProperties newModule.providers = append([]interface{}(nil), origModule.providers...) + newModule.providerInitialValueHashes = append([]uint64(nil), origModule.providerInitialValueHashes...) newModules = append(newModules, newModule) @@ -4207,6 +4224,34 @@ func (c *Context) SingletonName(singleton Singleton) string { return "" } +// Checks that the hashes of all the providers match the hashes from when they were first set. +// Does nothing on success, returns a list of errors otherwise. It's recommended to run this +// in a goroutine. +func (c *Context) VerifyProvidersWereUnchanged() []error { + if !c.buildActionsReady { + return []error{ErrBuildActionsNotReady} + } + var errors []error + for _, m := range c.modulesSorted { + for i, provider := range m.providers { + if provider != nil { + hash, err := proptools.HashProvider(provider) + if err != nil { + errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set, and no longer hashable afterwards: %s", providerRegistry[i].typ, m.Name(), err.Error())) + continue + } + if provider != nil && m.providerInitialValueHashes[i] != hash { + errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set", providerRegistry[i].typ, m.Name())) + } + } else if m.providerInitialValueHashes[i] != 0 { + // This should be unreachable, because in setProvider we check if the provider has already been set. + errors = append(errors, fmt.Errorf("provider %q on module %q was unset somehow, this is an internal error", providerRegistry[i].typ, m.Name())) + } + } + } + return errors +} + // WriteBuildFile writes the Ninja manifest text for the generated build // actions to w. If this is called before PrepareBuildActions successfully // completes then ErrBuildActionsNotReady is returned. diff --git a/proptools/hash_provider.go b/proptools/hash_provider.go new file mode 100644 index 0000000..b52a10e --- /dev/null +++ b/proptools/hash_provider.go @@ -0,0 +1,205 @@ +// Copyright 2023 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proptools + +import ( + "cmp" + "encoding/binary" + "fmt" + "hash/maphash" + "io" + "math" + "reflect" + "sort" +) + +var seed maphash.Seed = maphash.MakeSeed() + +// byte to insert between elements of lists, fields of structs/maps, etc in order +// to try and make sure the hash is different when values are moved around between +// elements. 36 is arbitrary, but it's the ascii code for a record separator +var recordSeparator []byte = []byte{36} + +func HashProvider(provider interface{}) (uint64, error) { + hasher := maphash.Hash{} + hasher.SetSeed(seed) + ptrs := make(map[uintptr]bool) + v := reflect.ValueOf(provider) + var err error + if v.IsValid() { + err = hashProviderInternal(&hasher, v, ptrs) + } + return hasher.Sum64(), err +} + +func hashProviderInternal(hasher io.Writer, v reflect.Value, ptrs map[uintptr]bool) error { + var int64Array [8]byte + int64Buf := int64Array[:] + binary.LittleEndian.PutUint64(int64Buf, uint64(v.Kind())) + hasher.Write(int64Buf) + v.IsValid() + switch v.Kind() { + case reflect.Struct: + binary.LittleEndian.PutUint64(int64Buf, uint64(v.NumField())) + hasher.Write(int64Buf) + for i := 0; i < v.NumField(); i++ { + hasher.Write(recordSeparator) + err := hashProviderInternal(hasher, v.Field(i), ptrs) + if err != nil { + return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error()) + } + } + case reflect.Map: + binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len())) + hasher.Write(int64Buf) + indexes := make([]int, v.Len()) + keys := make([]reflect.Value, v.Len()) + values := make([]reflect.Value, v.Len()) + iter := v.MapRange() + for i := 0; iter.Next(); i++ { + indexes[i] = i + keys[i] = iter.Key() + values[i] = iter.Value() + } + sort.SliceStable(indexes, func(i, j int) bool { + return compare_values(keys[indexes[i]], keys[indexes[j]]) < 0 + }) + for i := 0; i < v.Len(); i++ { + hasher.Write(recordSeparator) + err := hashProviderInternal(hasher, keys[indexes[i]], ptrs) + if err != nil { + return fmt.Errorf("in map: %s", err.Error()) + } + hasher.Write(recordSeparator) + err = hashProviderInternal(hasher, keys[indexes[i]], ptrs) + if err != nil { + return fmt.Errorf("in map: %s", err.Error()) + } + } + case reflect.Slice, reflect.Array: + binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len())) + hasher.Write(int64Buf) + for i := 0; i < v.Len(); i++ { + hasher.Write(recordSeparator) + err := hashProviderInternal(hasher, v.Index(i), ptrs) + if err != nil { + return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error()) + } + } + case reflect.Pointer: + if v.IsNil() { + int64Buf[0] = 0 + hasher.Write(int64Buf[:1]) + return nil + } + addr := v.Pointer() + binary.LittleEndian.PutUint64(int64Buf, uint64(addr)) + hasher.Write(int64Buf) + if _, ok := ptrs[addr]; ok { + // We could make this an error if we want to disallow pointer cycles in the future + return nil + } + ptrs[addr] = true + err := hashProviderInternal(hasher, v.Elem(), ptrs) + if err != nil { + return fmt.Errorf("in pointer: %s", err.Error()) + } + case reflect.Interface: + if v.IsNil() { + int64Buf[0] = 0 + hasher.Write(int64Buf[:1]) + } else { + // The only way get the pointer out of an interface to hash it or check for cycles + // would be InterfaceData(), but that's deprecated and seems like it has undefined behavior. + err := hashProviderInternal(hasher, v.Elem(), ptrs) + if err != nil { + return fmt.Errorf("in interface: %s", err.Error()) + } + } + case reflect.String: + hasher.Write([]byte(v.String())) + case reflect.Bool: + if v.Bool() { + int64Buf[0] = 1 + } else { + int64Buf[0] = 0 + } + hasher.Write(int64Buf[:1]) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + binary.LittleEndian.PutUint64(int64Buf, v.Uint()) + hasher.Write(int64Buf) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + binary.LittleEndian.PutUint64(int64Buf, uint64(v.Int())) + hasher.Write(int64Buf) + case reflect.Float32, reflect.Float64: + binary.LittleEndian.PutUint64(int64Buf, math.Float64bits(v.Float())) + hasher.Write(int64Buf) + default: + return fmt.Errorf("providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String()) + } + return nil +} + +func compare_values(x, y reflect.Value) int { + if x.Type() != y.Type() { + panic("Expected equal types") + } + + switch x.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return cmp.Compare(x.Uint(), y.Uint()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return cmp.Compare(x.Int(), y.Int()) + case reflect.Float32, reflect.Float64: + return cmp.Compare(x.Float(), y.Float()) + case reflect.String: + return cmp.Compare(x.String(), y.String()) + case reflect.Bool: + if x.Bool() == y.Bool() { + return 0 + } else if x.Bool() { + return 1 + } else { + return -1 + } + case reflect.Pointer: + return cmp.Compare(x.Pointer(), y.Pointer()) + case reflect.Array: + for i := 0; i < x.Len(); i++ { + if result := compare_values(x.Index(i), y.Index(i)); result != 0 { + return result + } + } + return 0 + case reflect.Struct: + for i := 0; i < x.NumField(); i++ { + if result := compare_values(x.Field(i), y.Field(i)); result != 0 { + return result + } + } + return 0 + case reflect.Interface: + if x.IsNil() && y.IsNil() { + return 0 + } else if x.IsNil() { + return 1 + } else if y.IsNil() { + return -1 + } + return compare_values(x.Elem(), y.Elem()) + default: + panic(fmt.Sprintf("Could not compare types %s and %s", x.Type().String(), y.Type().String())) + } +} diff --git a/proptools/hash_provider_test.go b/proptools/hash_provider_test.go new file mode 100644 index 0000000..1c97aec --- /dev/null +++ b/proptools/hash_provider_test.go @@ -0,0 +1,112 @@ +package proptools + +import ( + "strings" + "testing" +) + +func mustHash(t *testing.T, provider interface{}) uint64 { + t.Helper() + result, err := HashProvider(provider) + if err != nil { + t.Fatal(err) + } + return result +} + +func TestHashingMapGetsSameResults(t *testing.T) { + provider := map[string]string{"foo": "bar", "baz": "qux"} + first := mustHash(t, provider) + second := mustHash(t, provider) + third := mustHash(t, provider) + fourth := mustHash(t, provider) + if first != second || second != third || third != fourth { + t.Fatal("Did not get the same result every time for a map") + } +} + +func TestHashingNonSerializableTypesFails(t *testing.T) { + testCases := []struct { + name string + provider interface{} + }{ + { + name: "function pointer", + provider: []func(){nil}, + }, + { + name: "channel", + provider: []chan int{make(chan int)}, + }, + { + name: "list with non-serializable type", + provider: []interface{}{"foo", make(chan int)}, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + _, err := HashProvider(testCase) + if err == nil { + t.Fatal("Expected hashing error but didn't get one") + } + expected := "providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers" + if !strings.Contains(err.Error(), expected) { + t.Fatalf("Expected %q, got %q", expected, err.Error()) + } + }) + } +} + +func TestHashSuccessful(t *testing.T) { + testCases := []struct { + name string + provider interface{} + }{ + { + name: "int", + provider: 5, + }, + { + name: "string", + provider: "foo", + }, + { + name: "*string", + provider: StringPtr("foo"), + }, + { + name: "array", + provider: [3]string{"foo", "bar", "baz"}, + }, + { + name: "slice", + provider: []string{"foo", "bar", "baz"}, + }, + { + name: "struct", + provider: struct { + foo string + bar int + }{ + foo: "foo", + bar: 3, + }, + }, + { + name: "map", + provider: map[string]int{ + "foo": 3, + "bar": 4, + }, + }, + { + name: "list of interfaces with different types", + provider: []interface{}{"foo", 3, []string{"bar", "baz"}}, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mustHash(t, testCase.provider) + }) + } +} diff --git a/provider.go b/provider.go index 48527b1..297861e 100644 --- a/provider.go +++ b/provider.go @@ -16,6 +16,8 @@ package blueprint import ( "fmt" + + "github.com/google/blueprint/proptools" ) // This file implements Providers, modelled after Bazel @@ -151,6 +153,17 @@ func (c *Context) setProvider(m *moduleInfo, provider *providerKey, value any) { } m.providers[provider.id] = value + + if c.verifyProvidersAreUnchanged { + if m.providerInitialValueHashes == nil { + m.providerInitialValueHashes = make([]uint64, len(providerRegistry)) + } + hash, err := proptools.HashProvider(value) + if err != nil { + panic(fmt.Sprintf("Can't set value of provider %s: %s", provider.typ, err.Error())) + } + m.providerInitialValueHashes[provider.id] = hash + } } // provider returns the value, if any, for a given provider for a module. Verifies that it is