We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Handling Sum Queries After Update

Number: 2703

Difficulty: Hard

Paid? No

Companies: Trilogy


Problem Description

You are given two 0-indexed arrays, nums1 (containing only 0s and 1s) and nums2, and a list of queries. There are three types of queries:

  1. [1, l, r]: Flip all bits in nums1 between indices l and r (inclusive).
  2. [2, p, 0]: For every index i, update nums2[i] = nums2[i] + nums1[i] * p.
  3. [3, 0, 0]: Report the sum of all elements in nums2.

Return an array of answers corresponding to all type 3 queries.


Key Insights

  • For query type 2, instead of updating every element of nums2, maintain a global sum for nums2 and update it by adding p times the number of ones in nums1.
  • The challenge is effectively maintaining the count of ones in nums1 under range flip operations.
  • A Segment Tree (with lazy propagation) is an ideal data structure to support efficient range-flip updates and range sum queries on nums1.

Space and Time Complexity

Time Complexity: O((n + q) * log(n)), where n is the number of elements in nums1/nums2 and q is the number of queries. Space Complexity: O(n) for the segment tree.


Solution

We maintain: • A segment tree for nums1 that supports range sum queries (to count the number of 1s) and range flip updates. The lazy propagation technique is used to efficiently flip a range by storing flip information in internal nodes. • A variable (sum2) that holds the current total sum of nums2. For a query:

  • Type 1: Use the segment tree to flip bits in the range [l, r].
  • Type 2: Query the segment tree for the total number of ones in nums1 and then update sum2 by adding (number_of_1s * p).
  • Type 3: Append the current sum2 to the answer list.

The key trick is to avoid updating each element in nums2 individually, and instead, update the global sum directly via the count of ones in nums1.


Code Solutions

class SegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        self.lazy = [False] * (4 * self.n)
        self.build(0, 0, self.n - 1, nums)
        
    def build(self, node, start, end, nums):
        if start == end:
            self.tree[node] = nums[start]
        else:
            mid = (start + end) // 2
            self.build(node * 2 + 1, start, mid, nums)
            self.build(node * 2 + 2, mid + 1, end, nums)
            self.tree[node] = self.tree[node * 2 + 1] + self.tree[node * 2 + 2]
    
    def push(self, node, start, end):
        # If there is a pending flip, apply it to current node and propagate to children.
        if self.lazy[node]:
            # Flip the count of ones in this segment
            self.tree[node] = (end - start + 1) - self.tree[node]
            if start != end:
                self.lazy[node * 2 + 1] ^= True
                self.lazy[node * 2 + 2] ^= True
            self.lazy[node] = False

    def update_range(self, node, start, end, l, r):
        self.push(node, start, end)
        if start > r or end < l:
            return
        if l <= start and end <= r:
            self.lazy[node] = True
            self.push(node, start, end)
            return
        mid = (start + end) // 2
        self.update_range(node * 2 + 1, start, mid, l, r)
        self.update_range(node * 2 + 2, mid + 1, end, l, r)
        self.tree[node] = self.tree[node * 2 + 1] + self.tree[node * 2 + 2]

    def query_range(self, node, start, end, l, r):
        if start > r or end < l:
            return 0
        self.push(node, start, end)
        if l <= start and end <= r:
            return self.tree[node]
        mid = (start + end) // 2
        left_sum = self.query_range(node * 2 + 1, start, mid, l, r)
        right_sum = self.query_range(node * 2 + 2, mid + 1, end, l, r)
        return left_sum + right_sum

class Solution:
    def handleQueries(self, nums1, nums2, queries):
        n = len(nums1)
        segtree = SegmentTree(nums1)
        sum2 = sum(nums2)
        res = []
        for q in queries:
            if q[0] == 1:
                l, r = q[1], q[2]
                segtree.update_range(0, 0, n-1, l, r)  # flip range in nums1
            elif q[0] == 2:
                p = q[1]
                # Get current count of ones in entire nums1 using segment tree query.
                count_ones = segtree.query_range(0, 0, n-1, 0, n-1)
                sum2 += count_ones * p
            elif q[0] == 3:
                res.append(sum2)
        return res

# Example Usage:
# solution = Solution()
# print(solution.handleQueries([1,0,1], [0,0,0], [[1,1,1],[2,1,0],[3,0,0]]))
← Back to All Questions