We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Minimum XOR Sum of Two Arrays

Number: 1989

Difficulty: Hard

Paid? No

Companies: Media.net


Problem Description

Given two integer arrays nums1 and nums2 of the same length n, rearrange the elements of nums2 so that the XOR sum of corresponding elements (i.e. (nums1[0] XOR nums2[0]) + (nums1[1] XOR nums2[1]) + ... + (nums1[n-1] XOR nums2[n-1])) is minimized. You need to return the minimized XOR sum after the best rearrangement.


Key Insights

  • There are n! possible rearrangements; since n can be up to 14, brute-forcing is too heavy.
  • Use bitmask dynamic programming to represent which elements of nums2 have been used.
  • Each DP state corresponds to a mask of used indices, where the number of set bits indicates the current index in nums1 to process.
  • Transition by iterating over unused indices in nums2 and updating the mask.
  • Time complexity is O(n * 2^n), which is acceptable given the constraints.

Space and Time Complexity

Time Complexity: O(n * 2^n) Space Complexity: O(2^n) (for the memoization table)


Solution

We use a bitmask DP approach. Define a function dp(mask) where mask is a bitmask representing which elements in nums2 have been paired. Let cnt be the number of set bits in mask which tells us the current index in nums1 we are processing. For every unused index in nums2 (i.e., where the bit is not set), we try pairing nums1[cnt] with that nums2 element and update the mask accordingly, adding the XOR value. We take the minimum sum among all choices. Memoization is used to avoid recomputation. The base case is when mask has all bits set; then all elements are paired so we return 0.


Code Solutions

from functools import lru_cache

def minimumXORSum(nums1, nums2):
    n = len(nums1)
    
    @lru_cache(maxsize=None)
    def dp(mask):
        # count number of bits set in mask = current index in nums1
        i = bin(mask).count("1")
        if i == n:
            return 0
        min_sum = float('inf')
        # Try every unused index in nums2
        for j in range(n):
            if (mask >> j) & 1 == 0:
                # XOR sum for current pairing plus recursion for next state
                min_sum = min(min_sum, (nums1[i] ^ nums2[j]) + dp(mask | (1 << j)))
        return min_sum
    
    return dp(0)

# Example usage:
#print(minimumXORSum([1,2], [2,3]))
← Back to All Questions