Scribbling

LeetCode: 215. Kth Largest Element in an Array 본문

Computer Science/Coding Test

LeetCode: 215. Kth Largest Element in an Array

focalpoint 2021. 12. 11. 11:26

O(N*logK) Time Complexity Solution using heap.

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        import heapq
        h = []
        for num in nums:
            if len(h) < k:
                heapq.heappush(h, num)
            else:
                heapq.heappush(h, num)
                heapq.heappop(h)
        return h[0]

 

A better approach is to quick selection, runs within O(N) time on average.

But this runs with O(N) memory.

import random
class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        return self.quick_select(nums, k)
        
    def quick_select(self, nums, k):
        rand_idx = random.randint(0, len(nums)-1) if len(nums) >= 2 else 0
        pivot = nums[rand_idx]
        
        left = [num for num in nums if num > pivot]
        middle = [num for num in nums if num == pivot]
        right = [num for num in nums if num < pivot]
        
        L, M = len(left), len(middle)
        if k <= L:
            return self.quick_select(left, k)
        if k <= L + M:
            return middle[0]
        return self.quick_select(right, k-L-M)

 

The best solution will be as below.

Not very intuitive at first, but it is basically the same idea with partitioning in-place.

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        import random
        k = len(nums) - k
        low, high = 0, len(nums) - 1
        while low < high:
            j = self.partition(nums, low, high)
            if j < k:
                low = j + 1
            elif j > k :
                high = j - 1
            else:
                return nums[k]
        return nums[low]
        
    def partition(self, nums, low, high):
        rand_idx = random.randint(low, high)
        # swap nums[high] and nums[rand_idx]
        nums[high], nums[rand_idx] = nums[rand_idx], nums[high]
        pivot = nums[high]
        
        left = low
        for i in range(low, high):
            if nums[i] <= pivot:
                # swap nums[i] and nums[left]
                nums[i], nums[left] = nums[left], nums[i]
                left += 1
        # swap nums[left] and nums[high]
        nums[left], nums[high] = nums[high], nums[left]
        return left