We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Maximum Sized Array

Number: 3679

Difficulty: Medium

Paid? Yes

Companies: N/A


Problem Description

Given a positive integer s, you are to determine the maximum integer n such that when creating a 3D array A of size n×n×n where each element is defined as A[i][j][k] = i * (j OR k) for 0 ≤ i, j, k < n, the sum of all elements in A does not exceed s.


Key Insights

  • The overall sum can be decoupled into two parts: one involving i and the other involving (j OR k). Notice that i appears linearly and its summation for i from 0 to n − 1 is n*(n − 1)/2.
  • The inner sum, sum(j,k in [0, n − 1] of (j OR k)), can be efficiently computed by processing each bit position independently. For each bit b, determine how many pairs (j,k) have that bit set in (j OR k), then multiply by the corresponding power of 2.
  • Use binary search to efficiently identify the maximum n such that total sum = (n*(n-1)/2) * [OR-sum over all pairs] is ≤ s. The lower bound for n is 1 (with a sum of 0) and an appropriate upper bound can be determined dynamically.

Space and Time Complexity

Time Complexity: O(log(n_max) * B) where B is the number of bits we check (typically constant, e.g., 32 or 60). Space Complexity: O(1)


Solution

We solve the problem by first expressing the total sum as a product of two separate sums. The outer sum over i gives n*(n-1)/2. The inner sum over j and k for (j OR k) can be computed by iterating over each bit from 0 up to a maximum bit ceiling. For each bit b:

  • Determine the cycle length (2^(b+1)) and the “off” period (2^b).
  • Compute the number of numbers in 0 to n − 1 that do not have the bth bit set using full cycles and the remainder.
  • The number of pairs (j, k) in which the bth bit is NOT set is the square of that count. The remainder pairs will have this bit set.
  • The contribution for bit b is then 2^b multiplied by (n^2 minus the square of the count of numbers not having that bit set).

Multiply the resulting OR sum with the sum from i to get the final total sum.

A binary search is performed over possible values of n (starting from 1) to find the maximum n where the computed sum does not exceed s. We use a doubling technique to set an initial high bound for the search.

Key data structures and techniques:

  • Bitwise operations for calculating bit contributions.
  • Binary search to efficiently narrow down the valid maximum n.
  • Constant extra space variables to track high bounds and cumulative sums.

Code Solutions

# Python solution with detailed comments
def maximum_sized_array(s):
    # Function to compute the inner OR-sum for all pairs (j, k) where 0 <= j, k < n
    def compute_or_sum(n):
        or_sum = 0
        # Iterate over each bit position (enough bits for n up to around 10^15 range)
        # Adjust the max_bits if n is huge; 60 bits is a safe upper-bound.
        max_bits = 60
        for b in range(max_bits):
            # Define cycle length for bit b: numbers repeat every 2^(b+1)
            cycle = 1 << (b + 1)
            half_cycle = 1 << b
            # Count numbers in range [0, n-1] with bth bit NOT set.
            full_cycles = n // cycle
            remainder = n % cycle
            count_zeros = full_cycles * half_cycle + max(0, remainder - half_cycle)
            # For pairs, if both numbers do not have the bth bit set then the bit is off in OR.
            pairs_bit_off = count_zeros * count_zeros
            pairs_bit_on = n * n - pairs_bit_off
            # Contribution for this bit across all pairs
            or_sum += (1 << b) * pairs_bit_on
        return or_sum

    # Function to compute the total sum of the 3D array for given n.
    def total_sum(n):
        # Sum of i for i from 0 to n-1 is n*(n-1)/2.
        sum_i = n * (n - 1) // 2
        # Total sum is product of sum_i and OR-sum for pairs (j,k)
        return sum_i * compute_or_sum(n)
    
    # Binary search for maximum n such that total_sum(n) <= s.
    low, high = 1, 1
    # Increase high exponentially until total_sum(high) > s (or high is a definite over-bound)
    while total_sum(high) <= s:
        high *= 2

    result = 1
    while low <= high:
        mid = (low + high) // 2
        current_sum = total_sum(mid)
        if current_sum <= s:
            result = mid  # mid is a candidate
            low = mid + 1
        else:
            high = mid - 1
    return result

# Example usage:
print(maximum_sized_array(10))  # Expected output: 2
print(maximum_sized_array(0))   # Expected output: 1
← Back to All Questions