from typing import Dict, List, Optional, Any
import copy
import uuid
[docs]
class TreeNode(dict):
"""
A tree node class. This class stores a nested
representation of the tree. Each node is a TreeNode object, and if a node
is a child of another node, it is stored in the parent node's `children`
attribute.
"""
[docs]
@staticmethod
def from_dict(data: Dict) -> "TreeNode":
"""
Create a TreeNode from a dictionary.
:param data: The dictionary to convert to a TreeNode.
:return: A TreeNode object.
"""
def _from_dict(data, parent):
node = TreeNode(parent=parent, payload=None, name=data.pop("name", None))
node.payload = data.pop("payload", {})
for k, v in data.items():
if k == "children":
for child in v:
_from_dict(child, node)
else:
node.payload[k] = v
return node
return _from_dict(copy.deepcopy(data), None)
[docs]
def clone(self) -> "TreeNode":
"""
Clone the tree node (sub-tree) rooted at the current node.
:return: A new TreeNode object with the same data as the current node.
"""
def _clone(node, parent):
new_node = TreeNode(
parent=parent, name=node.name, payload=copy.deepcopy(node.payload)
)
for child in node.children:
_clone(child, new_node)
return new_node
return _clone(self, None)
def __init__(
self,
parent: Optional["TreeNode"] = None,
name: Optional[str] = None,
payload: Optional[Any] = None,
*args,
**kwargs,
):
"""
Initialize a TreeNode. The parent of the node is set to the given parent
node. If the parent is None, the node is the root of the tree. The name
of the node is set to the given name. If the name is None, a random name
is generated. The payload of the node is any additional arguments passed
to the constructor.
:param parent: The parent node of the current node. Default is None.
:param name: The name of the node. Default is None, in which case a
random name is generated.
:param payload: The payload of the node. Default is None.
:param args: Additional arguments to pass to the payload.
:param kwargs: Additional keyword arguments to pass to the payload.
"""
if name is None:
name = str(uuid.uuid4())
self.name = name
if parent is not None and not isinstance(parent, TreeNode):
raise ValueError("Parent must be a TreeNode object")
self.children = []
self._parent = None
self.parent = parent
if payload is not None:
self.payload = payload
elif args or kwargs:
self.payload = dict(*args, **kwargs)
else:
self.payload = None
@property
def parent(self) -> Optional["TreeNode"]:
"""
Get the parent of the node.
:return: The parent of the node.
"""
return self._parent
@parent.setter
def parent(self, parent: Optional["TreeNode"]) -> None:
"""
Set the parent of the node.
:param parent: The new parent of the node.
"""
if parent is not None and not isinstance(parent, TreeNode):
raise ValueError("Parent must be a TreeNode object")
# remove the node from the parent's children
if self._parent is not None:
self._parent.children.remove(self)
# self._parent.children = [child for child in self._parent.children if child != self]
self._parent = parent
# update parent's children
if parent is not None:
parent.children.append(self)
@property
def root(self) -> "TreeNode":
"""
Get the root of the tree.
:return: The root node of the tree.
"""
node = self
while node.parent is not None:
node = node.parent
return node
[docs]
def nodes(self) -> List["TreeNode"]:
"""
Get all the nodes in the current sub-tree.
:return: A list of all the nodes in the current sub-tree.
"""
nodes = []
for child in self.children:
nodes.extend(child.nodes())
nodes.append(self)
return nodes
[docs]
def subtree(self, name: str) -> "TreeNode":
"""
Get the subtree rooted at the node with the given name. This is not
a view, but a new tree rooted at the node with the given name. This
is different from the `node` method, which just changes the current
node position. It's also different from the `subtree` method in the
`FlatForestNode` class, which returns a view of the tree.
:param name: The name of the node.
:return: The subtree rooted at the node with the given name.
"""
from copy import deepcopy
node = deepcopy(self.node(name))
node.parent = None
return node
[docs]
def node(self, name: str) -> "TreeNode":
"""
Get the node with the given name in the current sub-tree. The sub-tree
remains the same, we just change the current node position. If the name
is not found, raise a KeyError.
:param name: The name of the node.
:return: The node with the given name.
"""
def _descend(node, name):
if node.name == name:
return node
for child in node.children:
result = _descend(child, name)
if result is not None:
return result
return None
def _ascend(node, name):
if node.name == name:
return node
if node.parent is not None:
return _ascend(node.parent, name)
return None
asc_node = _ascend(self, name)
if asc_node is not None:
return asc_node
dsc_node = _descend(self, name)
if dsc_node is not None:
return dsc_node
raise KeyError(f"Node with name {name} not found")
[docs]
def add_child(
self, name: Optional[str] = None, payload: Optional[Any] = None, *args, **kwargs
) -> "TreeNode":
"""
Add a child node to the tree. Just invokes `__init__`. See `__init__` for
details.
"""
return TreeNode(parent=self, name=name, payload=payload, *args, **kwargs)
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
result = f"TreeNode(name={self.name}"
if self._parent is not None:
result += f", parent={self.parent.name}"
result += f", root={self.root.name}"
result += f", payload={self.payload}"
result += f", len(children)={len(self.children)})"
return result
[docs]
@staticmethod
def is_valid(data) -> bool:
"""
Check if the given data is a valid TreeNode data.
:param data: The data to check.
:return: True if the data is a valid TreeNode, False otherwise.
"""
if not isinstance(data, dict):
return False
if "children" in data:
if not isinstance(data["children"], list):
return False
for child in data["children"]:
if not TreeNode.is_valid(child):
return False
return True
[docs]
def to_dict(self):
"""
Convert the subtree rooted at `node` to a dictionary.
:return: A dictionary representation of the subtree.
"""
def _convert(node):
node_dict = {}
node_dict["name"] = node.name
node_dict["payload"] = node.payload
node_dict["children"] = [_convert(child) for child in node.children]
return node_dict
return _convert(self)
def __eq__(self, other) -> bool:
"""
Check if the current node is equal to the given node.
:param other: The other node to compare with.
:return: True if the nodes are equal, False otherwise.
"""
if not isinstance(other, TreeNode):
return False
return hash(self) == hash(other)
def __hash__(self) -> int:
"""
Compute the hash of the current node.
:return: The hash of the node.
"""
return id(self)
def __contains__(self, key) -> bool:
"""
Check if the node's payload contains the given key.
:param key: The key to check for.
:return: True if the key is present in the payload, False otherwise.
"""
return key in self.payload