LC834. Sum Distance in tree

An undirected, connected tree with N nodes labelled 0...N-1 and N-1 edges are given.

The ith edge connects nodes edges[i][0] and edges[i][1] together.

Return a list ans, where ans[i] is the sum of the distances between node i and all other nodes.

Example 1:

Input: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: 
Here is a diagram of the given tree:
  0
 / \
1   2
   /|\
  3 4 5
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.  Hence, answer[0] = 8, and so on.

Note: 1 <= N <= 10000

Approach:

The above tree will give the following distance matrix:

0: 0,1,1,2,2,2
1: 1,0,2,3,3,3
2: 1,2,0,1,1,1
3: 2,3,1,0,2,2
4: 2,3,3,2,0,2
5: 2,3,1,2,2,0

The first straight forward idea is:

change the graph representation into adjacent matrix (with parent – children relation)

iterate over each node as the root and get the above matrix.

The following dfs algorithm is O(N^2) and gets TLE.

class Solution {
public:
    vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
        vector<vector<int>> adj(N);
        for(int i=0;i<edges.size();i++)
        {
            adj[edges[i][0]].push_back(edges[i][1]);
            adj[edges[i][1]].push_back(edges[i][0]);
        }
        vector<int> distsum(N);
        vector<vector<int>> dist(N,vector<int>(N,-1));
        for(int i=0;i<N;i++) dist[i][i]=0;
        int len=0;
        for(int i=0;i<N;i++)
        {
            dfs(i,adj,-1,distsum[i],len);
        }
        return distsum;
    }
    int dfs(int root,vector<vector<int>>& adj,int parent,int& res,int& len)
    {
        //int len=0;
        for(int i=0;i<adj[root].size();i++)
        {
            int nd=adj[root][i];
            if(nd==parent) continue;
            len++;
            dfs(nd,adj,root,res,len);//
            res+=len;
            len--;
        }
        //cout<<root<<": "<<len<<endl;
        return len;
    }
};

There are many recalculations in above algorithm. For example, when we have 1 as the root, 0 becomes its child and 0’s child is already calculated. And the relation is easy to see: all 0’s children are now 1 farther and all 1’s previous children are 1 node nearer:

res[1]=res[0]-cnt[1]+(n-cnt[1])

In the following graph, when we choose A’s child B as the new root, All nodes for the B subtree will subtract one (the green one), all the remaining nodes will add one (N-green)

tree

One post-order traversal (first child and then parent node) will be able to get the cnt array and the distsum. cnt[i] is the number of node for the subtree i, and distsum[i] is the distance sum for subtree i.

cnt[i]=sum(cnt[child])+1
distsum[i]=sum(distsum[child])+sum(cnt[i])

The 1st equation: add all children plus one root node

The 2nd equation is not that clear: if we want to get distsum[i] we first get the sum of all its children, but the root needs go to every its child/grand child nodes, which is the distsum requires) and that one is missing, and that item is the node number under it, which is cnt[root]-1.

After this step, we need change the root. Its child can be easily calculated using its parent. Recursively we first parent, then child, this can be done by pre-order traversal.

So the idea is now more clear:

we first build the cnt and distsum array for each node assuming the node 0 is root! (Note the distsum other than node 0 is ONLY the subtree) using post-order traversal.

After that, we need choose 0’s children as the root, and calculate the distsum using above equation using pre-order traversal. (as we can see, in this step BFS can also do this, but dfs is more concise and concept clear).

This problem is hard if you are not familiar with the dfs recursive approach.

As in the straightforward approach, we always need to avoid cyclic calculation which leads to infinite loops, we can use visited array or hashset to achieve this.

The following code is from the LC community most voted one:

class Solution {
public:
    vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) 
    {
        vector<unordered_set<int>> tree(N);//adjacent matrix
        vector<int> res(N, 0);
        vector<int> count(N, 0);
        if (N == 1) return res;
        for (auto e : edges) 
        {
            tree[e[0]].insert(e[1]);
            tree[e[1]].insert(e[0]);
        }
        unordered_set<int> seen1, seen2;
        dfs(0, seen1, tree, res, count);
        dfs2(0, seen2, tree, res, count, N);
        return res;
    }

    void dfs(int root, unordered_set<int>& seen, vector<unordered_set<int>>& tree, vector<int>& res, vector<int>& count) 
    {
        seen.insert(root);
        for (auto i : tree[root])//root as a root tree
        {
            if (!seen.count(i)) //not visited
            {
                dfs(i, seen, tree, res, count); //child first
                count[root] += count[i];
                res[root] += res[i] + count[i]; //later the root
            }
        }
        count[root]++;
    }

    void dfs2(int root, unordered_set<int>& seen, vector<unordered_set<int>>& tree, vector<int>& res, vector<int>& count, int N) 
    {
        seen.insert(root);
        for (auto i : tree[root])
        {
            if (!seen.count(i)) 
            {
                res[i] = res[root] - count[i] + N - count[i];
                dfs2(i, seen, tree, res, count, N);
            }
        }
    }
};

Leave a comment