Merge "Allow traversal over the trie structure"

This commit is contained in:
Paul Duffin 2022-03-16 10:54:45 +00:00 committed by Gerrit Code Review
commit be74a9fa01
2 changed files with 104 additions and 9 deletions

View file

@ -22,6 +22,19 @@ from itertools import chain
@dataclasses.dataclass()
class Node:
"""A node in the signature trie."""
# The type of the node.
#
# Leaf nodes are of type "member".
# Interior nodes can be either "package", or "class".
type: str
# The selector of the node.
#
# That is a string that can be used to select the node, e.g. in a pattern
# that is passed to InteriorNode.get_matching_rows().
selector: str
def values(self, selector):
"""Get the values from a set of selected nodes.
@ -48,6 +61,10 @@ class Node:
"""
raise NotImplementedError("Please Implement this method")
def child_nodes(self):
"""Get an iterable of the child nodes of this node."""
raise NotImplementedError("Please Implement this method")
# pylint: disable=line-too-long
@dataclasses.dataclass()
@ -173,22 +190,68 @@ class InteriorNode(Node):
element_type, _ = InteriorNode.split_element(element)
return element_type
def add(self, signature, value):
@staticmethod
def elements_to_selector(elements):
"""Compute a selector for a set of elements.
A selector uniquely identifies a specific Node in the trie. It is
essentially a prefix of a signature (without the leading L).
e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
would contain nodes with the following selectors:
* "java"
* "java/lang"
* "java/lang/Object"
* "java/lang/Object;->String()Ljava/lang/String;"
"""
signature = ""
preceding_type = ""
for element in elements:
element_type, element_value = InteriorNode.split_element(element)
separator = ""
if element_type == "package":
separator = "/"
elif element_type == "class":
if preceding_type == "class":
separator = "$"
else:
separator = "/"
elif element_type == "wildcard":
separator = "/"
elif element_type == "member":
separator += ";->"
if signature:
signature += separator
signature += element_value
preceding_type = element_type
return signature
def add(self, signature, value, only_if_matches=False):
"""Associate the value with the specific signature.
:param signature: the member signature
:param value: the value to associated with the signature
:param only_if_matches: True if the value is added only if the signature
matches at least one of the existing top level packages.
:return: n/a
"""
# Split the signature into elements.
elements = self.signature_to_elements(signature)
# Find the Node associated with the deepest class.
node = self
for element in elements[:-1]:
for index, element in enumerate(elements[:-1]):
if element in node.nodes:
node = node.nodes[element]
elif only_if_matches and index == 0:
return
else:
next_node = InteriorNode()
selector = self.elements_to_selector(elements[0:index + 1])
next_node = InteriorNode(
type=InteriorNode.element_type(element), selector=selector)
node.nodes[element] = next_node
node = next_node
# Add a Leaf containing the value and associate it with the member
@ -201,7 +264,12 @@ class InteriorNode(Node):
"specific member")
if last_element in node.nodes:
raise Exception(f"Duplicate signature: {signature}")
node.nodes[last_element] = Leaf(value)
leaf = Leaf(
type=last_element_type,
selector=signature,
value=value,
)
node.nodes[last_element] = leaf
def get_matching_rows(self, pattern):
"""Get the values (plural) associated with the pattern.
@ -212,10 +280,6 @@ class InteriorNode(Node):
If the pattern is a class then this will return a list containing the
values associated with all members of that class.
If the pattern is a package then this will return a list containing the
values associated with all the members of all the classes in that
package and sub-packages.
If the pattern ends with "*" then the preceding part is treated as a
package and this will return a list containing the values associated
with all the members of all the classes in that package.
@ -261,6 +325,9 @@ class InteriorNode(Node):
if selector(key):
node.append_values(values, lambda x: True)
def child_nodes(self):
return self.nodes.values()
@dataclasses.dataclass()
class Leaf(Node):
@ -275,6 +342,9 @@ class Leaf(Node):
def append_values(self, values, selector):
values.append([self.value])
def child_nodes(self):
return []
def signature_trie():
return InteriorNode()
return InteriorNode(type="root", selector="")

View file

@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase):
def signature_to_elements(signature):
return InteriorNode.signature_to_elements(signature)
@staticmethod
def elements_to_signature(elements):
return InteriorNode.elements_to_selector(elements)
def test_nested_inner_classes(self):
elements = [
("package", "java"),
@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_basic_member(self):
elements = [
@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/Object;->hashCode()I"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_double_dollar_class(self):
elements = [
@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase):
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
"-><init>(Ljava/lang/CharSequence;)V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_no_member(self):
elements = [
@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_wildcard(self):
elements = [
@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "java/lang/*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_recursive_wildcard(self):
elements = [
@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "java/lang/**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_wildcard(self):
elements = [
@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_recursive_wildcard(self):
elements = [
@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_invalid_no_class_or_wildcard(self):
signature = "java/lang"
@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljavax/crypto/extObjectInputStream"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_invalid_pattern_wildcard(self):
pattern = "Ljava/lang/Class*"
@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;-><clinit>()V
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
def test_node_wildcard(self):
trie = self.read_trie()
node = list(trie.child_nodes())[0]
self.check_node_patterns(node, "**", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
# pylint: enable=line-too-long