defchangeNodes(self, existingTree, newTree) -> int: ifnot existingTree andnot newTree: return0 ifnot existingTree ornot newTree: return self.count(existingTree) + self.count(newTree) res = 0if existingTree.key == newTree.key and existingTree.val == newTree.val else1 existing_children_dict = self.get_children_dict(existingTree.children) new_tree_children_dict = self.get_children_dict(newTree.children) for key in existing_children_dict.keys() & new_tree_children_dict.keys(): # in both res += self.changeNodes(existing_children_dict[key], new_tree_children_dict[key]) for key in existing_children_dict.keys() - new_tree_children_dict.keys(): # in existing tree not in new tree res += self.count(existing_children_dict[key]) for key in new_tree_children_dict.keys() - existing_children_dict.keys(): # in new tree not in existing tree res += self.count(new_tree_children_dict[key]) return res
defcount(self, root): ifnot root: return0 res = 0 for child in root.children: res += self.count(child) return1 + res
defget_children_dict(self, children): key_to_node = {} for child in children: key_to_node[child.key] = child return key_to_node