Support variable bindings in selects

This allows us to recreate soong config value variables in selects.

This adds a new "any" pattern to selects, which is the same as "default"
except that it doesn't match undefined variables, and it's (currently)
the only pattern that can accept a binding.

The syntax looks like:
```
select(soong_config_variable("my_namespace", "my_variable"), {
    any @ my_binding: "foo" + my_binding,
    default: "other value",
})
```

Bug: 323382414
Test: m nothing --no-skip-soong-tests
Change-Id: I4feb4073172d8797dee5472f43f9c248a76c3f1f
This commit is contained in:
Cole Faust 2024-06-20 18:14:45 -07:00 committed by Bartłomiej Rudecki
parent 1e62c68bfe
commit 738bb54ded
Signed by: przekichane
GPG key ID: 751F23C6F014EF76
7 changed files with 243 additions and 50 deletions

View file

@ -131,6 +131,10 @@ func (p *Property) String() string {
func (p *Property) Pos() scanner.Position { return p.NamePos }
func (p *Property) End() scanner.Position { return p.Value.End() }
func (p *Property) MarkReferencedVariables(scope *Scope) {
p.Value.MarkReferencedVariables(scope)
}
// An Expression is a Value in a Property or Assignment. It can be a literal (String or Bool), a
// Map, a List, an Operator that combines two expressions of the same type, or a Variable that
// references and Assignment.
@ -152,6 +156,11 @@ type Expression interface {
// value. It will modify the AST in-place. This is used to implement soong config value
// variables, but should be removed when those have switched to selects.
PrintfInto(value string) error
// MarkReferencedVariables marks the variables in the given scope referenced if there
// is a matching variable reference in this expression. This happens naturally during
// Eval as well, but for selects, we need to mark variables as referenced without
// actually evaluating the expression yet.
MarkReferencedVariables(scope *Scope)
}
// ExpressionsAreSame tells whether the two values are the same Expression.
@ -350,6 +359,11 @@ func (x *Operator) PrintfInto(value string) error {
return x.Args[1].PrintfInto(value)
}
func (x *Operator) MarkReferencedVariables(scope *Scope) {
x.Args[0].MarkReferencedVariables(scope)
x.Args[1].MarkReferencedVariables(scope)
}
func (x *Operator) Pos() scanner.Position { return x.Args[0].Pos() }
func (x *Operator) End() scanner.Position { return x.Args[1].End() }
@ -384,6 +398,12 @@ func (x *Variable) PrintfInto(value string) error {
return nil
}
func (x *Variable) MarkReferencedVariables(scope *Scope) {
if assignment := scope.Get(x.Name); assignment != nil {
assignment.Referenced = true
}
}
func (x *Variable) String() string {
return x.Name
}
@ -438,6 +458,12 @@ func (x *Map) PrintfInto(value string) error {
panic("printfinto() is unsupported on maps")
}
func (x *Map) MarkReferencedVariables(scope *Scope) {
for _, prop := range x.Properties {
prop.MarkReferencedVariables(scope)
}
}
func (x *Map) String() string {
propertyStrings := make([]string, len(x.Properties))
for i, property := range x.Properties {
@ -553,6 +579,12 @@ func (x *List) PrintfInto(value string) error {
return nil
}
func (x *List) MarkReferencedVariables(scope *Scope) {
for _, val := range x.Values {
val.MarkReferencedVariables(scope)
}
}
func (x *List) String() string {
valueStrings := make([]string, len(x.Values))
for i, value := range x.Values {
@ -599,6 +631,9 @@ func (x *String) PrintfInto(value string) error {
return nil
}
func (x *String) MarkReferencedVariables(scope *Scope) {
}
func (x *String) String() string {
return fmt.Sprintf("%q@%s", x.Value, x.LiteralPos)
}
@ -629,6 +664,9 @@ func (x *Int64) PrintfInto(value string) error {
return nil
}
func (x *Int64) MarkReferencedVariables(scope *Scope) {
}
func (x *Int64) String() string {
return fmt.Sprintf("%q@%s", x.Value, x.LiteralPos)
}
@ -659,6 +697,9 @@ func (x *Bool) PrintfInto(value string) error {
return nil
}
func (x *Bool) MarkReferencedVariables(scope *Scope) {
}
func (x *Bool) String() string {
return fmt.Sprintf("%t@%s", x.Value, x.LiteralPos)
}
@ -808,6 +849,7 @@ func (s *Select) Copy() Expression {
func (s *Select) Eval(scope *Scope) (Expression, error) {
s.Scope = scope
s.MarkReferencedVariables(scope)
return s, nil
}
@ -816,6 +858,15 @@ func (x *Select) PrintfInto(value string) error {
panic("Cannot call PrintfInto on a select expression")
}
func (x *Select) MarkReferencedVariables(scope *Scope) {
for _, c := range x.Cases {
c.MarkReferencedVariables(scope)
}
if x.Append != nil {
x.Append.MarkReferencedVariables(scope)
}
}
func (s *Select) String() string {
return "<select>"
}
@ -827,12 +878,29 @@ func (s *Select) Type() Type {
return UnknownType
}
type SelectPattern struct {
Value Expression
Binding Variable
}
func (c *SelectPattern) Pos() scanner.Position { return c.Value.Pos() }
func (c *SelectPattern) End() scanner.Position {
if c.Binding.NamePos.IsValid() {
return c.Binding.End()
}
return c.Value.End()
}
type SelectCase struct {
Patterns []Expression
Patterns []SelectPattern
ColonPos scanner.Position
Value Expression
}
func (x *SelectCase) MarkReferencedVariables(scope *Scope) {
x.Value.MarkReferencedVariables(scope)
}
func (c *SelectCase) Copy() *SelectCase {
ret := *c
ret.Value = c.Value.Copy()
@ -880,5 +948,8 @@ func (x *UnsetProperty) PrintfInto(value string) error {
return nil
}
func (x *UnsetProperty) MarkReferencedVariables(scope *Scope) {
}
func (n *UnsetProperty) Pos() scanner.Position { return n.Position }
func (n *UnsetProperty) End() scanner.Position { return n.Position }

View file

@ -29,6 +29,7 @@ var errTooManyErrors = errors.New("too many errors")
const maxErrors = 1
const default_select_branch_name = "__soong_conditions_default__"
const any_select_branch_name = "__soong_conditions_any__"
type ParseError struct {
Err error
@ -502,44 +503,72 @@ func (p *parser) parseSelect() Expression {
return nil
}
parseOnePattern := func() Expression {
maybeParseBinding := func() (Variable, bool) {
if p.scanner.TokenText() != "@" {
return Variable{}, false
}
p.next()
value := Variable{
Name: p.scanner.TokenText(),
NamePos: p.scanner.Position,
}
p.accept(scanner.Ident)
return value, true
}
parseOnePattern := func() SelectPattern {
var result SelectPattern
switch p.tok {
case scanner.Ident:
switch p.scanner.TokenText() {
case "default":
case "any":
result.Value = &String{
LiteralPos: p.scanner.Position,
Value: any_select_branch_name,
}
p.next()
return &String{
if binding, exists := maybeParseBinding(); exists {
result.Binding = binding
}
return result
case "default":
result.Value = &String{
LiteralPos: p.scanner.Position,
Value: default_select_branch_name,
}
case "true":
p.next()
return &Bool{
return result
case "true":
result.Value = &Bool{
LiteralPos: p.scanner.Position,
Value: true,
}
case "false":
p.next()
return &Bool{
return result
case "false":
result.Value = &Bool{
LiteralPos: p.scanner.Position,
Value: false,
}
p.next()
return result
default:
p.errorf("Expted a string, true, false, or default, got %s", p.scanner.TokenText())
p.errorf("Expected a string, true, false, or default, got %s", p.scanner.TokenText())
}
case scanner.String:
if s := p.parseStringValue(); s != nil {
if strings.HasPrefix(s.Value, "__soong") {
p.errorf("select branch conditions starting with __soong are reserved for internal use")
return nil
p.errorf("select branch patterns starting with __soong are reserved for internal use")
return result
}
return s
result.Value = s
return result
}
fallthrough
default:
p.errorf("Expted a string, true, false, or default, got %s", p.scanner.TokenText())
p.errorf("Expected a string, true, false, or default, got %s", p.scanner.TokenText())
}
return nil
return result
}
hasNonUnsetValue := false
@ -551,11 +580,7 @@ func (p *parser) parseSelect() Expression {
return nil
}
for i := 0; i < len(conditions); i++ {
if p := parseOnePattern(); p != nil {
c.Patterns = append(c.Patterns, p)
} else {
return nil
}
c.Patterns = append(c.Patterns, parseOnePattern())
if i < len(conditions)-1 {
if !p.accept(',') {
return nil
@ -569,11 +594,7 @@ func (p *parser) parseSelect() Expression {
return nil
}
} else {
if p := parseOnePattern(); p != nil {
c.Patterns = append(c.Patterns, p)
} else {
return nil
}
c.Patterns = append(c.Patterns, parseOnePattern())
}
c.ColonPos = p.scanner.Position
if !p.accept(':') {
@ -599,16 +620,17 @@ func (p *parser) parseSelect() Expression {
return nil
}
patternsEqual := func(a, b Expression) bool {
switch a2 := a.(type) {
patternsEqual := func(a, b SelectPattern) bool {
// We can ignore the bindings, they don't affect which pattern is matched
switch a2 := a.Value.(type) {
case *String:
if b2, ok := b.(*String); ok {
if b2, ok := b.Value.(*String); ok {
return a2.Value == b2.Value
} else {
return false
}
case *Bool:
if b2, ok := b.(*Bool); ok {
if b2, ok := b.Value.(*Bool); ok {
return a2.Value == b2.Value
} else {
return false
@ -619,7 +641,7 @@ func (p *parser) parseSelect() Expression {
}
}
patternListsEqual := func(a, b []Expression) bool {
patternListsEqual := func(a, b []SelectPattern) bool {
if len(a) != len(b) {
return false
}
@ -632,18 +654,29 @@ func (p *parser) parseSelect() Expression {
}
for i, c := range result.Cases {
// Check for duplicates
// Check for duplicate patterns across different branches
for _, d := range result.Cases[i+1:] {
if patternListsEqual(c.Patterns, d.Patterns) {
p.errorf("Found duplicate select patterns: %v", c.Patterns)
return nil
}
}
// check for duplicate bindings within this branch
for i := range c.Patterns {
if c.Patterns[i].Binding.Name != "" {
for j := i + 1; j < len(c.Patterns); j++ {
if c.Patterns[i].Binding.Name == c.Patterns[j].Binding.Name {
p.errorf("Found duplicate select pattern binding: %s", c.Patterns[i].Binding.Name)
return nil
}
}
}
}
// Check that the only all-default cases is the last one
if i < len(result.Cases)-1 {
isAllDefault := true
for _, x := range c.Patterns {
if x2, ok := x.(*String); !ok || x2.Value != default_select_branch_name {
if x2, ok := x.Value.(*String); !ok || x2.Value != default_select_branch_name {
isAllDefault = false
break
}

View file

@ -860,6 +860,17 @@ func TestParserError(t *testing.T) {
`,
err: "Duplicate select condition found: arch()",
},
{
name: "select with duplicate binding",
input: `
m {
foo: select((arch(), os()), {
(any @ bar, any @ bar): true,
}),
}
`,
err: "Found duplicate select pattern binding: bar",
},
// TODO: test more parser errors
}

View file

@ -143,7 +143,7 @@ func (p *printer) printSelect(s *Select) {
return
}
if len(s.Cases) == 1 && len(s.Cases[0].Patterns) == 1 {
if str, ok := s.Cases[0].Patterns[0].(*String); ok && str.Value == default_select_branch_name {
if str, ok := s.Cases[0].Patterns[0].Value.(*String); ok && str.Value == default_select_branch_name {
p.printExpression(s.Cases[0].Value)
p.pos = s.RBracePos
return
@ -196,22 +196,7 @@ func (p *printer) printSelect(s *Select) {
p.printToken("(", p.pos)
}
for i, pat := range c.Patterns {
switch pat := pat.(type) {
case *String:
if pat.Value != default_select_branch_name {
p.printToken(strconv.Quote(pat.Value), pat.LiteralPos)
} else {
p.printToken("default", pat.LiteralPos)
}
case *Bool:
s := "false"
if pat.Value {
s = "true"
}
p.printToken(s, pat.LiteralPos)
default:
panic("Unhandled case")
}
p.printSelectPattern(pat)
if i < len(c.Patterns)-1 {
p.printToken(",", p.pos)
p.requestSpace()
@ -240,6 +225,33 @@ func (p *printer) printSelect(s *Select) {
}
}
func (p *printer) printSelectPattern(pat SelectPattern) {
switch pat := pat.Value.(type) {
case *String:
if pat.Value == default_select_branch_name {
p.printToken("default", pat.LiteralPos)
} else if pat.Value == any_select_branch_name {
p.printToken("any", pat.LiteralPos)
} else {
p.printToken(strconv.Quote(pat.Value), pat.LiteralPos)
}
case *Bool:
s := "false"
if pat.Value {
s = "true"
}
p.printToken(s, pat.LiteralPos)
default:
panic("Unhandled case")
}
if pat.Binding.Name != "" {
p.requestSpace()
p.printToken("@", pat.Binding.Pos())
p.requestSpace()
p.printExpression(&pat.Binding)
}
}
func (p *printer) printList(list []Expression, pos, endPos scanner.Position) {
p.requestSpace()
p.printToken("[", pos)

View file

@ -733,6 +733,26 @@ foo {
default: [],
}),
}
`,
},
{
name: "Select with bindings",
input: `
foo {
stuff: select(arch(), {
"x86": "a",
any
@ baz: "b" + baz,
}),
}
`,
output: `
foo {
stuff: select(arch(), {
"x86": "a",
any @ baz: "b" + baz,
}),
}
`,
},
}

View file

@ -151,6 +151,17 @@ type ConfigurableValue struct {
boolValue bool
}
func (c *ConfigurableValue) toExpression() parser.Expression {
switch c.typ {
case configurableValueTypeBool:
return &parser.Bool{Value: c.boolValue}
case configurableValueTypeString:
return &parser.String{Value: c.stringValue}
default:
panic(fmt.Sprintf("Unhandled configurableValueType: %s", c.typ.String()))
}
}
func (c *ConfigurableValue) String() string {
switch c.typ {
case configurableValueTypeString:
@ -194,6 +205,7 @@ const (
configurablePatternTypeString configurablePatternType = iota
configurablePatternTypeBool
configurablePatternTypeDefault
configurablePatternTypeAny
)
func (v *configurablePatternType) String() string {
@ -204,6 +216,8 @@ func (v *configurablePatternType) String() string {
return "bool"
case configurablePatternTypeDefault:
return "default"
case configurablePatternTypeAny:
return "any"
default:
panic("unimplemented")
}
@ -223,6 +237,7 @@ type ConfigurablePattern struct {
typ configurablePatternType
stringValue string
boolValue bool
binding string
}
func NewStringConfigurablePattern(s string) ConfigurablePattern {
@ -252,6 +267,9 @@ func (p *ConfigurablePattern) matchesValue(v ConfigurableValue) bool {
if v.typ == configurableValueTypeUndefined {
return false
}
if p.typ == configurablePatternTypeAny {
return true
}
if p.typ != v.typ.patternType() {
return false
}
@ -272,6 +290,9 @@ func (p *ConfigurablePattern) matchesValueType(v ConfigurableValue) bool {
if v.typ == configurableValueTypeUndefined {
return true
}
if p.typ == configurablePatternTypeAny {
return true
}
return p.typ == v.typ.patternType()
}
@ -525,7 +546,8 @@ func (c *singleConfigurable[T]) evaluateNonTransitive(propertyName string, evalu
}
}
if allMatch && !foundMatch {
if r, err := expressionToConfiguredValue[T](case_.value, c.scope); err != nil {
newScope := createScopeWithBindings(c.scope, case_.patterns, values)
if r, err := expressionToConfiguredValue[T](case_.value, newScope); err != nil {
evaluator.PropertyErrorf(propertyName, "%s", err.Error())
return nil
} else {
@ -542,6 +564,27 @@ func (c *singleConfigurable[T]) evaluateNonTransitive(propertyName string, evalu
return nil
}
func createScopeWithBindings(parent *parser.Scope, patterns []ConfigurablePattern, values []ConfigurableValue) *parser.Scope {
result := parent
for i, pattern := range patterns {
if pattern.binding != "" {
if result == parent {
result = parser.NewScope(parent)
}
err := result.HandleAssignment(&parser.Assignment{
Name: pattern.binding,
Value: values[i].toExpression(),
Assigner: "=",
})
if err != nil {
// This shouldn't happen due to earlier validity checks
panic(err.Error())
}
}
}
return result
}
func appendConfiguredValues[T ConfigurableElements](a, b *T) *T {
if a == nil && b == nil {
return nil

View file

@ -445,10 +445,12 @@ func (ctx *unpackContext) unpackToConfigurable(propertyName string, property *pa
for _, c := range v.Cases {
patterns := make([]ConfigurablePattern, len(c.Patterns))
for i, pat := range c.Patterns {
switch pat := pat.(type) {
switch pat := pat.Value.(type) {
case *parser.String:
if pat.Value == "__soong_conditions_default__" {
patterns[i].typ = configurablePatternTypeDefault
} else if pat.Value == "__soong_conditions_any__" {
patterns[i].typ = configurablePatternTypeAny
} else {
patterns[i].typ = configurablePatternTypeString
patterns[i].stringValue = pat.Value
@ -459,6 +461,7 @@ func (ctx *unpackContext) unpackToConfigurable(propertyName string, property *pa
default:
panic("unimplemented")
}
patterns[i].binding = pat.Binding.Name
}
case_ := reflect.New(configurableCaseType)