Enforce that providers are not changed
When setProvider() is called, hash the provider and store the hash in the module. Then after the build is done, hash all the providers again and compare the hashes. It's an error if they don't match. Also add a flag to control it in case this check gets slow as we convert more things to providers. However right now it's fast (unnoticable in terms of whole seconds) so just have the flag always enabled. Bug: 322069292 Test: m nothing Change-Id: Ie4e806a6a9f20542ffcc7439eef376d3fb6a98ca
This commit is contained in:
parent
1b3fe6bd2c
commit
7add62142d
6 changed files with 415 additions and 16 deletions
|
@ -52,7 +52,7 @@ bootstrap_go_package {
|
|||
"provider.go",
|
||||
"scope.go",
|
||||
"singleton_ctx.go",
|
||||
"source_file_provider.go"
|
||||
"source_file_provider.go",
|
||||
],
|
||||
testSrcs: [
|
||||
"context_test.go",
|
||||
|
@ -120,6 +120,7 @@ bootstrap_go_package {
|
|||
"proptools/escape.go",
|
||||
"proptools/extend.go",
|
||||
"proptools/filter.go",
|
||||
"proptools/hash_provider.go",
|
||||
"proptools/proptools.go",
|
||||
"proptools/tag.go",
|
||||
"proptools/typeequal.go",
|
||||
|
@ -130,6 +131,7 @@ bootstrap_go_package {
|
|||
"proptools/escape_test.go",
|
||||
"proptools/extend_test.go",
|
||||
"proptools/filter_test.go",
|
||||
"proptools/hash_provider_test.go",
|
||||
"proptools/tag_test.go",
|
||||
"proptools/typeequal_test.go",
|
||||
"proptools/unpack_test.go",
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"runtime/debug"
|
||||
"runtime/pprof"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
|
||||
"github.com/google/blueprint"
|
||||
)
|
||||
|
@ -134,6 +135,15 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf
|
|||
return ninjaDeps, nil
|
||||
}
|
||||
|
||||
providersValidationChan := make(chan []error, 1)
|
||||
if ctx.GetVerifyProvidersAreUnchanged() {
|
||||
go func() {
|
||||
providersValidationChan <- ctx.VerifyProvidersWereUnchanged()
|
||||
}()
|
||||
} else {
|
||||
providersValidationChan <- nil
|
||||
}
|
||||
|
||||
const outFilePermissions = 0666
|
||||
var out blueprint.StringWriterWriter
|
||||
var f *os.File
|
||||
|
@ -172,6 +182,18 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf
|
|||
}
|
||||
}
|
||||
|
||||
providerValidationErrors := <-providersValidationChan
|
||||
if providerValidationErrors != nil {
|
||||
var sb strings.Builder
|
||||
for i, err := range providerValidationErrors {
|
||||
if i != 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(err.Error())
|
||||
}
|
||||
return nil, errors.New(sb.String())
|
||||
}
|
||||
|
||||
if args.Memprofile != "" {
|
||||
f, err := os.Create(joinPath(ctx.SrcDir(), args.Memprofile))
|
||||
if err != nil {
|
||||
|
|
75
context.go
75
context.go
|
@ -101,6 +101,8 @@ type Context struct {
|
|||
// set by SetAllowMissingDependencies
|
||||
allowMissingDependencies bool
|
||||
|
||||
verifyProvidersAreUnchanged bool
|
||||
|
||||
// set during PrepareBuildActions
|
||||
nameTracker *nameTracker
|
||||
liveGlobals *liveTracker
|
||||
|
@ -351,7 +353,8 @@ type moduleInfo struct {
|
|||
// set during PrepareBuildActions
|
||||
actionDefs localBuildActions
|
||||
|
||||
providers []interface{}
|
||||
providers []interface{}
|
||||
providerInitialValueHashes []uint64
|
||||
|
||||
startedMutator *mutatorInfo
|
||||
finishedMutator *mutatorInfo
|
||||
|
@ -463,20 +466,21 @@ type mutatorInfo struct {
|
|||
func newContext() *Context {
|
||||
eventHandler := metrics.EventHandler{}
|
||||
return &Context{
|
||||
Context: context.Background(),
|
||||
EventHandler: &eventHandler,
|
||||
moduleFactories: make(map[string]ModuleFactory),
|
||||
nameInterface: NewSimpleNameInterface(),
|
||||
moduleInfo: make(map[Module]*moduleInfo),
|
||||
globs: make(map[globKey]pathtools.GlobResult),
|
||||
fs: pathtools.OsFs,
|
||||
finishedMutators: make(map[*mutatorInfo]bool),
|
||||
includeTags: &IncludeTags{},
|
||||
sourceRootDirs: &SourceRootDirs{},
|
||||
outDir: nil,
|
||||
requiredNinjaMajor: 1,
|
||||
requiredNinjaMinor: 7,
|
||||
requiredNinjaMicro: 0,
|
||||
Context: context.Background(),
|
||||
EventHandler: &eventHandler,
|
||||
moduleFactories: make(map[string]ModuleFactory),
|
||||
nameInterface: NewSimpleNameInterface(),
|
||||
moduleInfo: make(map[Module]*moduleInfo),
|
||||
globs: make(map[globKey]pathtools.GlobResult),
|
||||
fs: pathtools.OsFs,
|
||||
finishedMutators: make(map[*mutatorInfo]bool),
|
||||
includeTags: &IncludeTags{},
|
||||
sourceRootDirs: &SourceRootDirs{},
|
||||
outDir: nil,
|
||||
requiredNinjaMajor: 1,
|
||||
requiredNinjaMinor: 7,
|
||||
requiredNinjaMicro: 0,
|
||||
verifyProvidersAreUnchanged: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -952,6 +956,18 @@ func (c *Context) SetAllowMissingDependencies(allowMissingDependencies bool) {
|
|||
c.allowMissingDependencies = allowMissingDependencies
|
||||
}
|
||||
|
||||
// SetVerifyProvidersAreUnchanged makes blueprint hash all providers immediately
|
||||
// after SetProvider() is called, and then hash them again after the build finished.
|
||||
// If the hashes change, it's an error. Providers are supposed to be immutable, but
|
||||
// we don't have any more direct way to enforce that in go.
|
||||
func (c *Context) SetVerifyProvidersAreUnchanged(verifyProvidersAreUnchanged bool) {
|
||||
c.verifyProvidersAreUnchanged = verifyProvidersAreUnchanged
|
||||
}
|
||||
|
||||
func (c *Context) GetVerifyProvidersAreUnchanged() bool {
|
||||
return c.verifyProvidersAreUnchanged
|
||||
}
|
||||
|
||||
func (c *Context) SetModuleListFile(listFile string) {
|
||||
c.moduleListFile = listFile
|
||||
}
|
||||
|
@ -1730,6 +1746,7 @@ func (c *Context) createVariations(origModule *moduleInfo, mutatorName string,
|
|||
newModule.variant = newVariant(origModule, mutatorName, variationName, local)
|
||||
newModule.properties = newProperties
|
||||
newModule.providers = append([]interface{}(nil), origModule.providers...)
|
||||
newModule.providerInitialValueHashes = append([]uint64(nil), origModule.providerInitialValueHashes...)
|
||||
|
||||
newModules = append(newModules, newModule)
|
||||
|
||||
|
@ -4207,6 +4224,34 @@ func (c *Context) SingletonName(singleton Singleton) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// Checks that the hashes of all the providers match the hashes from when they were first set.
|
||||
// Does nothing on success, returns a list of errors otherwise. It's recommended to run this
|
||||
// in a goroutine.
|
||||
func (c *Context) VerifyProvidersWereUnchanged() []error {
|
||||
if !c.buildActionsReady {
|
||||
return []error{ErrBuildActionsNotReady}
|
||||
}
|
||||
var errors []error
|
||||
for _, m := range c.modulesSorted {
|
||||
for i, provider := range m.providers {
|
||||
if provider != nil {
|
||||
hash, err := proptools.HashProvider(provider)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set, and no longer hashable afterwards: %s", providerRegistry[i].typ, m.Name(), err.Error()))
|
||||
continue
|
||||
}
|
||||
if provider != nil && m.providerInitialValueHashes[i] != hash {
|
||||
errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set", providerRegistry[i].typ, m.Name()))
|
||||
}
|
||||
} else if m.providerInitialValueHashes[i] != 0 {
|
||||
// This should be unreachable, because in setProvider we check if the provider has already been set.
|
||||
errors = append(errors, fmt.Errorf("provider %q on module %q was unset somehow, this is an internal error", providerRegistry[i].typ, m.Name()))
|
||||
}
|
||||
}
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
// WriteBuildFile writes the Ninja manifest text for the generated build
|
||||
// actions to w. If this is called before PrepareBuildActions successfully
|
||||
// completes then ErrBuildActionsNotReady is returned.
|
||||
|
|
205
proptools/hash_provider.go
Normal file
205
proptools/hash_provider.go
Normal file
|
@ -0,0 +1,205 @@
|
|||
// Copyright 2023 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proptools
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
var seed maphash.Seed = maphash.MakeSeed()
|
||||
|
||||
// byte to insert between elements of lists, fields of structs/maps, etc in order
|
||||
// to try and make sure the hash is different when values are moved around between
|
||||
// elements. 36 is arbitrary, but it's the ascii code for a record separator
|
||||
var recordSeparator []byte = []byte{36}
|
||||
|
||||
func HashProvider(provider interface{}) (uint64, error) {
|
||||
hasher := maphash.Hash{}
|
||||
hasher.SetSeed(seed)
|
||||
ptrs := make(map[uintptr]bool)
|
||||
v := reflect.ValueOf(provider)
|
||||
var err error
|
||||
if v.IsValid() {
|
||||
err = hashProviderInternal(&hasher, v, ptrs)
|
||||
}
|
||||
return hasher.Sum64(), err
|
||||
}
|
||||
|
||||
func hashProviderInternal(hasher io.Writer, v reflect.Value, ptrs map[uintptr]bool) error {
|
||||
var int64Array [8]byte
|
||||
int64Buf := int64Array[:]
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(v.Kind()))
|
||||
hasher.Write(int64Buf)
|
||||
v.IsValid()
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(v.NumField()))
|
||||
hasher.Write(int64Buf)
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
hasher.Write(recordSeparator)
|
||||
err := hashProviderInternal(hasher, v.Field(i), ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error())
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
|
||||
hasher.Write(int64Buf)
|
||||
indexes := make([]int, v.Len())
|
||||
keys := make([]reflect.Value, v.Len())
|
||||
values := make([]reflect.Value, v.Len())
|
||||
iter := v.MapRange()
|
||||
for i := 0; iter.Next(); i++ {
|
||||
indexes[i] = i
|
||||
keys[i] = iter.Key()
|
||||
values[i] = iter.Value()
|
||||
}
|
||||
sort.SliceStable(indexes, func(i, j int) bool {
|
||||
return compare_values(keys[indexes[i]], keys[indexes[j]]) < 0
|
||||
})
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
hasher.Write(recordSeparator)
|
||||
err := hashProviderInternal(hasher, keys[indexes[i]], ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in map: %s", err.Error())
|
||||
}
|
||||
hasher.Write(recordSeparator)
|
||||
err = hashProviderInternal(hasher, keys[indexes[i]], ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in map: %s", err.Error())
|
||||
}
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
|
||||
hasher.Write(int64Buf)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
hasher.Write(recordSeparator)
|
||||
err := hashProviderInternal(hasher, v.Index(i), ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error())
|
||||
}
|
||||
}
|
||||
case reflect.Pointer:
|
||||
if v.IsNil() {
|
||||
int64Buf[0] = 0
|
||||
hasher.Write(int64Buf[:1])
|
||||
return nil
|
||||
}
|
||||
addr := v.Pointer()
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(addr))
|
||||
hasher.Write(int64Buf)
|
||||
if _, ok := ptrs[addr]; ok {
|
||||
// We could make this an error if we want to disallow pointer cycles in the future
|
||||
return nil
|
||||
}
|
||||
ptrs[addr] = true
|
||||
err := hashProviderInternal(hasher, v.Elem(), ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in pointer: %s", err.Error())
|
||||
}
|
||||
case reflect.Interface:
|
||||
if v.IsNil() {
|
||||
int64Buf[0] = 0
|
||||
hasher.Write(int64Buf[:1])
|
||||
} else {
|
||||
// The only way get the pointer out of an interface to hash it or check for cycles
|
||||
// would be InterfaceData(), but that's deprecated and seems like it has undefined behavior.
|
||||
err := hashProviderInternal(hasher, v.Elem(), ptrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("in interface: %s", err.Error())
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
hasher.Write([]byte(v.String()))
|
||||
case reflect.Bool:
|
||||
if v.Bool() {
|
||||
int64Buf[0] = 1
|
||||
} else {
|
||||
int64Buf[0] = 0
|
||||
}
|
||||
hasher.Write(int64Buf[:1])
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
binary.LittleEndian.PutUint64(int64Buf, v.Uint())
|
||||
hasher.Write(int64Buf)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
binary.LittleEndian.PutUint64(int64Buf, uint64(v.Int()))
|
||||
hasher.Write(int64Buf)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
binary.LittleEndian.PutUint64(int64Buf, math.Float64bits(v.Float()))
|
||||
hasher.Write(int64Buf)
|
||||
default:
|
||||
return fmt.Errorf("providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compare_values(x, y reflect.Value) int {
|
||||
if x.Type() != y.Type() {
|
||||
panic("Expected equal types")
|
||||
}
|
||||
|
||||
switch x.Kind() {
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return cmp.Compare(x.Uint(), y.Uint())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return cmp.Compare(x.Int(), y.Int())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return cmp.Compare(x.Float(), y.Float())
|
||||
case reflect.String:
|
||||
return cmp.Compare(x.String(), y.String())
|
||||
case reflect.Bool:
|
||||
if x.Bool() == y.Bool() {
|
||||
return 0
|
||||
} else if x.Bool() {
|
||||
return 1
|
||||
} else {
|
||||
return -1
|
||||
}
|
||||
case reflect.Pointer:
|
||||
return cmp.Compare(x.Pointer(), y.Pointer())
|
||||
case reflect.Array:
|
||||
for i := 0; i < x.Len(); i++ {
|
||||
if result := compare_values(x.Index(i), y.Index(i)); result != 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return 0
|
||||
case reflect.Struct:
|
||||
for i := 0; i < x.NumField(); i++ {
|
||||
if result := compare_values(x.Field(i), y.Field(i)); result != 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return 0
|
||||
case reflect.Interface:
|
||||
if x.IsNil() && y.IsNil() {
|
||||
return 0
|
||||
} else if x.IsNil() {
|
||||
return 1
|
||||
} else if y.IsNil() {
|
||||
return -1
|
||||
}
|
||||
return compare_values(x.Elem(), y.Elem())
|
||||
default:
|
||||
panic(fmt.Sprintf("Could not compare types %s and %s", x.Type().String(), y.Type().String()))
|
||||
}
|
||||
}
|
112
proptools/hash_provider_test.go
Normal file
112
proptools/hash_provider_test.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package proptools
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustHash(t *testing.T, provider interface{}) uint64 {
|
||||
t.Helper()
|
||||
result, err := HashProvider(provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestHashingMapGetsSameResults(t *testing.T) {
|
||||
provider := map[string]string{"foo": "bar", "baz": "qux"}
|
||||
first := mustHash(t, provider)
|
||||
second := mustHash(t, provider)
|
||||
third := mustHash(t, provider)
|
||||
fourth := mustHash(t, provider)
|
||||
if first != second || second != third || third != fourth {
|
||||
t.Fatal("Did not get the same result every time for a map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashingNonSerializableTypesFails(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
provider interface{}
|
||||
}{
|
||||
{
|
||||
name: "function pointer",
|
||||
provider: []func(){nil},
|
||||
},
|
||||
{
|
||||
name: "channel",
|
||||
provider: []chan int{make(chan int)},
|
||||
},
|
||||
{
|
||||
name: "list with non-serializable type",
|
||||
provider: []interface{}{"foo", make(chan int)},
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := HashProvider(testCase)
|
||||
if err == nil {
|
||||
t.Fatal("Expected hashing error but didn't get one")
|
||||
}
|
||||
expected := "providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers"
|
||||
if !strings.Contains(err.Error(), expected) {
|
||||
t.Fatalf("Expected %q, got %q", expected, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashSuccessful(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
provider interface{}
|
||||
}{
|
||||
{
|
||||
name: "int",
|
||||
provider: 5,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
provider: "foo",
|
||||
},
|
||||
{
|
||||
name: "*string",
|
||||
provider: StringPtr("foo"),
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
provider: [3]string{"foo", "bar", "baz"},
|
||||
},
|
||||
{
|
||||
name: "slice",
|
||||
provider: []string{"foo", "bar", "baz"},
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
provider: struct {
|
||||
foo string
|
||||
bar int
|
||||
}{
|
||||
foo: "foo",
|
||||
bar: 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
provider: map[string]int{
|
||||
"foo": 3,
|
||||
"bar": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list of interfaces with different types",
|
||||
provider: []interface{}{"foo", 3, []string{"bar", "baz"}},
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mustHash(t, testCase.provider)
|
||||
})
|
||||
}
|
||||
}
|
13
provider.go
13
provider.go
|
@ -16,6 +16,8 @@ package blueprint
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/blueprint/proptools"
|
||||
)
|
||||
|
||||
// This file implements Providers, modelled after Bazel
|
||||
|
@ -151,6 +153,17 @@ func (c *Context) setProvider(m *moduleInfo, provider *providerKey, value any) {
|
|||
}
|
||||
|
||||
m.providers[provider.id] = value
|
||||
|
||||
if c.verifyProvidersAreUnchanged {
|
||||
if m.providerInitialValueHashes == nil {
|
||||
m.providerInitialValueHashes = make([]uint64, len(providerRegistry))
|
||||
}
|
||||
hash, err := proptools.HashProvider(value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Can't set value of provider %s: %s", provider.typ, err.Error()))
|
||||
}
|
||||
m.providerInitialValueHashes[provider.id] = hash
|
||||
}
|
||||
}
|
||||
|
||||
// provider returns the value, if any, for a given provider for a module. Verifies that it is
|
||||
|
|
Loading…
Reference in a new issue