We use cookies (including Google cookies) to personalize ads and analyze traffic. By continuing to use our site, you accept our Privacy Policy.

Kth Smallest Element in a Sorted Matrix

Number: 378

Difficulty: Medium

Paid? No

Companies: Meta, Amazon, PhonePe, Apple, TikTok, Microsoft, Oracle, Salesforce, Uber, Google, X


Problem Description

Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix. Note that it is the kth smallest element in the sorted order (not the kth distinct element).


Key Insights

  • The matrix is sorted both row-wise and column-wise.
  • Instead of flattening and sorting the entire matrix (which uses O(n^2) space), take advantage of the sorted properties.
  • Binary search can be applied over the range of values in the matrix.
  • Count the number of elements less than or equal to a midpoint to determine if the kth element is on the left or right side of the search range.
  • Alternatively, a min-heap approach can be used, though it generally leads to higher space complexity.

Space and Time Complexity

Time Complexity: O(n * log(max - min)), where max and min represent the range of values in the matrix. Space Complexity: O(1) extra space (ignoring recursion and input).


Solution

We solve the problem by performing binary search on the value range between the smallest and largest elements in the matrix. For a mid-value during the binary search, we count how many elements in the matrix are less than or equal to mid by scanning from the bottom-left corner. If the count is less than k, we move to the right half of the range; otherwise, we adjust the range to the left. This iterative process continues until the range is narrowed down to a single element, which is the kth smallest.


Code Solutions

# Python solution using binary search

def kthSmallest(matrix, k):
    n = len(matrix)
    # Set initial low and high based on the smallest and largest values in the matrix
    low, high = matrix[0][0], matrix[-1][-1]
    
    def count_less_equal(x):
        # Count the number of elements less than or equal to x
        count = 0
        # Start from the bottom-left corner of the matrix
        row, col = n - 1, 0
        while row >= 0 and col < n:
            if matrix[row][col] <= x:
                # All elements in this column up to this row are <= x
                count += row + 1
                col += 1
            else:
                row -= 1
        return count
    
    # Binary search to find the kth smallest element
    while low < high:
        mid = low + (high - low) // 2
        if count_less_equal(mid) < k:
            low = mid + 1  # kth element is greater than mid
        else:
            high = mid  # kth element is less than or equal to mid
    return low

# Example usage:
# print(kthSmallest([[1,5,9],[10,11,13],[12,13,15]], 8))
← Back to All Questions