I wrote the below as a solution to:
Problem
Find the highest product of three numbers in a list
Constraints
- Is the input a list of integers?
- Yes
- Can we get negative inputs?
- Yes
- Can there be duplicate entries in the input?
- Yes
- Will there always be at least three integers?
- No
- Can we assume the inputs are valid?
- No, check for None input
- Can we assume this fits memory?
- Yes
Solution
from functools import reduce
from typing import List, Set, Callable
class Solution(object):
def max_prod_three(self, array: List[int]) -> int:
if array is None:
raise TypeError("array cannot be None")
if len(array) < 3:
raise ValueError("array must have at least 3 elements")
pos: List[int] = [n for n in array if n >= 0]
neg: List[int] = [n for n in array if n < 0]
pos_len: int = len(pos)
neg_len: int = len(neg)
candidates: Set[int] = set()
mult: Callable[[int, int], int] = lambda x,y:x*y
# possible combinations are:
# ---, --+, -+-, +--, -++, +-+, ++-, +++
# but order does not matter, so left with:
# --- -> - : len(neg) >= 3, will never be the max if len(pos) > = 3, if len(pos) < 3,
# max will be with 3 lowest absolute values
# --+ -> + : len(neg) >= 2, len(pos) >= 1, will be the max if exists -- > any ++
# -++ -> - : len(neg) >= 1, len(pos) >= 2, will never be the max if len(pos >= 3,
# if len(pos) == 2, max will be with lowest absolute value
# +++ -> + : len(pos) >= 3, max is with 3 highest values
if neg_len >= 3 and pos_len < 3:
candidates.add(reduce(mult, sorted(neg)[-3:]))
if neg_len >= 2 and pos_len >= 1:
candidates.add(reduce(mult, sorted(neg)[:2] + [max(pos)]))
if neg_len >= 1 and pos_len == 2:
candidates.add(reduce(mult, sorted(neg)[-1:] + pos))
if pos_len >= 3:
candidates.add(reduce(mult, sorted(pos)[-3:]))
return max(candidates)
Please could I get some feedback on it? In particular:
- Readability and does the comment make it clear what the code is doing?
- Pythonic-ness
- Typing - I have not really used typing before. Is using it on every variable worth doing, or overkill? If you know of a good resource to help a beginner, please add a link, thanks :)
- Complexity - how would I calculate this in big-0 notation? The provided solution here is O(n) for time and O(1) space, but this runs much faster than that (~10x faster for list of 7,000 entries, ~2x faster for very small lists). Notice I loop over the array twice (to separate -ve and +ve numbers), while the solution given only loops over it once, but with quite a few min/max ops within that loop. As I understand it, my code is also O(n) (since O(2n) == O(n)). But if that is correct, how is the time taken so different?
- Lastly, this is not written to be production code of course. But, what would/could I add to it for some hypothetical production use case?
2 Answers 2
Regarding readability, I'd say typing variables makes them harder to read. As is:
mult = lambda x, y: x * y
I can already tell its a function with two arguments that returns their product. I think this just adds noise to that:
mult: Callable[[int, int], int] = lambda x, y: x * y
Also, these lines are kind of hard to read (there is a lot going on in each line, plus they have duplicated code):
candidates: Set[int] = set()
mult: Callable[[int, int], int] = lambda x,y:x*y
if neg_len >= 3 and pos_len < 3:
candidates.add(reduce(mult, sorted(neg)[-3:]))
if neg_len >= 2 and pos_len >= 1:
candidates.add(reduce(mult, sorted(neg)[:2] + [max(pos)]))
if neg_len >= 1 and pos_len == 2:
candidates.add(reduce(mult, sorted(neg)[-1:] + pos))
if pos_len >= 3:
candidates.add(reduce(mult, sorted(pos)[-3:]))
return max(candidates)
You could change them to something like:
candidates_groups = []
if neg_len >= 3 and pos_len < 3:
candidates_groups.add(sorted(neg)[-3:])
if neg_len >= 2 and pos_len >= 1:
candidates_groups.add(sorted(neg)[:2] + [max(pos)])
if neg_len >= 1 and pos_len == 2:
candidates_groups.add(sorted(neg)[-1:] + pos)
if pos_len >= 3:
candidates_groups.add(sorted(pos)[-3:])
# reduce is cool, but this is more readable
candidates = [cand[0] * cand[1] * cand[2] for cand in candidates_groups]
return max(candidates)
Then regarding your answer's Big-O notation, as @Pavlo Slavynsky said, sorting a list is already O(nlog(n)), so that is your codes time complexity. Regarding space complexity, you store all of the values in the input in a new list, so that would make it O(n).
A few points:
- Why is your
max_prod_three
function in a class at all? It seems to me like it would be fine in the global namespace. - Why write your own
mult
function when you can just importmul
from theoperator
module in the standard library? - I agree with @m-alorda on the type hints being a bit overkill. I'm a big fan of type hints, but they do make the code less legible, and a lot of yours are unnecessary. Type hints are only useful when they tell IDEs and/or other programmers things about your code that wouldn't otherwise be obvious. A call to
len()
is never going to return a value that isn't an integer, and any IDE or python programmer worth its/their salt will know that. So annotating the return value of a call tolen()
as being an integer doesn't add any non-obvious information to the code, and just clutters it up a little.
-
1\$\begingroup\$ This was an answer to a code challenge. It already provided the Solution class and function signature - just like on Leetcode. Good point about the mul operator, thanks! Also good advice on the type hints - seems restricting them to function args and return type is best, with occasional use where they clarify something not obvious is best. \$\endgroup\$msm1089– msm10892021年07月30日 11:04:06 +00:00Commented Jul 30, 2021 at 11:04
sorted()
is \$O(nlog(n))\,ドル so is your code. Try to sort each array once. \$\endgroup\$