Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.
Note: The solution set must not contain duplicate quadruplets.
Idea is to put all the pair sums a
in hashmap along with corresponding indexes and once done check if -a
is also present in the hashmap. If both a
and -a
is present and since the question is looking for unique quadruplets then we can just filter out with indexes.
class Solution(object):
def fourSum(self, arr, target):
seen = {}
for i in range(len(arr)-1):
for j in range(i+1, len(arr)):
if arr[i]+arr[j] in seen:
seen[arr[i]+arr[j]].add((i,j))
else:
seen[arr[i]+arr[j]] = {(i,j)}
result = []
for key in seen:
if -key + target in seen:
for (i,j) in seen[key]:
for (p,q) in seen[-key + target]:
sorted_index = sorted([arr[i], arr[j], arr[p], arr[q]])
if i not in (p, q) and j not in (p, q) and sorted_index not in result:
result.append(sorted_index)
return result
2 Answers 2
- Use
enumerate
rather thanrange(len(...))
+__getitem__
. It is both faster and more readable. - To limit items of the second iteration to be "after the current item" you can use
itertools.combinations
. - To avoid the need to check for the special case of "is the item already in the dictionary?", use a
collections.defaultdict
. - You could use a
set
rather than alist
to store the final results and remove yourself the need to check for duplicates -key + target
is better written astarget - key
import itertools
from collections import defaultdict
def four_sum(array, target):
seen = defaultdict(set)
for (i, first), (j, second) in itertools.combinations(enumerate(array), 2):
seen[first + second].add((i, j))
result = set()
for key, first_indices in seen.items():
second_indices = seen.get(target - key, set())
for p, q in second_indices:
for i, j in first_indices:
# Not reusing the same number twice
if not ({i, j} & {p, q}):
indices = tuple(sorted(array[x] for x in (i, j, p, q)))
result.add(indices)
return result
-
\$\begingroup\$ Yours is actually slower compared to OP's on leetcode, I must agree it is more readable though. Yours:
335ms
OP:239ms
. It must return alist
, so I've changed it a bit, but still didn;t really expect that. :) \$\endgroup\$Ludisposed– Ludisposed2017年11月27日 09:51:13 +00:00Commented Nov 27, 2017 at 9:51 -
\$\begingroup\$ Note: The solution set must not contain duplicate quadruplets. Yeah, online judges and their requirements matching their specs... \$\endgroup\$301_Moved_Permanently– 301_Moved_Permanently2017年11月27日 09:53:40 +00:00Commented Nov 27, 2017 at 9:53
Implementation
- why not build result with condition
i < j < p < q
?
Algorithm
- code builds hash map as combination of all indexes from
nums
. Combination of all unique values fromnums
(or index or unique values) is better choice. Case:fourSum([0 for x in range(n)], 0)
- code builds hash map with integers from
nums
which can't be added to result. Case:fourSum([x for x in range(1, n, 1)], 0)
- code check if for
key
from hash map alsotarget - key
exists in final loop, can earlier. Case:fourSum([x for x in range(0, n*10, 10)], n*5+1)
- You can split hash map for two parts:
a,b
andc,d
pair. Don't change complexity of hash map, but final loop: 1/2 * 1/2 faster
Speedup
- best: algorithm (big O notation), e.g. reduce O(n^2) memory to O(n)
- sometimes good: algorithm constants, e.g. split hash map for first and second pair
- bad: dirty, low-level language speed-up constants, e.g. replace
itertools.combinations
with directly loops. This is anti-pattern. Reasons: less understandable, maintainable, changeable and paradoxically slower. Slower because bottlenecks are usually caused by cascade several algorithms, e.g. O(n^3) * O(n^3). With clean code easier reduce problem to O(n^5) or less. With dirty code usually at the end we get O(n^6) with small const
Code (the same O(n^2) mem)
from itertools import combinations
from collections import defaultdict, Counter
def fourSum(self, nums, target):
if len(nums) < 4:
return []
half_target = target // 2
counter = Counter(nums)
uniques_wide = sorted(counter)
x_min, x_max = target - 3 * uniques_wide[-1], target - 3 * uniques_wide[0] # bad
uniques = [ x for x in uniques_wide if x_min <= x <= x_max ]
duplicates = [x for x in uniques if counter[x] > 1]
target_minus_xy_sums = set(target - x - y for x, y in combinations(uniques, 2))
target_minus_xy_sums |= set(target - x - x for x in duplicates)
ab_sum_pairs, cd_sum_pairs = defaultdict(list), defaultdict(list)
for (x, y) in combinations(uniques, 2):
if x + y in target_minus_xy_sums:
if x + y <= half_target:
ab_sum_pairs[x + y].append((x, y))
if x + y >= half_target:
cd_sum_pairs[x + y].append((x, y))
for x in duplicates:
if x + x in target_minus_xy_sums:
if x + x <= half_target:
ab_sum_pairs[x + x].append((x, x))
if x + x >= half_target:
cd_sum_pairs[x + x].append((x, x))
return [[a, b, c, d]
for ab in ab_sum_pairs
for (a, b) in ab_sum_pairs[ab]
for (c, d) in cd_sum_pairs[target - ab]
if b < c or b == c and [a, b, c, d].count(b) <= counter[b]] # if bi < ci
Explore related questions
See similar questions with these tags.