-
-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy path0834-sum-of-distances-in-tree.py
More file actions
38 lines (29 loc) · 1008 Bytes
/
0834-sum-of-distances-in-tree.py
File metadata and controls
38 lines (29 loc) · 1008 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# time complexity: O(n)
# space complexity: O(n)
import collections
from typing import List
class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
graph = collections.defaultdict(set)
for u, v in edges:
graph[u].add(v)
graph[v].add(u)
count = [1] * n
ans = [0] * n
def dfs(node=0, parent=None):
for child in graph[node]:
if child != parent:
dfs(child, node)
count[node] += count[child]
ans[node] += ans[child] + count[child]
def dfs2(node=0, parent=None):
for child in graph[node]:
if child != parent:
ans[child] = ans[node] - count[child] + n - count[child]
dfs2(child, node)
dfs()
dfs2()
return ans
n = 6
edges = [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]]
print(Solution().sumOfDistancesInTree(n, edges))