Merge "Allow traversal over the trie structure"
This commit is contained in:
commit
be74a9fa01
2 changed files with 104 additions and 9 deletions
|
@ -22,6 +22,19 @@ from itertools import chain
|
|||
|
||||
@dataclasses.dataclass()
|
||||
class Node:
|
||||
"""A node in the signature trie."""
|
||||
|
||||
# The type of the node.
|
||||
#
|
||||
# Leaf nodes are of type "member".
|
||||
# Interior nodes can be either "package", or "class".
|
||||
type: str
|
||||
|
||||
# The selector of the node.
|
||||
#
|
||||
# That is a string that can be used to select the node, e.g. in a pattern
|
||||
# that is passed to InteriorNode.get_matching_rows().
|
||||
selector: str
|
||||
|
||||
def values(self, selector):
|
||||
"""Get the values from a set of selected nodes.
|
||||
|
@ -48,6 +61,10 @@ class Node:
|
|||
"""
|
||||
raise NotImplementedError("Please Implement this method")
|
||||
|
||||
def child_nodes(self):
|
||||
"""Get an iterable of the child nodes of this node."""
|
||||
raise NotImplementedError("Please Implement this method")
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
@dataclasses.dataclass()
|
||||
|
@ -173,22 +190,68 @@ class InteriorNode(Node):
|
|||
element_type, _ = InteriorNode.split_element(element)
|
||||
return element_type
|
||||
|
||||
def add(self, signature, value):
|
||||
@staticmethod
|
||||
def elements_to_selector(elements):
|
||||
"""Compute a selector for a set of elements.
|
||||
|
||||
A selector uniquely identifies a specific Node in the trie. It is
|
||||
essentially a prefix of a signature (without the leading L).
|
||||
|
||||
e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
|
||||
would contain nodes with the following selectors:
|
||||
* "java"
|
||||
* "java/lang"
|
||||
* "java/lang/Object"
|
||||
* "java/lang/Object;->String()Ljava/lang/String;"
|
||||
"""
|
||||
signature = ""
|
||||
preceding_type = ""
|
||||
for element in elements:
|
||||
element_type, element_value = InteriorNode.split_element(element)
|
||||
separator = ""
|
||||
if element_type == "package":
|
||||
separator = "/"
|
||||
elif element_type == "class":
|
||||
if preceding_type == "class":
|
||||
separator = "$"
|
||||
else:
|
||||
separator = "/"
|
||||
elif element_type == "wildcard":
|
||||
separator = "/"
|
||||
elif element_type == "member":
|
||||
separator += ";->"
|
||||
|
||||
if signature:
|
||||
signature += separator
|
||||
|
||||
signature += element_value
|
||||
|
||||
preceding_type = element_type
|
||||
|
||||
return signature
|
||||
|
||||
def add(self, signature, value, only_if_matches=False):
|
||||
"""Associate the value with the specific signature.
|
||||
|
||||
:param signature: the member signature
|
||||
:param value: the value to associated with the signature
|
||||
:param only_if_matches: True if the value is added only if the signature
|
||||
matches at least one of the existing top level packages.
|
||||
:return: n/a
|
||||
"""
|
||||
# Split the signature into elements.
|
||||
elements = self.signature_to_elements(signature)
|
||||
# Find the Node associated with the deepest class.
|
||||
node = self
|
||||
for element in elements[:-1]:
|
||||
for index, element in enumerate(elements[:-1]):
|
||||
if element in node.nodes:
|
||||
node = node.nodes[element]
|
||||
elif only_if_matches and index == 0:
|
||||
return
|
||||
else:
|
||||
next_node = InteriorNode()
|
||||
selector = self.elements_to_selector(elements[0:index + 1])
|
||||
next_node = InteriorNode(
|
||||
type=InteriorNode.element_type(element), selector=selector)
|
||||
node.nodes[element] = next_node
|
||||
node = next_node
|
||||
# Add a Leaf containing the value and associate it with the member
|
||||
|
@ -201,7 +264,12 @@ class InteriorNode(Node):
|
|||
"specific member")
|
||||
if last_element in node.nodes:
|
||||
raise Exception(f"Duplicate signature: {signature}")
|
||||
node.nodes[last_element] = Leaf(value)
|
||||
leaf = Leaf(
|
||||
type=last_element_type,
|
||||
selector=signature,
|
||||
value=value,
|
||||
)
|
||||
node.nodes[last_element] = leaf
|
||||
|
||||
def get_matching_rows(self, pattern):
|
||||
"""Get the values (plural) associated with the pattern.
|
||||
|
@ -212,10 +280,6 @@ class InteriorNode(Node):
|
|||
If the pattern is a class then this will return a list containing the
|
||||
values associated with all members of that class.
|
||||
|
||||
If the pattern is a package then this will return a list containing the
|
||||
values associated with all the members of all the classes in that
|
||||
package and sub-packages.
|
||||
|
||||
If the pattern ends with "*" then the preceding part is treated as a
|
||||
package and this will return a list containing the values associated
|
||||
with all the members of all the classes in that package.
|
||||
|
@ -261,6 +325,9 @@ class InteriorNode(Node):
|
|||
if selector(key):
|
||||
node.append_values(values, lambda x: True)
|
||||
|
||||
def child_nodes(self):
|
||||
return self.nodes.values()
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class Leaf(Node):
|
||||
|
@ -275,6 +342,9 @@ class Leaf(Node):
|
|||
def append_values(self, values, selector):
|
||||
values.append([self.value])
|
||||
|
||||
def child_nodes(self):
|
||||
return []
|
||||
|
||||
|
||||
def signature_trie():
|
||||
return InteriorNode()
|
||||
return InteriorNode(type="root", selector="")
|
||||
|
|
|
@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
def signature_to_elements(signature):
|
||||
return InteriorNode.signature_to_elements(signature)
|
||||
|
||||
@staticmethod
|
||||
def elements_to_signature(elements):
|
||||
return InteriorNode.elements_to_selector(elements)
|
||||
|
||||
def test_nested_inner_classes(self):
|
||||
elements = [
|
||||
("package", "java"),
|
||||
|
@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
|
||||
|
||||
def test_basic_member(self):
|
||||
elements = [
|
||||
|
@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "Ljava/lang/Object;->hashCode()I"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
|
||||
|
||||
def test_double_dollar_class(self):
|
||||
elements = [
|
||||
|
@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
|
||||
"-><init>(Ljava/lang/CharSequence;)V"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
|
||||
|
||||
def test_no_member(self):
|
||||
elements = [
|
||||
|
@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
|
||||
|
||||
def test_wildcard(self):
|
||||
elements = [
|
||||
|
@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "java/lang/*"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, self.elements_to_signature(elements))
|
||||
|
||||
def test_recursive_wildcard(self):
|
||||
elements = [
|
||||
|
@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "java/lang/**"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, self.elements_to_signature(elements))
|
||||
|
||||
def test_no_packages_wildcard(self):
|
||||
elements = [
|
||||
|
@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "*"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, self.elements_to_signature(elements))
|
||||
|
||||
def test_no_packages_recursive_wildcard(self):
|
||||
elements = [
|
||||
|
@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "**"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, self.elements_to_signature(elements))
|
||||
|
||||
def test_invalid_no_class_or_wildcard(self):
|
||||
signature = "java/lang"
|
||||
|
@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase):
|
|||
]
|
||||
signature = "Ljavax/crypto/extObjectInputStream"
|
||||
self.assertEqual(elements, self.signature_to_elements(signature))
|
||||
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
|
||||
|
||||
def test_invalid_pattern_wildcard(self):
|
||||
pattern = "Ljava/lang/Class*"
|
||||
|
@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;-><clinit>()V
|
|||
"Ljava/util/zip/ZipFile;-><clinit>()V",
|
||||
])
|
||||
|
||||
def test_node_wildcard(self):
|
||||
trie = self.read_trie()
|
||||
node = list(trie.child_nodes())[0]
|
||||
self.check_node_patterns(node, "**", [
|
||||
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
|
||||
"Ljava/lang/Character;->serialVersionUID:J",
|
||||
"Ljava/lang/Object;->hashCode()I",
|
||||
"Ljava/lang/Object;->toString()Ljava/lang/String;",
|
||||
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
|
||||
"Ljava/util/zip/ZipFile;-><clinit>()V",
|
||||
])
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue