Merge "Refactor treble_sepolicy_tests.py"

This commit is contained in:
Thiébaud Weksteen 2023-03-09 06:10:48 +00:00 committed by Gerrit Code Review
commit c691211c02

View file

@ -51,172 +51,166 @@ class scontext:
self.entrypointpaths = []
self.error = ""
def PrintScontexts():
for d in sorted(alldomains.keys()):
sctx = alldomains[d]
print(d)
print("\tcoredomain="+str(sctx.coredomain))
print("\tappdomain="+str(sctx.appdomain))
print("\tfromSystem="+str(sctx.fromSystem))
print("\tfromVendor="+str(sctx.fromVendor))
print("\tattributes="+str(sctx.attributes))
print("\tentrypoints="+str(sctx.entrypoints))
print("\tentrypointpaths=")
if sctx.entrypointpaths is not None:
for path in sctx.entrypointpaths:
print("\t\t"+str(path))
alldomains = {}
coredomains = set()
appdomains = set()
vendordomains = set()
pol = None
class TestPolicy:
"""A policy loaded in memory with its domains easily accessible."""
# compat vars
alltypes = set()
oldalltypes = set()
compatMapping = None
pubtypes = set()
def __init__(self):
self.alldomains = {}
self.coredomains = set()
self.appdomains = set()
self.vendordomains = set()
self.pol = None
# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
FakeTreble = False
# compat vars
self.alltypes = set()
self.oldalltypes = set()
self.compatMapping = None
self.pubtypes = set()
def GetAllDomains(pol):
global alldomains
for result in pol.QueryTypeAttribute("domain", True):
alldomains[result] = scontext()
# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
self.FakeTreble = False
def GetAppDomains():
global appdomains
global alldomains
for d in alldomains:
# The application of the "appdomain" attribute is trusted because core
# selinux policy contains neverallow rules that enforce that only zygote
# and runas spawned processes may transition to processes that have
# the appdomain attribute.
if "appdomain" in alldomains[d].attributes:
alldomains[d].appdomain = True
appdomains.add(d)
def GetAllDomains(self):
for result in self.pol.QueryTypeAttribute("domain", True):
self.alldomains[result] = scontext()
def GetCoreDomains():
global alldomains
global coredomains
for d in alldomains:
domain = alldomains[d]
# TestCoredomainViolations will verify if coredomain was incorrectly
# applied.
if "coredomain" in domain.attributes:
domain.coredomain = True
coredomains.add(d)
# check whether domains are executed off of /system or /vendor
if d in coredomainAllowlist:
continue
# TODO(b/153112003): add checks to prevent app domains from being
# incorrectly labeled as coredomain. Apps don't have entrypoints as
# they're always dynamically transitioned to by zygote.
if d in appdomains:
continue
# TODO(b/153112747): need to handle cases where there is a dynamic
# transition OR there happens to be no context in AOSP files.
if not domain.entrypointpaths:
continue
def GetAppDomains(self):
for d in self.alldomains:
# The application of the "appdomain" attribute is trusted because core
# selinux policy contains neverallow rules that enforce that only zygote
# and runas spawned processes may transition to processes that have
# the appdomain attribute.
if "appdomain" in self.alldomains[d].attributes:
self.alldomains[d].appdomain = True
self.appdomains.add(d)
for path in domain.entrypointpaths:
vendor = any(MatchPathPrefix(path, prefix) for prefix in
["/vendor", "/odm"])
system = any(MatchPathPrefix(path, prefix) for prefix in
["/init", "/system_ext", "/product" ])
def GetCoreDomains(self):
for d in self.alldomains:
domain = self.alldomains[d]
# TestCoredomainViolations will verify if coredomain was incorrectly
# applied.
if "coredomain" in domain.attributes:
domain.coredomain = True
self.coredomains.add(d)
# check whether domains are executed off of /system or /vendor
if d in coredomainAllowlist:
continue
# TODO(b/153112003): add checks to prevent app domains from being
# incorrectly labeled as coredomain. Apps don't have entrypoints as
# they're always dynamically transitioned to by zygote.
if d in self.appdomains:
continue
# TODO(b/153112747): need to handle cases where there is a dynamic
# transition OR there happens to be no context in AOSP files.
if not domain.entrypointpaths:
continue
# only mark entrypoint as system if it is not in legacy /system/vendor
if MatchPathPrefix(path, "/system/vendor"):
vendor = True
elif MatchPathPrefix(path, "/system"):
system = True
for path in domain.entrypointpaths:
vendor = any(MatchPathPrefix(path, prefix) for prefix in
["/vendor", "/odm"])
system = any(MatchPathPrefix(path, prefix) for prefix in
["/init", "/system_ext", "/product" ])
if not vendor and not system:
domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
# only mark entrypoint as system if it is not in legacy /system/vendor
if MatchPathPrefix(path, "/system/vendor"):
vendor = True
elif MatchPathPrefix(path, "/system"):
system = True
domain.fromSystem = domain.fromSystem or system
domain.fromVendor = domain.fromVendor or vendor
if not vendor and not system:
domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
###
# Add the entrypoint type and path(s) to each domain.
#
def GetDomainEntrypoints(pol):
global alldomains
for x in pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
if not x.sctx in alldomains:
continue
alldomains[x.sctx].entrypoints.append(str(x.tctx))
# postinstall_file represents a special case specific to A/B OTAs.
# Update_engine mounts a partition and relabels it postinstall_file.
# There is no file_contexts entry associated with postinstall_file
# so skip the lookup.
if x.tctx == "postinstall_file":
continue
entrypointpath = pol.QueryFc(x.tctx)
if not entrypointpath:
continue
alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
###
# Get attributes associated with each domain
#
def GetAttributes(pol):
global alldomains
for domain in alldomains:
for result in pol.QueryTypeAttribute(domain, False):
alldomains[domain].attributes.add(result)
domain.fromSystem = domain.fromSystem or system
domain.fromVendor = domain.fromVendor or vendor
def GetAllTypes(pol, oldpol):
global alltypes
global oldalltypes
alltypes = pol.GetAllTypes(False)
oldalltypes = oldpol.GetAllTypes(False)
###
# Add the entrypoint type and path(s) to each domain.
#
def GetDomainEntrypoints(self):
for x in self.pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
if not x.sctx in self.alldomains:
continue
self.alldomains[x.sctx].entrypoints.append(str(x.tctx))
# postinstall_file represents a special case specific to A/B OTAs.
# Update_engine mounts a partition and relabels it postinstall_file.
# There is no file_contexts entry associated with postinstall_file
# so skip the lookup.
if x.tctx == "postinstall_file":
continue
entrypointpath = self.pol.QueryFc(x.tctx)
if not entrypointpath:
continue
self.alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
def setup(pol):
GetAllDomains(pol)
GetAttributes(pol)
GetDomainEntrypoints(pol)
GetAppDomains()
GetCoreDomains()
###
# Get attributes associated with each domain
#
def GetAttributes(self):
for domain in self.alldomains:
for result in self.pol.QueryTypeAttribute(domain, False):
self.alldomains[domain].attributes.add(result)
# setup for the policy compatibility tests
def compatSetup(pol, oldpol, mapping, types):
global compatMapping
global pubtypes
def setup(self, pol):
self.pol = pol
self.GetAllDomains()
self.GetAttributes()
self.GetDomainEntrypoints()
self.GetAppDomains()
self.GetCoreDomains()
GetAllTypes(pol, oldpol)
compatMapping = mapping
pubtypes = types
def GetAllTypes(self, basepol, oldpol):
self.alltypes = basepol.GetAllTypes(False)
self.oldalltypes = oldpol.GetAllTypes(False)
# setup for the policy compatibility tests
def compatSetup(self, basepol, oldpol, mapping, types):
self.GetAllTypes(basepol, oldpol)
self.compatMapping = mapping
self.pubtypes = types
def DomainsWithAttribute(self, attr):
domains = []
for domain in self.alldomains:
if attr in self.alldomains[domain].attributes:
domains.append(domain)
return domains
def PrintScontexts(self):
for d in sorted(self.alldomains.keys()):
sctx = self.alldomains[d]
print(d)
print("\tcoredomain="+str(sctx.coredomain))
print("\tappdomain="+str(sctx.appdomain))
print("\tfromSystem="+str(sctx.fromSystem))
print("\tfromVendor="+str(sctx.fromVendor))
print("\tattributes="+str(sctx.attributes))
print("\tentrypoints="+str(sctx.entrypoints))
print("\tentrypointpaths=")
if sctx.entrypointpaths is not None:
for path in sctx.entrypointpaths:
print("\t\t"+str(path))
def DomainsWithAttribute(attr):
global alldomains
domains = []
for domain in alldomains:
if attr in alldomains[domain].attributes:
domains.append(domain)
return domains
#############################################################
# Tests
#############################################################
def TestCoredomainViolations():
global alldomains
def TestCoredomainViolations(test_policy):
# verify that all domains launched from /system have the coredomain
# attribute
ret = ""
for d in alldomains:
domain = alldomains[d]
for d in test_policy.alldomains:
domain = test_policy.alldomains[d]
if domain.fromSystem and domain.fromVendor:
ret += "The following domain is system and vendor: " + d + "\n"
for domain in alldomains.values():
for domain in test_policy.alldomains.values():
ret += domain.error
violators = []
for d in alldomains:
domain = alldomains[d]
for d in test_policy.alldomains:
domain = test_policy.alldomains[d]
if domain.fromSystem and "coredomain" not in domain.attributes:
violators.append(d);
if len(violators) > 0:
@ -228,8 +222,8 @@ def TestCoredomainViolations():
# verify that all domains launched form /vendor do not have the coredomain
# attribute
violators = []
for d in alldomains:
domain = alldomains[d]
for d in test_policy.alldomains:
domain = test_policy.alldomains[d]
if domain.fromVendor and "coredomain" in domain.attributes:
violators.append(d)
if len(violators) > 0:
@ -243,17 +237,13 @@ def TestCoredomainViolations():
###
# Make sure that any new public type introduced in the new policy that was not
# present in the old policy has been recorded in the mapping file.
def TestNoUnmappedNewTypes():
global alltypes
global oldalltypes
global compatMapping
global pubtypes
newt = alltypes - oldalltypes
def TestNoUnmappedNewTypes(test_policy):
newt = test_policy.alltypes - test_policy.oldalltypes
ret = ""
violators = []
for n in newt:
if n in pubtypes and compatMapping.rTypeattributesets.get(n) is None:
if n in test_policy.pubtypes and test_policy.compatMapping.rTypeattributesets.get(n) is None:
violators.append(n)
if len(violators) > 0:
@ -270,16 +260,13 @@ def TestNoUnmappedNewTypes():
###
# Make sure that any public type removed in the current policy has its
# declaration added to the mapping file for use in non-platform policy
def TestNoUnmappedRmTypes():
global alltypes
global oldalltypes
global compatMapping
rmt = oldalltypes - alltypes
def TestNoUnmappedRmTypes(test_policy):
rmt = test_policy.oldalltypes - test_policy.alltypes
ret = ""
violators = []
for o in rmt:
if o in compatMapping.pubtypes and not o in compatMapping.types:
if o in test_policy.compatMapping.pubtypes and not o in test_policy.compatMapping.types:
violators.append(o)
if len(violators) > 0:
@ -292,34 +279,32 @@ def TestNoUnmappedRmTypes():
ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/822743\n"
return ret
def TestTrebleCompatMapping():
ret = TestNoUnmappedNewTypes()
ret += TestNoUnmappedRmTypes()
def TestTrebleCompatMapping(test_policy):
ret = TestNoUnmappedNewTypes(test_policy)
ret += TestNoUnmappedRmTypes(test_policy)
return ret
def TestViolatorAttribute(attribute):
global FakeTreble
def TestViolatorAttribute(test_policy, attribute):
ret = ""
if FakeTreble:
if test_policy.FakeTreble:
return ret
violators = DomainsWithAttribute(attribute)
violators = test_policy.DomainsWithAttribute(attribute)
if len(violators) > 0:
ret += "SELinux: The following domains violate the Treble ban "
ret += "against use of the " + attribute + " attribute: "
ret += " ".join(str(x) for x in sorted(violators)) + "\n"
return ret
def TestViolatorAttributes():
def TestViolatorAttributes(test_policy):
ret = ""
ret += TestViolatorAttribute("socket_between_core_and_vendor_violators")
ret += TestViolatorAttribute("vendor_executes_system_violators")
ret += TestViolatorAttribute(test_policy, "socket_between_core_and_vendor_violators")
ret += TestViolatorAttribute(test_policy, "vendor_executes_system_violators")
return ret
# TODO move this to sepolicy_tests
def TestCoreDataTypeViolations():
global pol
return pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
def TestCoreDataTypeViolations(test_policy):
return test_policy.pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
"/data/vendor_de/"], [], "core_data_file_type")
###
@ -349,7 +334,7 @@ def do_main(libpath):
Args:
libpath: string, path to libsepolwrap.so
"""
global pol, FakeTreble
test_policy = TestPolicy()
usage = "treble_sepolicy_tests "
usage += "-f nonplat_file_contexts -f plat_file_contexts "
@ -402,27 +387,27 @@ def do_main(libpath):
oldpol = policy.Policy(options.oldpolicy, None, libpath)
mapping = mini_parser.MiniCilParser(options.mapping)
pubpol = mini_parser.MiniCilParser(options.base_pub_policy)
compatSetup(basepol, oldpol, mapping, pubpol.types)
test_policy.compatSetup(basepol, oldpol, mapping, pubpol.types)
if options.faketreble:
FakeTreble = True
test_policy.FakeTreble = True
pol = policy.Policy(options.policy, options.file_contexts, libpath)
setup(pol)
test_policy.setup(pol)
if DEBUG:
PrintScontexts()
test_policy.PrintScontexts()
results = ""
# If an individual test is not specified, run all tests.
if options.tests is None:
for t in Tests.values():
results += t()
results += t(test_policy)
else:
for tn in options.tests:
t = Tests.get(tn)
if t:
results += t()
results += t(test_policy)
else:
err = "Error: unknown test: " + tn + "\n"
err += "Available tests:\n"