diff --git a/context.go b/context.go index a4fea26..482a2b1 100644 --- a/context.go +++ b/context.go @@ -1542,31 +1542,15 @@ func blueprintDepsMutator(ctx BottomUpMutatorContext) { } } -// findMatchingVariant searches the moduleGroup for a module with the same variant as module, -// and returns the matching module, or nil if one is not found. -func (c *Context) findMatchingVariant(module *moduleInfo, possible *moduleGroup, reverse bool) *moduleInfo { +// findExactVariantOrSingle searches the moduleGroup for a module with the same variant as module, +// and returns the matching module, or nil if one is not found. A group with exactly one module +// is always considered matching. +func findExactVariantOrSingle(module *moduleInfo, possible *moduleGroup, reverse bool) *moduleInfo { if len(possible.modules) == 1 { return possible.modules[0] } else { - var variantToMatch variationMap - if !reverse { - // For forward dependency, ignore local variants by matching against - // dependencyVariant which doesn't have the local variants - variantToMatch = module.variant.dependencyVariations - } else { - // For reverse dependency, use all the variants - variantToMatch = module.variant.variations - } - for _, m := range possible.modules { - if m.variant.variations.equal(variantToMatch) { - return m - } - } - for _, m := range possible.aliases { - if m.variant.variations.equal(variantToMatch) { - return m.target - } - } + found, _ := findVariant(module, possible, nil, false, reverse) + return found } return nil @@ -1589,7 +1573,7 @@ func (c *Context) addDependency(module *moduleInfo, tag DependencyTag, depName s return c.discoveredMissingDependencies(module, depName) } - if m := c.findMatchingVariant(module, possibleDeps, false); m != nil { + if m := findExactVariantOrSingle(module, possibleDeps, false); m != nil { module.newDirectDeps = append(module.newDirectDeps, depInfo{m, tag}) atomic.AddUint32(&c.depsModified, 1) return nil @@ -1626,7 +1610,7 @@ func (c *Context) findReverseDependency(module *moduleInfo, destName string) (*m }} } - if m := c.findMatchingVariant(module, possibleDeps, true); m != nil { + if m := findExactVariantOrSingle(module, possibleDeps, true); m != nil { return m, nil } @@ -1644,7 +1628,7 @@ func (c *Context) findReverseDependency(module *moduleInfo, destName string) (*m }} } -func (c *Context) findVariant(module *moduleInfo, possibleDeps *moduleGroup, variations []Variation, far bool, reverse bool) (*moduleInfo, variationMap) { +func findVariant(module *moduleInfo, possibleDeps *moduleGroup, variations []Variation, far bool, reverse bool) (*moduleInfo, variationMap) { // We can't just append variant.Variant to module.dependencyVariant.variantName and // compare the strings because the result won't be in mutator registration order. // Create a new map instead, and then deep compare the maps. diff --git a/context_test.go b/context_test.go index 0541c06..073693c 100644 --- a/context_test.go +++ b/context_test.go @@ -606,3 +606,169 @@ func TestParseFailsForModuleWithoutName(t *testing.T) { t.Errorf("Incorrect errors; expected:\n%s\ngot:\n%s", expectedErrs, errs) } } + +func Test_findVariant(t *testing.T) { + module := &moduleInfo{ + variant: variant{ + name: "normal_local", + variations: variationMap{ + "normal": "normal", + "local": "local", + }, + dependencyVariations: variationMap{ + "normal": "normal", + }, + }, + } + + type alias struct { + variations variationMap + target int + } + + makeDependencyGroup := func(modules []*moduleInfo, aliases []alias) *moduleGroup { + group := &moduleGroup{ + name: "dep", + modules: modules, + } + + for _, alias := range aliases { + group.aliases = append(group.aliases, &moduleAlias{ + variant: variant{ + variations: alias.variations, + }, + target: group.modules[alias.target], + }) + } + + for _, m := range group.modules { + m.group = group + } + return group + } + + tests := []struct { + name string + possibleDeps *moduleGroup + variations []Variation + far bool + reverse bool + want string + }{ + { + name: "AddVariationDependencies(nil)", + // A dependency that matches the non-local variations of the module + possibleDeps: makeDependencyGroup([]*moduleInfo{ + { + variant: variant{ + name: "normal", + variations: variationMap{ + "normal": "normal", + }, + }, + }, + }, nil), + variations: nil, + far: false, + reverse: false, + want: "normal", + }, + { + name: "AddVariationDependencies(nil) to alias", + // A dependency with an alias that matches the non-local variations of the module + possibleDeps: makeDependencyGroup([]*moduleInfo{ + { + variant: variant{ + name: "normal_a", + variations: variationMap{ + "normal": "normal", + "a": "a", + }, + }, + }, + }, []alias{ + { + variations: variationMap{ + "normal": "normal", + }, + target: 0, + }, + }), + variations: nil, + far: false, + reverse: false, + want: "normal_a", + }, + { + name: "AddVariationDependencies(a)", + // A dependency with local variations + possibleDeps: makeDependencyGroup([]*moduleInfo{ + { + variant: variant{ + name: "normal_a", + variations: variationMap{ + "normal": "normal", + "a": "a", + }, + }, + }, + }, nil), + variations: []Variation{{"a", "a"}}, + far: false, + reverse: false, + want: "normal_a", + }, + { + name: "AddFarVariationDependencies(far)", + // A dependency with far variations + possibleDeps: makeDependencyGroup([]*moduleInfo{ + { + variant: variant{ + name: "far", + variations: variationMap{ + "far": "far", + }, + }, + }, + }, nil), + variations: []Variation{{"far", "far"}}, + far: true, + reverse: false, + want: "far", + }, + { + name: "AddFarVariationDependencies(far) to alias", + // A dependency with far variations and aliases + possibleDeps: makeDependencyGroup([]*moduleInfo{ + { + variant: variant{ + name: "far_a", + variations: variationMap{ + "far": "far", + "a": "a", + }, + }, + }, + }, []alias{ + { + variations: variationMap{ + "far": "far", + }, + target: 0, + }, + }), + variations: []Variation{{"far", "far"}}, + far: true, + reverse: false, + want: "far_a", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := findVariant(module, tt.possibleDeps, tt.variations, tt.far, tt.reverse) + if g, w := got.String(), fmt.Sprintf("module %q variant %q", "dep", tt.want); g != w { + t.Errorf("findVariant() got = %v, want %v", g, w) + } + }) + } +} diff --git a/module_ctx.go b/module_ctx.go index c992c0a..34f70f4 100644 --- a/module_ctx.go +++ b/module_ctx.go @@ -510,7 +510,7 @@ func (m *baseModuleContext) OtherModuleDependencyVariantExists(variations []Vari if possibleDeps == nil { return false } - found, _ := m.context.findVariant(m.module, possibleDeps, variations, false, false) + found, _ := findVariant(m.module, possibleDeps, variations, false, false) return found != nil } @@ -519,7 +519,7 @@ func (m *baseModuleContext) OtherModuleReverseDependencyVariantExists(name strin if possibleDeps == nil { return false } - found, _ := m.context.findVariant(m.module, possibleDeps, nil, false, true) + found, _ := findVariant(m.module, possibleDeps, nil, false, true) return found != nil }