educative.io

Anther approach could be to use sliding window for path sum

Hi guys,

Instead of calculating pathSum on every node. Another approach could be to use a sliding window for maintaining path sum. I believe time complexity will be improved to O(N log N).

O(N) for traversing the tree.
For the sliding window, all paths will be accessed once. So total paths( (N+1) / 2) * height of balanced Tree (log N) elements will be accessed.
O((N+1) / 2 * log N) => O(N log N)

O(N) + O (N log N) => O (N log N) should be the time complexity.
What are your thoughts?

  private static int findPaths(TreeNode root, int reqSum) {
    List<Integer> currentPath = new ArrayList<>();
    return findPathRecursively(root, currentPath, 0, reqSum, 0);
  }

  private static int findPathRecursively(
      TreeNode node,
      List<Integer> currentPath,
      int prevSum,
      int reqSum,
      int begIndex
  ) {
    if (node == null) {
      return 0;
    }

    currentPath.add(node.getValue());
    int currSum = prevSum + node.getValue();
    int matchedPaths = 0;

    while (currSum >= reqSum) {
      if (currSum == reqSum) {
        matchedPaths++;
      }

      currSum -= currentPath.get(begIndex);
      begIndex++;
    }

    int leftMatchedPaths = findPathRecursively(node.getLeftNode(), currentPath, currSum, reqSum,
        begIndex);
    int rightMatchedPaths = findPathRecursively(node.getRightNode(), currentPath, currSum, reqSum,
        begIndex);

    currentPath.remove(currentPath.size() - 1);
    return leftMatchedPaths + rightMatchedPaths + matchedPaths;
  }

Thanks

This is the solution I thought of as well. However, I think the time complexity is still O(n^2) as in the worst case (tree is a linkedlist - every node except the leaf has a single child) the depth of the tree is N and the root-to-leaf path is therefore N too.

I do think this is a more elegant approach. I implemented it in python:

EDIT: Note that I am no expert though! There may be something wrong in my solution. I do think I need an extra case in the while loop to account for zeros. When you have zeros, multiple paths can be found through the same nodes to reach the sum.

import collections


class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


def count_paths(root, S):
    total_sums = 0
    for path in iter_all_paths_dfs(root):
        end_index = 0
        start_index = 0
        running_sum = 0
        while end_index < len(path):
            running_sum += path[end_index]
            while running_sum > S and start_index < end_index:
                running_sum -= path[start_index]
                start_index += 1

            if running_sum == S:
                total_sums += 1

            end_index += 1

    return total_sums


def iter_all_paths_dfs(node, current_path=collections.deque()):

    if node is None:
        return

    current_path.append(node.val)

    if not node.left and not node.right:
        # Leaf node.
        yield list(current_path)

    yield from iter_all_paths_dfs(node.left)
    yield from iter_all_paths_dfs(node.right)

    current_path.pop()
1 Like

Sorry for the repost, still working through this. TLDR - agree that the O(N) to traverse, but I think that we tend towards an average constant number of removals for each node in the worst case, so O(N) run time. The general intuition is that shallower nodes can be removed from the sliding window more times (up to log(N) for the root node in the worst case), but there are fewer of them.

We can have up to N recursive calls with this solution. The only thing left to determine is how many times a given node can be removed from the window as we advance the left index, to determine the total number of times the while loop runs in the below solution.

The runtime is the best when we have a linked list - each node can be removed at most one time, which gives us O(N) runtime. This reduces to the sliding window problem where we search for the number of subarrays a given sum appears in an array.

The runtime is the worst when we have a complete or perfect binary tree. In this case, the number of removals for each node is a function depth of the node. The root node can be present (and thus removed) from every leaf node’s sliding window - so log(N) times, but there’s only one of it. The leaf nodes can be removed at most one time, and there’s log(N) of them. In general, the number of times a given node in a perfect binary tree is log(N) - depth. The average depth is O(log(N)), so the average number of times a given node can be removed I think is actually closer to constant time. Interested to hear anyone’s thoughts on this.

The average depth of each node in a complete binary tree is log(N).

Uploading: Screen Shot 2022-07-18 at 4.06.05 PM.png…

Here’s the code I wrote:

class CountAllPathSum {
 public:
  static int countPaths(TreeNode *root, int S) {
    vector<int> curr_path = {};
    int path_count = 0;
    count_paths_recurse(root, S, 0, curr_path, 0, path_count);
    return path_count;
  }
 private:
  static void count_paths_recurse(TreeNode *curr, int target_sum, int curr_sum,
                                  vector<int> curr_path, int left_index, int &path_count) {
    if (!curr) {
      return;
    }
    curr_sum += curr->val;
    curr_path.push_back(curr->val);
    while (curr_sum > target_sum && left_index < curr_path.size()) {
      curr_sum -= curr_path[left_index];
      left_index++;
    }
    if (curr_sum == target_sum) {
      path_count++;
    }
    count_paths_recurse(curr->left, target_sum, curr_sum, curr_path, left_index, path_count);
    count_paths_recurse(curr->right, target_sum, curr_sum, curr_path, left_index, path_count);
  }
};

Course: Grokking the Coding Interview: Patterns for Coding Questions - Learn Interactively
Lesson: Count Paths for a Sum (medium) - Grokking the Coding Interview: Patterns for Coding Questions