I am trying to learn dynamic programming using hash table. I was given this the "Quadruple sum" problem from firecode.io as a challenge:
Given a sorted array of integers and an integer target, find all the unique quadruplets which sum up to the given target.
Note: Each quadruplet must have elements
(input[i], input[j], input[k], input[l])
, such that \$i < j < k < l\$. The ordering of unique quadruplets within the output list does not matter.Examples:
Input : [1,2,3,4,5,6,7,8] Target: 10 Output: [(1, 2, 3, 4])] Input : [1,2,3,4,5,6,7] Target: 20 Output: [(2, 5, 6, 7), (3, 4, 6, 7)]
solution #1
def all_sums(xs):
dict = {}
l = [[(i,j) for i in range(len(xs)) if i < j] for j in range(len(xs))]
pairs = [item for sublist in l for item in sublist]
for i,j in pairs:
if xs[i]+xs[j] not in dict:
dict[xs[i]+xs[j]] = set()
dict[xs[i]+xs[j]].add((i,j))
return dict
def iquadruple_sum(xs, target):
sums = all_sums(xs)
for s, positions in sums.items():
for (k,l) in positions:
if target - s in sums:
for (i,j) in sums[target - s]:
if len({i,j,k,l}) == 4: yield sorted((xs[i],xs[j],xs[k],xs[l]))
def quadruple_sum(xs, target):
return list(set(tuple(x) for x in iquadruple_sum(xs, target)))
\$O(n^4)\$ in the worst case (if the array is mostly complementary). It also takes \$O(n^2)\$ space, vs \$O(n)\$. The worst case is unlikely though.
Alternative solution #2
#
def all_sums(xs):
"Returns dictionary of all possible ordered sums in O(n^2) time and space."
from collections import defaultdict
from itertools import product
d = defaultdict(set)
pairs = ((i,j) for (i,x),(j,y) in product(enumerate(xs), enumerate(xs)) if i<j)
for i,j in pairs: d[xs[i]+xs[j]].add((i,j))
return d
def iquadruple_sum(xs, target):
sums = all_sums(xs)
for s, positions in sums.items():
for (k,l) in positions:
for (i,j) in sums[target - s]:
if len({i,j,k,l}) == 4: yield sorted((xs[i],xs[j],xs[k],xs[l]))
def quadruple_sum(xs, target):
return list(set(tuple(x) for x in iquadruple_sum(xs, target)))
1 Answer 1
Just reviewing solution 2.
1. Bugs
Some inputs cause the code to raise an exception:
>>> quadruple_sum([1, 2, 3, 4, 5], 10) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "cr185635.py", line 21, in quadruple_sum return list(set(tuple(x) for x in iquadruple_sum(xs, target))) File "cr185635.py", line 21, in <genexpr> return list(set(tuple(x) for x in iquadruple_sum(xs, target))) File "cr185635.py", line 14, in iquadruple_sum for s, positions in sums.items(): RuntimeError: dictionary changed size during iteration
The problem is with the line:
for (i,j) in sums[target - s]:
sums
is acollections.defaultdict
, and this means that iftarget - s
is not already in the dictionary, then merely looking up the value causes the dictionary to be updated. There needs to be a guard like the one you have in solution 1:if target - s in sums: for i, j in sums[target - s]:
(In practice you'd want to move the guard outside the loop over
k, l
as well.)The code constructs the quadruple like this:
yield sorted((xs[i],xs[j],xs[k],xs[l]))
This means that the quadruple is sorted in order by its elements, but in the problem description it says, "Each quadruplet must have elements
(input[i], input[j], input[k], input[l])
, such that \$i<j<k<l\$", in other words, the quadruples must be sorted in order by their indices.>>> quadruple_sum([4, 3, 2, 1], 10) [(1, 2, 3, 4)] # expected [(4, 3, 2, 1)]
This also causes some quadruples to be missed:
>>> quadruple_sum([2, 1, 1, 2, 2], 6) [(1, 1, 2, 2)] # expected [(1, 1, 2, 2), (2, 1, 1, 2)]
This problem is easy to miss if you only ever test with the input in sorted order. The correct code needs to look something like this:
yield tuple(xs[m] for m in sorted((i, j, k, l)))
2. Testing
How could you have found these bugs? By testing, of course! It's easy to spot that the problem can be solved like this:
list(set(q for q in combinations(xs, 4) if sum(q) == target))
This is inefficient (takes \$Θ(n^4)\$ in all cases) but is clearly correct, so we can use it as a test oracle together with randomized test case generation. For example,
import random
import unittest
class TestQuadupleSum(unittest.TestCase):
def test_quadruple_sum(self):
for _ in range(100):
n = random.randrange(10)
xs = [random.randrange(10) for _ in range(n)]
if len(xs) >= 4 and random.random() < .9:
target = sum(random.sample(xs, 4))
else:
target = random.randrange(sum(xs) + 1)
expected = list(set(q for q in combinations(xs, 4)
if sum(q) == target))
found = quadruple_sum(xs, target)
self.assertCountEqual(expected, found,
"xs={}, target={}: expected {} but found {}"
.format(xs, target, expected, found))
This soon fails with one of the bugs in §1 above:
======================================================================
FAIL: test_quadruple_sum (cr185635.TestQuadupleSum)
----------------------------------------------------------------------
Traceback (most recent call last):
File "cr185635.py", line 59, in test_quadruple_sum
.format(xs, target, expected, found))
AssertionError: Element counts were not equal:
First has 1, Second has 0: (1, 5, 8, 3)
First has 0, Second has 1: (1, 3, 5, 8) :
xs=[1, 5, 8, 7, 3], target=17: expected [(1, 5, 8, 3)] but found [(1, 3, 5, 8)]
----------------------------------------------------------------------
3. Review
The term "dynamic programming" is normally used to describe the technique in which you build up the solution for a large problem instance by combining the solutions for smaller instances of the same problem. For example, in the well-known "change-making problem" you have to count the number of ways of making change for an amount \$a\,ドル and in the dynamic programming approach you count the number of ways of making change for various amounts \$b < a\$ and then add them. (The name "dynamic programming" is a bit obscure and I wrote a piece about this on my blog.)
What you have here is not "dynamic programming" in the usual sense of the term, since the code does not work by finding solutions to smaller instances of the "quadruple sums" problem and then combining them.
Solution 2 makes a list of pairs of distinct indexes like this:
pairs = ((i,j) for (i,x),(j,y) in product(enumerate(xs), enumerate(xs)) if i<j)
When you want to generate distinct subsets of a collection without repetition, you need
itertools.combinations
. In this case you would write:pairs = ((i,j) for (i,x),(j,y) in combinations(enumerate(xs), 2))
The code then looks up the elements at these pairs of indices, like this:
for i,j in pairs: d[xs[i]+xs[j]].add((i,j))
But in fact
xs[i]
andxs[j]
were the valuesx
andy
that you got out of the enumeration on the previous line but then discarded! Instead of discarding these elements and then looking them up again, hang on to them and use them:for (i, x), (j, y) in combinations(enumerate(xs), 2): d[x + y].add((i, j))
The value of
target - s
does not change in any of the inner loops, so you could cache it in a local variable and avoid recomputing it.The two loops over
positions
andtarget - s
:for (k,l) in positions: for (i,j) in sums[target - s]:
can be combined into one using
itertools.product
:for (i, j), (k, l) in product(positions, sums[target - s]):
The functions
all_sums
andiquadruple_sum
are short and each is called from just one place, so it might make sense to inline them at their single point of use.
4. Revised code
def quadruple_sum(xs, target):
"Return list of unique quadruples of xs that add up to target."
# Map from sum to set of pairs of indexes with that sum.
sums = defaultdict(set)
for (i, x), (j, y) in combinations(enumerate(xs), 2):
sums[x + y].add((i, j))
result = set()
for s, pairs in sums.items():
t = target - s
if t in sums:
for (i, j), (k, l) in product(pairs, sums[t]):
indexes = {i, j, k, l}
if len(indexes) == 4:
result.add(tuple(xs[m] for m in sorted(indexes)))
return list(result)
Explore related questions
See similar questions with these tags.