This is a LeetCode.com problem. https://leetcode.com/problems/count-of-smaller-numbers-after-self/description/
You are given an integer array nums and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of nums[i].
Example:
Given nums = [5, 2, 6, 1]
To the right of 5 there are 2 smaller elements (2 and 1). To the right of 2 there is only 1 smaller element (1). To the right of 6 there is 1 smaller element (1). To the right of 1 there is 0 smaller element.
Return the array [2, 1, 1, 0].
Following is my solution which is accepted. But I noticed my run time is lot higher than many other submissions. I am curious to know if there is a better algorithm to this problem or is it something in the implementation that I can improve? How pythonic is my code? Do you see issues with any logic that could be done in a better way (both performance and code beauty wise)
Algorithm: I start traversing the input list from the end. I take each element and insert it into a BST. While inserting I also track for each node in the BST, how many nodes are left of it (the size of left subtree) which denotes elements that are strictly smaller than current node. I also compute a position variable that starts at zero at root and gets incremented by X, each time a node is visited with value less than the given node where X is the count of its left subtree size plus one. If the number is equal, I only increment it by the node's left subtree count.
class bst_node(object):
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.left_count = 0 # less than
# Name of this class is given by LeetCode, not my choice.
class Solution(object):
def add_node(self, bst_pointer, new_node, position):
if bst_pointer == None:
return position, new_node
if new_node.val < bst_pointer.val:
bst_pointer.left_count += 1
position, bst_pointer.left = self.add_node(bst_pointer.left, new_node, position)
else:
if new_node.val > bst_pointer.val:
position += 1
position += bst_pointer.left_count
position, bst_pointer.right = self.add_node(bst_pointer.right, new_node, position)
return position, bst_pointer
# This method signature is also given by Leetcode.
def countSmaller(self, nums):
"""
:type nums: List[int]
:rtype: List[int]
"""
res = []
bst = None
for n in nums[::-1]:
smaller_after, bst = self.add_node(bst, bst_node(n), 0)
res.append(smaller_after)
return res[::-1]
# Few test cases ...
print(Solution().countSmaller([]))
print(Solution().countSmaller([5, 2, 6, 1]))
2 Answers 2
Your BST can become unbalanced. In the worst case when the input is sorted, the tree degenerates to a linked list, and you traverse through it all to insert each element. Various algorithms exist for self-balancing search trees, but you might also want to look into the SortedContainers library. Even if you want to create your solution from scratch, you can find good ideas there.
Using sortedcontainers.SortedList
the solution becomes very simple:
def countSmaller(self, nums):
result = []
seen = sortedcontainers.SortedList()
for num in reversed(nums):
result.append(seen.bisect_left(num))
seen.add(num)
result.reverse()
return result
This is twice as fast as yours for 300 random numbers, and the gap increases as the size grows. For list(range(100))
this is ten times as fast.
-
\$\begingroup\$ The input order matters. So rebalancing the tree would give incorrect answer unless one comes up with complicated count management algorithm when nodes are moved around. \$\endgroup\$Stack crashed– Stack crashed2017年10月23日 20:40:34 +00:00Commented Oct 23, 2017 at 20:40
-
\$\begingroup\$ @Stackcrashed That's a valid concern, though I don't think it has to be complicated. Add or subtract
left_count
when rotating left or right, respectively. \$\endgroup\$Janne Karila– Janne Karila2017年10月24日 05:55:53 +00:00Commented Oct 24, 2017 at 5:55 -
\$\begingroup\$ [::-1] seems to be a better option. stackoverflow.com/questions/3705670/… \$\endgroup\$Stack crashed– Stack crashed2017年10月25日 16:53:55 +00:00Commented Oct 25, 2017 at 16:53
-
\$\begingroup\$ @Stackcrashed That question is about creating a reversed copy of a list. In this case we can reverse the list in place without copying. \$\endgroup\$Janne Karila– Janne Karila2017年10月25日 18:33:37 +00:00Commented Oct 25, 2017 at 18:33
-
\$\begingroup\$ I guess you didn't read the link. It talks about exactly that - in place isn't better/faster. \$\endgroup\$Stack crashed– Stack crashed2017年10月25日 21:02:34 +00:00Commented Oct 25, 2017 at 21:02
Let's get some obvious stuff out of the way first:
- Do not compare to
None
using==
. Usevariable is None
instead. See this document for rationale: https://www.python.org/dev/peps/pep-0008/#id51 countSmaller
should use snake case:count_smaller
(similar convention).res[::-1]
- creates a reversed copy, which you don't need, sinceres
is a local variable, which is not shared with other functions. You could useres.reverse()
instead, saving some memory allocations.- Names of classes need to be in Pascal case, i.e.
bst_node
needs to beBSTNode
.
Some subtleties: Python actually has array
type, rarely used and, to be honest, not very useful, but, maybe if you are required to return an array, it is what you should return (or just reflect on the level of competence of the person describing the problem).
Now, while, in principle it's hard for me to think about a theoretically better algorithm, notice that a naive solution given below performs better on even reasonably long lists than the one you have simply because all the memory allocation and function calls needed to support this tree structure you've built are so expensive:
import cProfile
import random
def count_smaller(nums):
result = [0] * len(nums)
for i, n in enumerate(nums):
for j in nums[i+1:]:
if j < n:
result[i] += 1
return result
def test_count_smaller(n, m):
for _ in range(n):
count_smaller(random.sample(range(m), m))
def test_countSmaller(n, m):
for _ in range(n):
Solution().countSmaller(random.sample(range(m), m))
cProfile.run('test_count_smaller(100, 50)')
Gives:
18888 function calls in 0.019 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
100 0.011 0.000 0.011 0.000 <stdin>:1(count_smaller)
1 0.000 0.000 0.019 0.019 <stdin>:1(test_count_smaller)
1 0.000 0.000 0.019 0.019 <string>:1(<module>)
300 0.000 0.000 0.000 0.000 _weakrefset.py:70(__contains__)
200 0.000 0.000 0.000 0.000 abc.py:178(__instancecheck__)
5000 0.003 0.000 0.005 0.000 random.py:229(_randbelow)
100 0.002 0.000 0.007 0.000 random.py:289(sample)
1 0.000 0.000 0.019 0.019 {built-in method builtins.exec}
200 0.000 0.000 0.000 0.000 {built-in method builtins.isinstance}
200 0.000 0.000 0.000 0.000 {built-in method builtins.len}
100 0.000 0.000 0.000 0.000 {built-in method math.ceil}
100 0.000 0.000 0.000 0.000 {built-in method math.log}
5000 0.000 0.000 0.000 0.000 {method 'bit_length' of 'int' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
7584 0.001 0.000 0.001 0.000 {method 'getrandbits' of '_random.Random' objects}
While your original code gives:
cProfile.run('test_countSmaller(100, 50)')
59449 function calls (33812 primitive calls) in 0.030 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.001 0.001 0.030 0.030 <stdin>:1(test_countSmaller)
100 0.004 0.000 0.022 0.000 <stdin>:14(countSmaller)
5000 0.002 0.000 0.002 0.000 <stdin>:2(__init__)
30637/5000 0.016 0.000 0.016 0.000 <stdin>:2(add_node)
1 0.000 0.000 0.030 0.030 <string>:1(<module>)
300 0.000 0.000 0.000 0.000 _weakrefset.py:70(__contains__)
200 0.000 0.000 0.000 0.000 abc.py:178(__instancecheck__)
5000 0.003 0.000 0.004 0.000 random.py:229(_randbelow)
100 0.002 0.000 0.007 0.000 random.py:289(sample)
1 0.000 0.000 0.030 0.030 {built-in method builtins.exec}
200 0.000 0.000 0.000 0.000 {built-in method builtins.isinstance}
100 0.000 0.000 0.000 0.000 {built-in method builtins.len}
100 0.000 0.000 0.000 0.000 {built-in method math.ceil}
100 0.000 0.000 0.000 0.000 {built-in method math.log}
5000 0.000 0.000 0.000 0.000 {method 'append' of 'list' objects}
5000 0.000 0.000 0.000 0.000 {method 'bit_length' of 'int' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
7608 0.001 0.000 0.001 0.000 {method 'getrandbits' of '_random.Random' objects}
Things you could improve without significantly changing your code
- If you are after performance: don't use classes. Use
tuple
ornamedtuple
(see here: https://docs.python.org/2/library/collections.html#collections.namedtuple ). There's no reasonbst_node
should be a class. - Invoking class method is also expensive, but you have nothing to gain from
add_node
being a class method.
Explore related questions
See similar questions with these tags.