diff --git a/context.go b/context.go index d7db16b..8604f21 100644 --- a/context.go +++ b/context.go @@ -1872,34 +1872,46 @@ func (c *Context) findReverseDependency(module *moduleInfo, config any, destName }} } -// applyIncomingTransitions takes a variationMap being used to add a dependency on a module in a moduleGroup -// and applies the IncomingTransition method of each completed TransitionMutator to modify the requested variation. -// It finds a variant that existed before the TransitionMutator ran that is a subset of the requested variant to -// use as the module context for IncomingTransition. -func (c *Context) applyIncomingTransitions(config any, group *moduleGroup, variant variationMap, requestedVariations []Variation) { +// applyTransitions takes a variationMap being used to add a dependency on a module in a moduleGroup +// and applies the OutgoingTransition and IncomingTransition methods of each completed TransitionMutator to +// modify the requested variation. It finds a variant that existed before the TransitionMutator ran that is +// a subset of the requested variant to use as the module context for IncomingTransition. +func (c *Context) applyTransitions(config any, module *moduleInfo, group *moduleGroup, variant variationMap, + requestedVariations []Variation) { for _, transitionMutator := range c.transitionMutators { + // Apply the outgoing transition if it was not explicitly requested. + explicitlyRequested := slices.ContainsFunc(requestedVariations, func(variation Variation) bool { + return variation.Mutator == transitionMutator.name + }) + sourceVariation := variant[transitionMutator.name] + outgoingVariation := sourceVariation + if !explicitlyRequested { + ctx := &outgoingTransitionContextImpl{ + transitionContextImpl{context: c, source: module, dep: nil, depTag: nil, config: config}, + } + outgoingVariation = transitionMutator.mutator.OutgoingTransition(ctx, sourceVariation) + } + + // Find an appropriate module to use as the context for the IncomingTransition. appliedIncomingTransition := false for _, inputVariant := range transitionMutator.inputVariants[group] { if inputVariant.variant.variations.subsetOf(variant) { - sourceVariation := variant[transitionMutator.name] - + // Apply the incoming transition. ctx := &incomingTransitionContextImpl{ transitionContextImpl{context: c, source: nil, dep: inputVariant, depTag: nil, config: config}, } - outgoingVariation := transitionMutator.mutator.IncomingTransition(ctx, sourceVariation) - variant[transitionMutator.name] = outgoingVariation + finalVariation := transitionMutator.mutator.IncomingTransition(ctx, outgoingVariation) + variant[transitionMutator.name] = finalVariation appliedIncomingTransition = true break } } - if !appliedIncomingTransition { - // The transition mutator didn't apply anything to the target module, remove the variation unless it + if !appliedIncomingTransition && !explicitlyRequested { + // The transition mutator didn't apply anything to the target variant, remove the variation unless it // was explicitly requested when adding the dependency. - if !slices.ContainsFunc(requestedVariations, func(v Variation) bool { return v.Mutator == transitionMutator.name }) { - delete(variant, transitionMutator.name) - } + delete(variant, transitionMutator.name) } } } @@ -1928,7 +1940,7 @@ func (c *Context) findVariant(module *moduleInfo, config any, newVariant[v.Mutator] = v.Variation } - c.applyIncomingTransitions(config, possibleDeps, newVariant, requestedVariations) + c.applyTransitions(config, module, possibleDeps, newVariant, requestedVariations) check := func(variant variationMap) bool { if far { @@ -3002,7 +3014,7 @@ func (c *Context) runMutator(config interface{}, mutator *mutatorInfo, // Update module group to contain newly split variants if module.splitModules != nil { if isTransitionMutator { - // For transition mutators, save the pre-split variant for reusing later in applyIncomingTransitions. + // For transition mutators, save the pre-split variant for reusing later in applyTransitions. transitionMutatorInputVariants[group] = append(transitionMutatorInputVariants[group], module) } group.modules, i = spliceModules(group.modules, i, module.splitModules) diff --git a/transition_test.go b/transition_test.go index 73e92af..3e07b53 100644 --- a/transition_test.go +++ b/transition_test.go @@ -180,7 +180,7 @@ func TestTransition(t *testing.T) { func TestPostTransitionDeps(t *testing.T) { ctx, errs := testTransition(fmt.Sprintf(testTransitionBp, - `post_transition_deps: ["D:late", "E:d", "F"],`)) + `post_transition_deps: ["C", "D:late", "E:d", "F"],`)) assertNoErrors(t, errs) // Module A uses Split to create a and b variants @@ -209,13 +209,13 @@ func TestPostTransitionDeps(t *testing.T) { checkTransitionDeps(t, ctx, A_a, "B(a)", "C(a)") checkTransitionDeps(t, ctx, A_b, "B(b)", "C(b)") - // Verify post-mutator dependencies added to B: - // C(c) is a pre-mutator dependency + // Verify post-mutator dependencies added to B. The first C(c) is a pre-mutator dependency. + // C(c) was added by C and rewritten by OutgoingTransition on B // D(d) was added by D:late and rewritten by IncomingTransition on D // E(d) was added by E:d // F() was added by F, and ignored the existing variation on B - checkTransitionDeps(t, ctx, B_a, "C(c)", "D(d)", "E(d)", "F()") - checkTransitionDeps(t, ctx, B_b, "C(c)", "D(d)", "E(d)", "F()") + checkTransitionDeps(t, ctx, B_a, "C(c)", "C(c)", "D(d)", "E(d)", "F()") + checkTransitionDeps(t, ctx, B_b, "C(c)", "C(c)", "D(d)", "E(d)", "F()") checkTransitionDeps(t, ctx, C_a, "D(d)") checkTransitionDeps(t, ctx, C_b, "D(d)") checkTransitionDeps(t, ctx, C_c, "D(d)")