diff --git a/module_ctx.go b/module_ctx.go index c13df93..abd87c1 100644 --- a/module_ctx.go +++ b/module_ctx.go @@ -147,6 +147,9 @@ type ModuleContext interface { OtherModuleErrorf(m Module, fmt string, args ...interface{}) OtherModuleDependencyTag(m Module) DependencyTag + GetDirectDepWithTag(name string, tag DependencyTag) Module + GetDirectDep(name string) (Module, DependencyTag) + VisitDirectDeps(visit func(Module)) VisitDirectDepsIf(pred func(Module) bool, visit func(Module)) VisitDepsDepthFirst(visit func(Module)) @@ -300,6 +303,34 @@ func (m *baseModuleContext) OtherModuleDependencyTag(logicModule Module) Depende return nil } +// GetDirectDep returns the Module and DependencyTag for the direct dependency with the specified +// name, or nil if none exists. +func (m *baseModuleContext) GetDirectDep(name string) (Module, DependencyTag) { + for _, dep := range m.module.directDeps { + if dep.module.Name() == name { + return dep.module.logicModule, dep.tag + } + } + + return nil, nil +} + +// GetDirectDepWithTag returns the Module the direct dependency with the specified name, or nil if +// none exists. It panics if the dependency does not have the specified tag. +func (m *baseModuleContext) GetDirectDepWithTag(name string, tag DependencyTag) Module { + for _, dep := range m.module.directDeps { + if dep.module.Name() == name { + if dep.tag != tag { + panic(fmt.Errorf("found dependency %q with tag %#v, expected tag %#v", + dep.module, dep.tag, tag)) + } + return dep.module.logicModule + } + } + + return nil +} + func (m *baseModuleContext) VisitDirectDeps(visit func(Module)) { defer func() { if r := recover(); r != nil { @@ -488,6 +519,9 @@ type TopDownMutatorContext interface { OtherModuleErrorf(m Module, fmt string, args ...interface{}) OtherModuleDependencyTag(m Module) DependencyTag + GetDirectDepWithTag(name string, tag DependencyTag) Module + GetDirectDep(name string) (Module, DependencyTag) + VisitDirectDeps(visit func(Module)) VisitDirectDepsIf(pred func(Module) bool, visit func(Module)) VisitDepsDepthFirst(visit func(Module))