We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Sum of Distances in Tree

Number: 863

Difficulty: Hard

Paid? No

Companies: Google, Microsoft, PhonePe, TikTok, MathWorks, Amazon


Problem Description

Given an undirected tree with n nodes labeled from 0 to n - 1 (and exactly n - 1 edges), compute for each node the sum of distances to all other nodes in the tree.


Key Insights

  • Use tree dynamic programming with two DFS traversals.
  • The first DFS computes, for each node, the sum of distances from that node’s subtree and the count of nodes in that subtree.
  • The second DFS propagates the computed distances from the root to all children using the relation that adjusts the distances based on subtree sizes.
  • This method avoids recomputation, leading to an efficient O(n) solution.

Space and Time Complexity

Time Complexity: O(n) Space Complexity: O(n) (due to tree adjacency list and recursion stack)


Solution

We use a two-step DFS approach:

  1. In the first DFS (postorder), for each node we compute:
    • count[node]: number of nodes in the subtree (including itself).
    • res[node]: sum of distances from node to all nodes in its subtree.
  2. In the second DFS (preorder), we use the parent’s result to update each child’s result:
    • When moving from parent to child, update res[child] = res[parent] - count[child] + (n - count[child]).
    • This transformation works because moving to a child subtracts the distance for all nodes in the child’s subtree (as they get closer by 1) and adds the distance for the rest (as they get farther by 1).

The main data structures used are an adjacency list for the tree and two arrays (or vectors) to store the count and results. The algorithm leverages recursion to do DFS traversals.


Code Solutions

# Python solution
class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        # Build tree as an adjacency list
        tree = [[] for _ in range(n)]
        for u, v in edges:
            tree[u].append(v)
            tree[v].append(u)
        
        # Initialize result and count arrays
        res = [0] * n
        count = [1] * n  # each node counts as 1
        
        # Postorder DFS: compute count and res for subtrees
        def dfs(node: int, parent: int):
            for child in tree[node]:
                if child == parent:
                    continue
                dfs(child, node)
                count[node] += count[child]
                res[node] += res[child] + count[child]
        dfs(0, -1)
        
        # Preorder DFS: update results using parent's information
        def dfs2(node: int, parent: int):
            for child in tree[node]:
                if child == parent:
                    continue
                # Update result for child using parent's result formula
                res[child] = res[node] - count[child] + (n - count[child])
                dfs2(child, node)
        dfs2(0, -1)
        
        return res
← Back to All Questions