From dab3b1a1c064d07db3bcf222730937fe6abce70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thi=C3=A9baud=20Weksteen?= Date: Mon, 6 Mar 2023 13:54:07 +1100 Subject: [PATCH] Refactor treble_sepolicy_tests.py Introduce a new class TestPolicy to capture all the previous global variables. This class contains the constructor and loading methods (Get*) to load its internal state. The tests are modified to accept a TestPolicy as first argument. This commit is a no-op. There is no change to the tests. `git show --ignore-space-change` can be used to skip over the re-indentation due to the new class. Bug: 269182257 Test: m selinux_policy (runs treble_sepolicy_tests against all compatible versions) Test: Set DEBUG=True, compare generated scontexts. Identical. Change-Id: Ia8da115dc1c0109b835e03b95da029b35712d251 --- tests/treble_sepolicy_tests.py | 327 ++++++++++++++++----------------- 1 file changed, 156 insertions(+), 171 deletions(-) diff --git a/tests/treble_sepolicy_tests.py b/tests/treble_sepolicy_tests.py index b49f138a7..c966423c9 100644 --- a/tests/treble_sepolicy_tests.py +++ b/tests/treble_sepolicy_tests.py @@ -51,172 +51,166 @@ class scontext: self.entrypointpaths = [] self.error = "" -def PrintScontexts(): - for d in sorted(alldomains.keys()): - sctx = alldomains[d] - print(d) - print("\tcoredomain="+str(sctx.coredomain)) - print("\tappdomain="+str(sctx.appdomain)) - print("\tfromSystem="+str(sctx.fromSystem)) - print("\tfromVendor="+str(sctx.fromVendor)) - print("\tattributes="+str(sctx.attributes)) - print("\tentrypoints="+str(sctx.entrypoints)) - print("\tentrypointpaths=") - if sctx.entrypointpaths is not None: - for path in sctx.entrypointpaths: - print("\t\t"+str(path)) -alldomains = {} -coredomains = set() -appdomains = set() -vendordomains = set() -pol = None +class TestPolicy: + """A policy loaded in memory with its domains easily accessible.""" -# compat vars -alltypes = set() -oldalltypes = set() -compatMapping = None -pubtypes = set() + def __init__(self): + self.alldomains = {} + self.coredomains = set() + self.appdomains = set() + self.vendordomains = set() + self.pol = None -# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE -FakeTreble = False + # compat vars + self.alltypes = set() + self.oldalltypes = set() + self.compatMapping = None + self.pubtypes = set() -def GetAllDomains(pol): - global alldomains - for result in pol.QueryTypeAttribute("domain", True): - alldomains[result] = scontext() + # Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE + self.FakeTreble = False -def GetAppDomains(): - global appdomains - global alldomains - for d in alldomains: - # The application of the "appdomain" attribute is trusted because core - # selinux policy contains neverallow rules that enforce that only zygote - # and runas spawned processes may transition to processes that have - # the appdomain attribute. - if "appdomain" in alldomains[d].attributes: - alldomains[d].appdomain = True - appdomains.add(d) + def GetAllDomains(self): + for result in self.pol.QueryTypeAttribute("domain", True): + self.alldomains[result] = scontext() -def GetCoreDomains(): - global alldomains - global coredomains - for d in alldomains: - domain = alldomains[d] - # TestCoredomainViolations will verify if coredomain was incorrectly - # applied. - if "coredomain" in domain.attributes: - domain.coredomain = True - coredomains.add(d) - # check whether domains are executed off of /system or /vendor - if d in coredomainAllowlist: - continue - # TODO(b/153112003): add checks to prevent app domains from being - # incorrectly labeled as coredomain. Apps don't have entrypoints as - # they're always dynamically transitioned to by zygote. - if d in appdomains: - continue - # TODO(b/153112747): need to handle cases where there is a dynamic - # transition OR there happens to be no context in AOSP files. - if not domain.entrypointpaths: - continue + def GetAppDomains(self): + for d in self.alldomains: + # The application of the "appdomain" attribute is trusted because core + # selinux policy contains neverallow rules that enforce that only zygote + # and runas spawned processes may transition to processes that have + # the appdomain attribute. + if "appdomain" in self.alldomains[d].attributes: + self.alldomains[d].appdomain = True + self.appdomains.add(d) - for path in domain.entrypointpaths: - vendor = any(MatchPathPrefix(path, prefix) for prefix in - ["/vendor", "/odm"]) - system = any(MatchPathPrefix(path, prefix) for prefix in - ["/init", "/system_ext", "/product" ]) + def GetCoreDomains(self): + for d in self.alldomains: + domain = self.alldomains[d] + # TestCoredomainViolations will verify if coredomain was incorrectly + # applied. + if "coredomain" in domain.attributes: + domain.coredomain = True + self.coredomains.add(d) + # check whether domains are executed off of /system or /vendor + if d in coredomainAllowlist: + continue + # TODO(b/153112003): add checks to prevent app domains from being + # incorrectly labeled as coredomain. Apps don't have entrypoints as + # they're always dynamically transitioned to by zygote. + if d in self.appdomains: + continue + # TODO(b/153112747): need to handle cases where there is a dynamic + # transition OR there happens to be no context in AOSP files. + if not domain.entrypointpaths: + continue - # only mark entrypoint as system if it is not in legacy /system/vendor - if MatchPathPrefix(path, "/system/vendor"): - vendor = True - elif MatchPathPrefix(path, "/system"): - system = True + for path in domain.entrypointpaths: + vendor = any(MatchPathPrefix(path, prefix) for prefix in + ["/vendor", "/odm"]) + system = any(MatchPathPrefix(path, prefix) for prefix in + ["/init", "/system_ext", "/product" ]) - if not vendor and not system: - domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n" + # only mark entrypoint as system if it is not in legacy /system/vendor + if MatchPathPrefix(path, "/system/vendor"): + vendor = True + elif MatchPathPrefix(path, "/system"): + system = True - domain.fromSystem = domain.fromSystem or system - domain.fromVendor = domain.fromVendor or vendor + if not vendor and not system: + domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n" -### -# Add the entrypoint type and path(s) to each domain. -# -def GetDomainEntrypoints(pol): - global alldomains - for x in pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])): - if not x.sctx in alldomains: - continue - alldomains[x.sctx].entrypoints.append(str(x.tctx)) - # postinstall_file represents a special case specific to A/B OTAs. - # Update_engine mounts a partition and relabels it postinstall_file. - # There is no file_contexts entry associated with postinstall_file - # so skip the lookup. - if x.tctx == "postinstall_file": - continue - entrypointpath = pol.QueryFc(x.tctx) - if not entrypointpath: - continue - alldomains[x.sctx].entrypointpaths.extend(entrypointpath) -### -# Get attributes associated with each domain -# -def GetAttributes(pol): - global alldomains - for domain in alldomains: - for result in pol.QueryTypeAttribute(domain, False): - alldomains[domain].attributes.add(result) + domain.fromSystem = domain.fromSystem or system + domain.fromVendor = domain.fromVendor or vendor -def GetAllTypes(pol, oldpol): - global alltypes - global oldalltypes - alltypes = pol.GetAllTypes(False) - oldalltypes = oldpol.GetAllTypes(False) + ### + # Add the entrypoint type and path(s) to each domain. + # + def GetDomainEntrypoints(self): + for x in self.pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])): + if not x.sctx in self.alldomains: + continue + self.alldomains[x.sctx].entrypoints.append(str(x.tctx)) + # postinstall_file represents a special case specific to A/B OTAs. + # Update_engine mounts a partition and relabels it postinstall_file. + # There is no file_contexts entry associated with postinstall_file + # so skip the lookup. + if x.tctx == "postinstall_file": + continue + entrypointpath = self.pol.QueryFc(x.tctx) + if not entrypointpath: + continue + self.alldomains[x.sctx].entrypointpaths.extend(entrypointpath) -def setup(pol): - GetAllDomains(pol) - GetAttributes(pol) - GetDomainEntrypoints(pol) - GetAppDomains() - GetCoreDomains() + ### + # Get attributes associated with each domain + # + def GetAttributes(self): + for domain in self.alldomains: + for result in self.pol.QueryTypeAttribute(domain, False): + self.alldomains[domain].attributes.add(result) -# setup for the policy compatibility tests -def compatSetup(pol, oldpol, mapping, types): - global compatMapping - global pubtypes + def setup(self, pol): + self.pol = pol + self.GetAllDomains() + self.GetAttributes() + self.GetDomainEntrypoints() + self.GetAppDomains() + self.GetCoreDomains() - GetAllTypes(pol, oldpol) - compatMapping = mapping - pubtypes = types + def GetAllTypes(self, basepol, oldpol): + self.alltypes = basepol.GetAllTypes(False) + self.oldalltypes = oldpol.GetAllTypes(False) + + # setup for the policy compatibility tests + def compatSetup(self, basepol, oldpol, mapping, types): + self.GetAllTypes(basepol, oldpol) + self.compatMapping = mapping + self.pubtypes = types + + def DomainsWithAttribute(self, attr): + domains = [] + for domain in self.alldomains: + if attr in self.alldomains[domain].attributes: + domains.append(domain) + return domains + + def PrintScontexts(self): + for d in sorted(self.alldomains.keys()): + sctx = self.alldomains[d] + print(d) + print("\tcoredomain="+str(sctx.coredomain)) + print("\tappdomain="+str(sctx.appdomain)) + print("\tfromSystem="+str(sctx.fromSystem)) + print("\tfromVendor="+str(sctx.fromVendor)) + print("\tattributes="+str(sctx.attributes)) + print("\tentrypoints="+str(sctx.entrypoints)) + print("\tentrypointpaths=") + if sctx.entrypointpaths is not None: + for path in sctx.entrypointpaths: + print("\t\t"+str(path)) -def DomainsWithAttribute(attr): - global alldomains - domains = [] - for domain in alldomains: - if attr in alldomains[domain].attributes: - domains.append(domain) - return domains ############################################################# # Tests ############################################################# -def TestCoredomainViolations(): - global alldomains +def TestCoredomainViolations(test_policy): # verify that all domains launched from /system have the coredomain # attribute ret = "" - for d in alldomains: - domain = alldomains[d] + for d in test_policy.alldomains: + domain = test_policy.alldomains[d] if domain.fromSystem and domain.fromVendor: ret += "The following domain is system and vendor: " + d + "\n" - for domain in alldomains.values(): + for domain in test_policy.alldomains.values(): ret += domain.error violators = [] - for d in alldomains: - domain = alldomains[d] + for d in test_policy.alldomains: + domain = test_policy.alldomains[d] if domain.fromSystem and "coredomain" not in domain.attributes: violators.append(d); if len(violators) > 0: @@ -228,8 +222,8 @@ def TestCoredomainViolations(): # verify that all domains launched form /vendor do not have the coredomain # attribute violators = [] - for d in alldomains: - domain = alldomains[d] + for d in test_policy.alldomains: + domain = test_policy.alldomains[d] if domain.fromVendor and "coredomain" in domain.attributes: violators.append(d) if len(violators) > 0: @@ -243,17 +237,13 @@ def TestCoredomainViolations(): ### # Make sure that any new public type introduced in the new policy that was not # present in the old policy has been recorded in the mapping file. -def TestNoUnmappedNewTypes(): - global alltypes - global oldalltypes - global compatMapping - global pubtypes - newt = alltypes - oldalltypes +def TestNoUnmappedNewTypes(test_policy): + newt = test_policy.alltypes - test_policy.oldalltypes ret = "" violators = [] for n in newt: - if n in pubtypes and compatMapping.rTypeattributesets.get(n) is None: + if n in test_policy.pubtypes and test_policy.compatMapping.rTypeattributesets.get(n) is None: violators.append(n) if len(violators) > 0: @@ -270,16 +260,13 @@ def TestNoUnmappedNewTypes(): ### # Make sure that any public type removed in the current policy has its # declaration added to the mapping file for use in non-platform policy -def TestNoUnmappedRmTypes(): - global alltypes - global oldalltypes - global compatMapping - rmt = oldalltypes - alltypes +def TestNoUnmappedRmTypes(test_policy): + rmt = test_policy.oldalltypes - test_policy.alltypes ret = "" violators = [] for o in rmt: - if o in compatMapping.pubtypes and not o in compatMapping.types: + if o in test_policy.compatMapping.pubtypes and not o in test_policy.compatMapping.types: violators.append(o) if len(violators) > 0: @@ -292,34 +279,32 @@ def TestNoUnmappedRmTypes(): ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/822743\n" return ret -def TestTrebleCompatMapping(): - ret = TestNoUnmappedNewTypes() - ret += TestNoUnmappedRmTypes() +def TestTrebleCompatMapping(test_policy): + ret = TestNoUnmappedNewTypes(test_policy) + ret += TestNoUnmappedRmTypes(test_policy) return ret -def TestViolatorAttribute(attribute): - global FakeTreble +def TestViolatorAttribute(test_policy, attribute): ret = "" - if FakeTreble: + if test_policy.FakeTreble: return ret - violators = DomainsWithAttribute(attribute) + violators = test_policy.DomainsWithAttribute(attribute) if len(violators) > 0: ret += "SELinux: The following domains violate the Treble ban " ret += "against use of the " + attribute + " attribute: " ret += " ".join(str(x) for x in sorted(violators)) + "\n" return ret -def TestViolatorAttributes(): +def TestViolatorAttributes(test_policy): ret = "" - ret += TestViolatorAttribute("socket_between_core_and_vendor_violators") - ret += TestViolatorAttribute("vendor_executes_system_violators") + ret += TestViolatorAttribute(test_policy, "socket_between_core_and_vendor_violators") + ret += TestViolatorAttribute(test_policy, "vendor_executes_system_violators") return ret # TODO move this to sepolicy_tests -def TestCoreDataTypeViolations(): - global pol - return pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/", +def TestCoreDataTypeViolations(test_policy): + return test_policy.pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/", "/data/vendor_de/"], [], "core_data_file_type") ### @@ -349,7 +334,7 @@ def do_main(libpath): Args: libpath: string, path to libsepolwrap.so """ - global pol, FakeTreble + test_policy = TestPolicy() usage = "treble_sepolicy_tests " usage += "-f nonplat_file_contexts -f plat_file_contexts " @@ -402,27 +387,27 @@ def do_main(libpath): oldpol = policy.Policy(options.oldpolicy, None, libpath) mapping = mini_parser.MiniCilParser(options.mapping) pubpol = mini_parser.MiniCilParser(options.base_pub_policy) - compatSetup(basepol, oldpol, mapping, pubpol.types) + test_policy.compatSetup(basepol, oldpol, mapping, pubpol.types) if options.faketreble: - FakeTreble = True + test_policy.FakeTreble = True pol = policy.Policy(options.policy, options.file_contexts, libpath) - setup(pol) + test_policy.setup(pol) if DEBUG: - PrintScontexts() + test_policy.PrintScontexts() results = "" # If an individual test is not specified, run all tests. if options.tests is None: for t in Tests.values(): - results += t() + results += t(test_policy) else: for tn in options.tests: t = Tests.get(tn) if t: - results += t() + results += t(test_policy) else: err = "Error: unknown test: " + tn + "\n" err += "Available tests:\n"