I'm trying to improve my Python, so I'm trying to do coding challenges and getting them reviewed. I would really appreciate if someone can point out how I can make my code cleaner/more pythonic. My min_heap with a 3-value tuple seems a bit messy, but I couldn't think of a better way to do it.
Problem:
https://leetcode.com/problems/find-k-pairs-with-smallest-sums/description/
You are given two integer arrays
nums1
andnums2
sorted in ascending order and an integerk
.Define a pair
(u,v)
which consists of one element from the first array and one element from the second array.Find the
k
pairs(u1,v1),(u2,v2) ...(uk,vk)
with the smallest sums.
Example 1:
Given
nums1 = [1,7,11], nums2 = [2,4,6], k = 3
Return:
[1,2],[1,4],[1,6]
The first 3 pairs are returned from the sequence:
[1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]
Solution:
from heapq import heappop, heappush
class Solution:
def kSmallestPairs(self, nums1, nums2, k):
"""
:type nums1: List[int]
:type nums2: List[int]
:type k: int
:rtype: List[List[int]]
"""
solution = []
min_heap = []
# if nums1 or nums2 is empty or k is 0, we're done
if not nums1 or not nums2 or not k:
return solution
# first initialize min_heap with all
# (nums1[0..n], nums2[0]) pairs with
# (sum of pair, (nums1[0..n], nums2[0]), index of nums2)
for value in nums1:
heappush(min_heap, (value+nums2[0], (value, nums2[0]), 0))
while k and min_heap:
curr_pair = heappop(min_heap)
solution.append(curr_pair[1])
nums2_idx = curr_pair[2]
k -= 1
# if we haven't exhausted all (nums1[curr], nums2[0..n])
# pairs, offer pair (nums1[curr], nums2[curr+1]) to min_heap
if curr_pair[2] == len(nums2)-1: continue
heappush(min_heap, (curr_pair[1][0]+nums2[nums2_idx+1], (curr_pair[1][0],
nums2[nums2_idx+1]), nums2_idx+1))
return solution
1 Answer 1
There's no need for the special case
not k
, because that is already handled by thewhile k
in the main loop.In the initial loop over
nums1
the expressionnums2[0]
gets computed twice. This could be avoided by storing the result in a local variable:value2 = nums2[0] heappush(min_heap, (value + value2, (value, value2), 0))
In the main loop there are some complicated expressions, culminating in:
heappush(min_heap, (curr_pair[1][0]+nums2[nums2_idx+1], (curr_pair[1][0], nums2[nums2_idx+1]), nums2_idx+1))
which I found hard to understand. Complex expressions can be made clearer by giving names to the subexpressions. In particular, we can use tuple assignment to give names to all the elements of the tuple:
_, pair, j = heappop(min_heap) value1, _ = pair
The reason for writing the tuple assignment
value1, _ = pair
instead of the lookupvalue1 = pair[0]
is that the tuple assignment shows the reader that the pair has exactly two elements. (And if it doesn't then Python will raise "ValueError: not enough values to unpack" or "ValueError: too many values to unpack", so we get some checking too).Now the
heappush
call looks like this:heappush(min_heap, (value1 + nums2[j + 1], (value1, nums2[j + 1]), j + 1)
Here the expression
nums2[j + 1]
gets computed twice. This could be avoided by using a local variable:value2 = nums2[j + 1] heappush(min_heap, (value1 + value2, (value1, value2), j + 1)
Now, it ought to strike you that there is a strong similarity between the revised code in §1 and §4 above. The duplication could be avoided by defining a local function, like this:
solution = [] heap = [] def add(value1, j): # Add the pair (value1, nums2[j]) to the heap if possible. if j < len(nums2): value2 = nums2[j] heappush(heap, (value1 + value2, (value1, value2), j)) for value1 in nums1: add(value1, 0) while k and heap: k -= 1 _, pair, j = curr_pair = heappop(heap) solution.append(pair) value1, _ = pair add(value1, j + 1) return solution
Note that in this version of the code there is no longer any need for the special cases
not num1
andnot nums2
. Ifnums1
is empty, then there will be no iterations forfor value1 in nums1
. Ifnums2
is empty, thenj < len(nums2)
will always be false, and so nothing will be added to the heap. Either way, there will be no iterations of the main loop and so the solution will be empty as required.The revised code in §5 above always takes at least as many steps as there are elements in
nums1
. But this is wasteful in many cases. Consider something like:kSmallestPairs(list(range(1000000)), [1, 2, 3], 1)
Since
k=1
, then result should be the single pair[(0, 1)]
. But the initialization loop runs over all the million elements ofnums1
.This could be avoided by initializing the heap with the single pair
(nums1[0], nums2[0])
since we know that this is the pair with the smallest sum. And then when we pop a pair(nums1[i], nums2[j])
from the heap, we have to add two new pairs to the heap: that is,(nums1[i + 1], nums2[j])
and(nums1[i], nums2[j + 1])
, as either of these pairs might be next in order after the pair we just processed.However, we have to be careful to add each pair just once — after processing
(nums1[1], nums2[0])
and(nums1[0], nums2[1])
we must not add the pair(nums1[1], nums2[1])
twice. This can be avoided by keeping a set of all the pairs we have added to the heap.solution = [] heap = [] # Min-heap of (nums1[i] + nums2[j], i, j) added = set() # Set of indexes (i, j) that have been added to heap. def add(i, j): # Add (nums1[i] + nums2[j], i, j) to the heap if possible. if i < len(nums1) and j < len(nums2) and (i, j) not in added: added.add((i, j)) heappush(heap, (nums1[i] + nums2[j], i, j)) add(0, 0) while k and heap: k -= 1 _, i, j = heappop(heap) solution.append((nums1[i], nums2[j])) add(i + 1, j) add(i, j + 1) return solution