-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
maximum-xor-of-two-non-overlapping-subtrees.py
145 lines (130 loc) · 4.23 KB
/
maximum-xor-of-two-non-overlapping-subtrees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Time: O(nlogr), r is sum(values)
# Space: O(n)
# iterative dfs, trie, greedy
class Trie(object):
def __init__(self, bit_length):
self.__root = {}
self.__bit_length = bit_length
def insert(self, num):
node = self.__root
for i in reversed(xrange(self.__bit_length)):
curr = (num>>i) & 1
if curr not in node:
node[curr] = {}
node = node[curr]
def query(self, num):
if not self.__root:
return -1
node, result = self.__root, 0
for i in reversed(xrange(self.__bit_length)):
curr = (num>>i) & 1
if 1^curr in node:
node = node[1^curr]
result |= 1<<i
else:
node = node[curr]
return result
class Solution(object):
def maxXor(self, n, edges, values):
"""
:type n: int
:type edges: List[List[int]]
:type values: List[int]
:rtype: int
"""
def iter_dfs():
lookup = [0]*len(values)
stk = [(1, 0, -1)]
while stk:
step, u, p = stk.pop()
if step == 1:
stk.append((2, u, p))
for v in adj[u]:
if v == p:
continue
stk.append((1, v, u))
elif step == 2:
lookup[u] = values[u]+sum(lookup[v] for v in adj[u] if v != p)
return lookup
def iter_dfs2():
trie = Trie(lookup[0].bit_length())
result = [0]
stk = [(1, (0, -1, result))]
while stk:
step, args = stk.pop()
if step == 1:
u, p, ret = args
ret[0] = max(trie.query(lookup[u]), 0)
stk.append((3, (u,)))
for v in adj[u]:
if v == p:
continue
new_ret = [0]
stk.append((2, (new_ret, ret)))
stk.append((1, (v, u, new_ret)))
elif step == 2:
new_ret, ret = args
ret[0] = max(ret[0], new_ret[0])
elif step == 3:
u = args[0]
trie.insert(lookup[u])
return result[0]
adj = [[] for _ in xrange(len(values))]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
lookup = iter_dfs()
return iter_dfs2()
# Time: O(nlogr), r is sum(values)
# Space: O(n)
# dfs, trie, greedy
class Trie(object):
def __init__(self, bit_length):
self.__root = {}
self.__bit_length = bit_length
def insert(self, num):
node = self.__root
for i in reversed(xrange(self.__bit_length)):
curr = (num>>i) & 1
if curr not in node:
node[curr] = {}
node = node[curr]
def query(self, num):
if not self.__root:
return -1
node, result = self.__root, 0
for i in reversed(xrange(self.__bit_length)):
curr = (num>>i) & 1
if 1^curr in node:
node = node[1^curr]
result |= 1<<i
else:
node = node[curr]
return result
class Solution2(object):
def maxXor(self, n, edges, values):
"""
:type n: int
:type edges: List[List[int]]
:type values: List[int]
:rtype: int
"""
def dfs(u, p):
lookup[u] = values[u]+sum(dfs(v, u) for v in adj[u] if v != p)
return lookup[u]
def dfs2(u, p):
result = max(trie.query(lookup[u]), 0)
for v in adj[u]:
if v == p:
continue
result = max(result, dfs2(v, u))
trie.insert(lookup[u])
return result
adj = [[] for _ in xrange(len(values))]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
lookup = [0]*len(values)
dfs(0, -1)
trie = Trie(lookup[0].bit_length())
return dfs2(0, -1)