I have a list of (potentially overlapping) ranges, e.g [(3, 9), (8, 10), (1, 18), (1, 1000000)]
. I need to process them in order and for each range calculate how many numbers from the range have not been seen so far (i.e. not present in previous ranges). For example, for the above list the result would be [7, 1, 10, 999982]
. A simple solution using a set:
def get_missing(ranges):
seen = set()
result = []
for start, end in ranges:
missing = 0
for n in range(start, end+1):
if n not in seen:
missing += 1
seen.add(n)
result.append(missing)
return result
How can I improve the performance of this solution (i.e. do it without looping through each number of each range)?
-
\$\begingroup\$ Do the ranges tend to overlap a lot (as in your example, in which each addition merely expands the seen-range we already had) or is it more common for the ranges to have almost no overlap at all (lots of discrete, non-overlapping seen-ranges to keep track of)? Or both? \$\endgroup\$FMc– FMc2023年01月29日 07:18:34 +00:00Commented Jan 29, 2023 at 7:18
-
\$\begingroup\$ @FMc: I think it's both. \$\endgroup\$Eugene Yarmash– Eugene Yarmash2023年01月29日 08:12:53 +00:00Commented Jan 29, 2023 at 8:12
6 Answers 6
The primary problem with your current approach. As you know from previous answers and comments, the limitation of your current approach is that it falters when faced with very large intervals, which cause the inner loop to balloon in size. A little bit of progress can be made by taking fuller advantage of sets to perform the intersection logic (as shown below in the benchmarks). But that improvement is only modest: the sets eliminate the inner loop in your own code, but behind the scenes the Python sets are doing iteration of their own.
Some infrastructure. Let's create a simple dataclass to make the code more
readable, a utility function to generate intervals to our specifications (many
or few, big or small), and a utility function to benchmark the various
approaches to the problem. We will compare your code (slight modified to handle
the new dataclass), a set-based alternative similar to the one in another
answer, and a much faster approach using an IntervalUnion
(to be discussed
below).
import sys
import time
from dataclasses import dataclass
from random import randint
def main(args):
intervals = tuple(create_intervals(
n = 1000,
start_limit = 100000000,
max_size = 100000,
))
funcs = (
get_missing_orig,
get_missing_sets,
get_missing_interval_union,
)
exp = None
for func in funcs:
got, dur = measure(func, intervals)
if exp is None:
exp = got
print(func.__name__, dur, got == exp)
@dataclass(frozen = True, order = True)
class Interval:
start: int
end: int
def create_intervals(n, start_limit, max_size):
for _ in range(n):
start = randint(0, start_limit)
end = start + randint(0, max_size)
yield Interval(start, end)
def measure(func, intervals):
t1 = time.time()
got = func(intervals)
return (got, time.time() - t1)
def get_missing_orig(ranges):
# Your original implementation, slightly adjusted.
seen = set()
result = []
for x in ranges:
missing = 0
for n in range(x.start, x.end + 1):
if n not in seen:
missing += 1
seen.add(n)
result.append(missing)
return result
def get_missing_sets(intervals):
# A set-based approach.
counts = []
seen = set()
for x in intervals:
s = set(range(x.start, x.end + 1)) - seen
counts.append(len(s))
seen.update(s)
return counts
def get_missing_interval_union(intervals):
# An approach that just stores intervals, as few as possible.
iu = IntervalUnion()
return [iu.add(x) for x in intervals]
if __name__ == '__main__':
main(sys.argv[1:])
The intuition behind IntervalUnion. We're trying to avoid two problems. First, we want to process and store only the intervals, not all of their implied values. Second, we don't want to end up having to make passes over an ever-growing collection of intervals. Instead, we would rather merge intervals whenever they overlap. If we can keep the size of the data universe in check, our computation will also be quick. For starters we need a couple of utility functions: one to tell us whether two intervals can be merged and, if so, how many of their values are overlapping; and another that can merge two intervals into one.
def overlapping(x, y):
# Takes two intervals. Returns a (CAN_MERGE, N_OVERLAPPING) tuple.
# Intervals can be merged if they overlap or abut.
# N of overlapping values is 1 + min-end - max-start.
n = 1 + min(x.end, y.end) - max(x.start, y.start)
if n >= 0:
return (True, n)
else:
return (False, 0)
def merge(x, y):
# Takes two overlapping intervals and returns their merger.
return Interval(
min(x.start, y.start),
max(x.end, y.end),
)
The data structure of an IntervalUnion. A IntervalUnion holds a SortedList
of Interval instances. The SortedList provides the ability to add and remove
intervals without having to keep the list sorted ourselves. The SortedList will
do that work efficiently for us, and the various add/remove operations will
operate on the order of O(logN)
rather than O(N)
or O(NlogN)
. The add()
method orchestrates those details, which are explained in the code comments,
and returns the number that you need -- namely, how many distinct values are
represented by the interval we just added.
from sortedcontainers import SortedList
class IntervalUnion:
def __init__(self):
self.intervals = SortedList()
def add(self, x):
# Setup and initialization:
# - The N of values in the initial interval x.
# - N of overlaps observed as we add/merge x into the IntervalUnion.
# - Convenience variable for the SortedList of existing intervals.
# - Existing intervals to be removed as a result of those mergers.
n_vals = x.end - x.start + 1
n_overlaps = 0
xs = self.intervals
removals = []
# Get the index where interval x would be added in the SortedList.
# From that location we will look leftward and rightward to find
# nearby intervals that can be merged with x. To the left, we
# just need to check the immediate neighbor. To the right, we
# must keep checking until no more merges are possible.
i = xs.bisect_left(x)
for j in range(max(0, i - 1), len(xs)):
y = self.intervals[j]
can_merge, n = overlapping(x, y)
if can_merge:
# If we can merge, do it. Then add y to the list of intervals
# to be removed, and increment the tally of overlaps.
x = merge(x, y)
removals.append(y)
n_overlaps += n
elif j >= i:
# Stop on the first rightward inability to merge.
break
# Remove and add.
for y in removals:
xs.remove(y)
xs.add(x)
# Return the distinct new values added to the IntervalUnion.
return n_vals - n_overlaps
Benchmarks. In terms of space, the IntervalUnion is quite efficient: it
stores only intervals and it merges them whenever possible. At one extreme (all
of the intervals overlap), the space used is O(1)
because the IntervalUnion
never contains more than one interval. At the other extreme (no overlap), the
space used is O(N)
, where N
represents the number of intervals.
In terms of time, the IntervalUnion becomes faster than the other
approaches when the interval sizes reach about 300 (at least in my limited number
of experiments). When the intervals get even bigger, the advantages
of the IntervalUnion are substantial. For example:
# max_size = 300
get_missing_orig 0.027595043182373047 True
get_missing_sets 0.01658797264099121 True
get_missing_interval_union 0.013303995132446289 True
# max_size = 1000
get_missing_orig 0.10612797737121582 True
get_missing_sets 0.054525136947631836 True
get_missing_interval_union 0.013611078262329102 True
# max_size = 10000
get_missing_orig 1.1063508987426758 True
get_missing_sets 0.5742030143737793 True
get_missing_interval_union 0.013240814208984375 True
# max_size = 100000
get_missing_orig 9.316476106643677 True
get_missing_sets 6.468451023101807 True
get_missing_interval_union 0.016165733337402344 True
-
\$\begingroup\$ Would it make sense to use
SortedList(key=attrgetter('start'))
for storing intervals? \$\endgroup\$Eugene Yarmash– Eugene Yarmash2023年03月02日 14:34:05 +00:00Commented Mar 2, 2023 at 14:34 -
\$\begingroup\$ @EugeneYarmash I don't think that would help (and it might hurt). The
Interval
dataclass is already arranging things so that sorting is being driven bystart
andend
. Usingattrgetter('start')
would cause the sorting to ignoreend
. Although my memory of the details has already grown fuzzy, I vaguely recall thinking during my experiments with this code that there were some situations where it mattered that the intervals were being sorted by both of their attributes. \$\endgroup\$FMc– FMc2023年03月02日 15:32:43 +00:00Commented Mar 2, 2023 at 15:32 -
\$\begingroup\$ I think it shouldn't matter as you look both leftward and rightward for an index when merging ranges, and there can't possibly be more than one range with the same
start
after a merge. I wonder if using a key would speed up theSortedList
. \$\endgroup\$Eugene Yarmash– Eugene Yarmash2023年03月02日 16:48:14 +00:00Commented Mar 2, 2023 at 16:48 -
\$\begingroup\$ @EugeneYarmash Perhaps you are right (and my memory poor) about the necessity of sorting with both attributes. Regarding speed, I'm skeptical that it would make an appreciable difference (the dataclass
order = True
setting is effectively providing the key-function). But that's just a guess about performance. Benchmarking is the only way to know for sure. \$\endgroup\$FMc– FMc2023年03月02日 16:56:45 +00:00Commented Mar 2, 2023 at 16:56
Don't write, never present undocumented code.
Python got it right specifying docstrings such that it is easy to copy them with the code, and tedious to copy the code without them.
Thinking of the problem as a set problem looks as valid as giving the chosen environment's implementation a try first.
You missed that the number wanted is the difference current range - union of previous ranges:
def debutants(inclusive_ranges):
""" For each range, list the number of numbers not seen before.
inclusive_ranges is an iterable of pairs (start, stop).
"""
seen = set()
result = []
for start, end in inclusive_ranges:
current = { n for n in range(start, end+1) } - seen
result.append(len(current))
seen |= current
return result
if __name__ == '__main__':
print(debutants([(3, 9), (8, 10), (1, 18)]))
help(debutants)
There are specialised data structures for operations on unions of intervals.
-
1\$\begingroup\$ I'm old. Python type hinting is one of those new-fangled things. \$\endgroup\$greybeard– greybeard2022年04月09日 10:03:43 +00:00Commented Apr 9, 2022 at 10:03
-
1\$\begingroup\$ This solution will not scale if the entry count remains small but the entries themselves are large. There are alternative algorithms that will successfully process i.e.
(1e12, 9e12)
\$\endgroup\$Reinderien– Reinderien2022年04月09日 10:30:24 +00:00Commented Apr 9, 2022 at 10:30 -
\$\begingroup\$ @Reinderien There's two sides to as valid as, and giving a try isn't claiming the one solution to rule them all. I did mention alternative data structures, would have a look at any alternative algorithm presented here. \$\endgroup\$greybeard– greybeard2022年04月09日 10:55:16 +00:00Commented Apr 9, 2022 at 10:55
I think the complicating factor here is that most of these operations on individual intervals can return multiple intervals back to you, which can be tricky to work with.
I'd probably start by creating a data structure for both individual intervals, and collections of intervals, and then restrict my public API to the collection type only to try and prevent confusion. For this kind of thing it's probably also a good idea to go for immutability. Let's start with this:
from typing import NamedTuple
class Intervals:
def __init__(self, *intervals):
self.intervals = intervals
@classmethod
def empty(cls):
return cls()
@classmethod
def single(cls, start, end):
return cls(_Interval(start, end))
def __repr__(self):
if len(self.intervals) == 0:
return "Intervals.empty()"
if len(self.intervals) == 1:
return f"Intervals.single({self.intervals[0].start}, {self.intervals[0].end})"
return f"Intervals({self.intervals})"
class _Interval(NamedTuple):
start: int
end: int
This gives us an initial datastructure to work with. Now to complete this task we will need a way of showing the size of a group of intervals, a way of computing the difference between two sets of intervals, and a way of computing the combined union of two sets of intervals (to keep track of our running total).
The size calculation is quite straightforward, first we do it on the single Interval class:
def __len__(self):
return (self.end - self.start) + 1
And then our composite:
def __len__(self):
return sum(len(interval) for interval in self.intervals)
The others will need a little more work, and for that we'll need to come up with a mechanism for simply concatenating two sets of intervals, without checking for overlaps, this mechanism will also need to be able to handle both individual intervals and composites (due to the varied return types of things like subtraction).
def __add__(self, other) -> "Intervals":
if isinstance(other, _Interval):
return self + Intervals(other)
return Intervals(*(self.intervals + other.intervals))
It will also be handy to create an 'overlaps' method for our single Interval type:
def overlaps(self, other: "_Interval") -> bool:
return other.start <= self.end or other.end >= self.start
Now, for our union / overlap we can implement that very simply for a single interval:
def __or__(self, other: "_Interval") -> Intervals:
if not self.overlaps(other):
return Intervals(self, other)
return Intervals.single(min(self.start, other.start), max(self.end, other.end))
And then for two sets of intervals:
def __or__(self, other: "Intervals") -> "Intervals":
output = self.empty()
for self_interval in self.intervals:
for other_interval in other.intervals:
self_interval = self_interval | other_interval
output = output + self_interval
return output
Subtraction is harder, and this is where the mixing between single and composite becomes the most noticeable. First we implement it for single intervals, taking into account each of the three ways that two intervals can overlap:
def __sub__(self, other: "_Interval") -> Intervals:
if not self.overlaps(other):
return Intervals(self)
if other.start > self.start:
interval1 = _Interval(self.start, other.start - 1)
if other.end < self.end:
return Intervals(
interval1,
_Interval(other.end + 1, self.end)
)
else:
return Intervals(interval1)
else:
if other.end < self.end:
return Intervals.single(other.end + 1, self.end)
And then we can implement that for the composite:
def __sub__(self, other):
output = self.empty()
for self_interval in self.intervals:
for other_interval in other.intervals:
self_interval = self_interval - other_interval
output = output + self_interval
return output
This is the same as the method for union, so we could dry this out at this point:
def _combine(self, other, combination: Callable[[_Interval, _Interval], "Intervals"])):
output = self.empty()
for self_interval in self.intervals:
for other_interval in other.intervals:
self_interval = combination(self_interval, other_interval)
output = output + self_interval
return output
def __sub__(self, other):
return self._combine(other, operator.__sub__)
def __or__(self, other):
return self._combine(other, operator.__or__)
Putting it all together gives us:
from typing import NamedTuple, Callable
import operator
class Intervals:
def __init__(self, *intervals):
self.intervals = intervals
def __len__(self):
return sum(len(interval) for interval in self.intervals)
def __add__(self, other):
if isinstance(other, _Interval):
return self + Intervals(other)
return Intervals(*(self.intervals + other.intervals))
def _combine(self, other, combination: Callable[[_Interval, _Interval], "Intervals"])):
output = self.empty()
for self_interval in self.intervals:
for other_interval in other.intervals:
self_interval = combination(self_interval, other_interval)
output = output + self_interval
return output
def __sub__(self, other):
return self._combine(other, operator.__sub__)
def __or__(self, other):
return self._combine(other, operator.__or__)
@classmethod
def empty(cls):
return cls()
@classmethod
def single(cls, start, end):
return cls(_Interval(start, end))
def __repr__(self):
if len(self.intervals) == 0:
return "Intervals.empty()"
if len(self.intervals) == 1:
return f"Intervals.single({self.intervals[0].start}, {self.intervals[0].end})"
return f"Intervals({self.intervals})"
class _Interval(NamedTuple):
start: int
end: int
def __len__(self):
return (self.end - self.start) + 1
def overlaps(self, other: "_Interval") -> bool:
return other.start <= self.end or other.end >= self.start
def __sub__(self, other: "_Interval") -> Intervals:
if not self.overlaps(other):
return Intervals(self)
if other.start > self.start:
interval1 = _Interval(self.start, other.start - 1)
if other.end < self.end:
return Intervals(
interval1,
_Interval(other.end + 1, self.end)
)
else:
return Intervals(interval1)
else:
if other.end < self.end:
return Intervals.single(other.end + 1, self.end)
def __or__(self, other: "_Interval") -> Intervals:
if not self.overlaps(other):
return Intervals(self, other)
return Intervals.single(min(self.start, other.start), max(self.end, other.end))
Now we have a high level API to work with. Let's start by creating our intervals + an empty set of intervals to act as a running total:
running_total = Intervals.empty()
basic_intervals = [Intervals.single(3, 9), Intervals.single(8, 10), Intervals.single(1, 18), Intervals.single(1, 1000000)]
Now we can go through those intervals, updating our running total, and computing the difference (+ its size) each time.
for interval in basic_intervals:
print(f"Running Total: {running_total}")
print(f"New Interval: {interval}")
print(f"Difference: {interval - running_total}")
print(len(interval - running_total))
running_total = interval | running_total
This gives us:
Running Total: Intervals.empty()
New Interval: Intervals.single(3, 9)
Difference: Intervals.single(3, 9)
7
Running Total: Intervals.single(3, 9)
New Interval: Intervals.single(8, 10)
Difference: Intervals.single(10, 10)
1
Running Total: Intervals.single(3, 10)
New Interval: Intervals.single(1, 18)
Difference: Intervals((_Interval(start=1, end=2), _Interval(start=11, end=18)))
10
Running Total: Intervals.single(1, 18)
New Interval: Intervals.single(1, 1000000)
Difference: Intervals.single(19, 1000000)
999982
It's a bit fiddly figuring out the space-time complexity here, it could be n^2 worst case for time if all of your intervals don't overlap, which we could speed up by having 'Intervals' be an interval tree (which would mean we could more efficiently zero in on the overlapping intervals to do the relevant calculations) - this would further complicate the code though,and may be less efficient if lack of overlap is rare. This implementation should be strictly more time efficient than your reference implementation. This implementation also has a much more manageable memory footprint, as it only keeps track of the start and end of each interval and nothing else.
I like the fact that you thought to use a set for your current approach. That is typically the right data structure for "Have I seen this before?" type questions, and is actually doing a lot to help keep the speed up. I also quite like your variable naming: the names all say what they mean without too much fuss.
As you noted, your current algorithm is working through all the numbers in a range. That is dangerous, because it gets slow not from having more ranges to worry about but just from using bigger numbers. Technically, that means it's O(2n) in time and space complexity if n is the size of the input in bits.
The natural way to find out how many numbers are in a range is subtraction. For your first range, you want to do 9-3. (plus one if both 9 and 3 should be counted) Now, let's suppose you have (3,9) having already seen (4,6). You have the 7 from the larger range, minus 3 (i.e. 6-4+1) from the smaller range which is entirely consumed. Sure enough, there are 7-3=4 numbers left: 3, 7, 8, 9. Importantly, if we had (3000000, 9000000) instead, we can still just subtract in one step and carry on. Even if we had to use continuous numbers like 9.45-3.62, it just works. So far so good.
Now, there are a couple of extra catches that you'd need to work through to get a really good algorithm. First, after (4, 6) and (3, 9) you have two previous ranges which cover some of the same numbers. If you just do the same thing with all previously seen ranges, you'll double-count 4,5,6. The solution is pretty clear: once you have processed 3-9, you can delete 4-6 from the previous entries because it is entirely swallowed up. This is important because it will allow you to make an amortised efficient algorithm. That means that even if there are individual cases which might be expensive to process, doing that processing cleans everything up in a way that ensures the average is cheap. For example, suppose you have loads of little previous ranges 1-2, 3-4, 5-6, all the way up to 999997-999998 and then you get 1-1000000. There's no way around it: you'll have to process all the shrapnel. But if you then get 0-1000001, you just have one big range left to worry about.
There are two other key considerations you'll need to worry about, which I won't go into quite as much detail for. It's "Left as an exercise for the reader" but to be honest I need to get to bed!
- Although it's easy to subtract the previous range which is entirely inside the new one, for the algorithm to work you need to think about previous ranges which only partially overlap. Think about both what the score should be and how to absorb them for the amortisation.
- For the most efficient algorithm, you'll need a slightly faster way of finding which previous ranges overlap (partly or entirely) with the new one. You could just loop through them all, but that would give you an O(n2) algorithm. You should be able to get down to an O(n log(n)) algorithm for the whole thing. You'll need a bit of care around how you keep things tidy to allow an efficient search for the right place to go.
Another approach to improve the performance would be to use interval trees, as mentioned by @OnceAndFutureKing, to keep track of ranges and calculate the missing numbers more efficiently.
An interval tree is a data structure that can be used to efficiently query for overlapping intervals. In the context of this problem, we can use an interval tree to keep track of the ranges seen so far, and query it for each new range to determine the missing numbers in that range.
class IntervalTreeNode:
def __init__(self, start, end):
self.start = start
self.end = end
self.left = None
self.right = None
def overlap(self, start, end):
return (start <= self.end and end >= self.start)
class IntervalTree:
def __init__(self):
self.root = None
def insert(self, start, end):
if not self.root:
self.root = IntervalTreeNode(start, end)
return
node = self.root
while node:
if end < node.start:
if not node.left:
node.left = IntervalTreeNode(start, end)
return
node = node.left
elif start > node.end:
if not node.right:
node.right = IntervalTreeNode(start, end)
return
node = node.right
else:
node.start = min(node.start, start)
node.end = max(node.end, end)
return
def query(self, start, end):
missing = end - start + 1
node = self.root
while node:
if node.overlap(start, end):
missing -= (min(end, node.end) - max(start, node.start) + 1)
node = node.left
elif start < node.start:
node = node.left
else:
node = node.right
return missing
def get_missing(ranges):
result = []
tree = IntervalTree()
for start, end in ranges:
missing = tree.query(start, end)
tree.insert(start, end)
result.append(missing)
return result
This would output [7, 1, 10, 999982]
which are the numbers of missing numbers in each range.
The algorithm explained
The interval tree solution works as follows:
Create an interval tree data structure to store the ranges as intervals.
For each range
(start, end)
in the input, insert the interval[start, end]
into the interval tree.For each range
(start, end)
, perform a query operation in the interval tree to find all intervals that overlap with the interval[start, end]
.The number of missing numbers in the range
(start, end)
can be calculated asend - start + 1 - sum(end_i - start_i + 1 for start_i, end_i in overlapping intervals)
where overlapping intervals is the result of the query operation in step 3.Repeat steps 3 and 4 for each range in the input and store the result in a list.
Return the list of results as the final answer.
This solution leverages the efficiency of the interval tree data structure to efficiently find overlapping intervals, which allows us to avoid the need to loop through each number of each range. Instead, we can calculate the missing numbers by simply subtracting the total number of numbers in each range from the total number of numbers in the overlapping intervals.
Time / Space complexity
The time complexity of the solution is O(nlogn) where n is the total number of ranges. This is because the insert and query operations of the interval tree each have an average time complexity of O(logn) and they are performed once per range.
The space complexity of the solution is O(n), where n is the total number of ranges. This is because the interval tree stores all the ranges and each node in the tree requires space proportional to the number of ranges.
It's worth noting that the time and space complexity are dominated by the use of the interval tree, and the performance of the solution is largely dependent on the efficiency of the interval tree implementation.
In general, the interval tree solution is a good choice when the number of ranges is large and performance is a concern. For smaller inputs or less demanding requirements, simpler solutions such as the set-based solution provided in the original question may be sufficient.
It's also possible that there are other, more specialized algorithms that would be more efficient for a particular problem, but they may be more complex to implement and may not be worth considering unless there is a compelling reason to do so.
-
\$\begingroup\$ From the question: process in order; how many numbers from [each range] range have not been seen so far (i.e. not present in previous ranges. \$\endgroup\$greybeard– greybeard2023年01月30日 21:29:00 +00:00Commented Jan 30, 2023 at 21:29
-
\$\begingroup\$ You have presented an alternative solution, but haven't reviewed the code. Please edit to show what aspects of the question code prompted you to write this version, and in what ways it's an improvement over the original. It may be worth (re-)reading How to Answer. \$\endgroup\$Toby Speight– Toby Speight2023年01月31日 08:01:10 +00:00Commented Jan 31, 2023 at 8:01
To make the algorithm more time/space efficient it's necessary to keep only start/end points for each interval instead of individual numbers. For efficiency, we could keep a sorted list of intervals as well as a running total of seen numbers. Then for each interval:
- Add it to the sorted list of intervals.
- Merge any overlapping intervals (this should be easy since the list is sorted).
- Get the difference between the total numbers in the intervals list and our running total and add it to the result list.
The code:
from bisect import insort_left
def get_missing(intervals):
processed = []
result = []
missing = 0
for interval in intervals:
insort_left(processed, interval)
processed = merge(processed)
total = get_total(processed)
result.append(total - missing)
missing = total
return result
def merge(intervals):
merged = []
for interval in intervals:
if merged and merged[-1][1] >= interval[0] - 1:
merged[-1][1] = max(merged[-1][1], interval[1])
else:
merged.append(interval)
return merged
def get_total(intervals):
return sum(interval[1] - interval[0] + 1 for interval in intervals)
assert get_missing([
[3, 9], [8, 10], [1, 18], [1, 10000000], [499999, 2000000], [5000000, 10000001]
]) == [7, 1, 10, 9999982, 0, 1]
The time complexity of this algorithm is O(n2) with n
being the number of intervals and the space complexity is O(n). The time complexity can be further improved to O(nlogn) by using a SortedList
instead of a plain Python list for storing intervals.