You are given information about users of your website. The information includes username, a phone number and/or an email. Write a program that takes in a list of tuples where each tuple represents information for a particular user and returns a list of lists where each sublist contains the indices of tuples containing information about the same person. For example:
Input:
[("MLGuy42", "[email protected]", "123-4567"), ("CS229DungeonMaster", "123-4567", "[email protected]"), ("Doomguy", "[email protected]", "[email protected]"), ("andrew26", "[email protected]", "[email protected]")]
Output:
[[0, 1, 3], [2]]
Since "MLGuy42", "CS229DungeonMaster" and "andrew26" are all the same person.
Each sublist in the output should be sorted and the outer list should be sorted by the first element in the sublist.
Below is the code snippet that I did for this problem. It seems to work fine, but I'm wondering if there is a better/optimized solution.
def find_duplicates(user_info):
results = list()
seen = dict()
for i, user in enumerate(user_info):
first_seen = True
key_info = None
for info in user:
if info in seen:
first_seen = False
key_info = info
break
if first_seen:
results.append([i])
pos = len(results) - 1
else:
index = seen[key_info]
results[index].append(i)
pos = index
for info in user:
seen[info] = pos
return results
-
\$\begingroup\$ By the way, this is a thinly disguised union-find problem. \$\endgroup\$200_success– 200_success2017年11月21日 22:31:10 +00:00Commented Nov 21, 2017 at 22:31
3 Answers 3
There is a way to increase readability. A for-else
statement can be used to replace flags that are only set on break
. Click here to learn more. The else
clause is only called if the loop completes without break
being called.
def find_duplicates(user_info):
results = list()
seen = dict()
for i, user in enumerate(user_info):
for info in user:
if info in seen:
index = seen[info]
results[index].append(i)
pos = index
break
else:
results.append([i])
pos = len(results) - 1
for info in user:
seen[info] = pos
return results
How I tested speeds
Your code, as it currently stands, will work exactly the same with integers. So, I just randomly generated large data sets this way (and played along with MAXVALUE
and range
a bit.
from random import randrange
from timeit import timeit
MAXVALUE = 1000
for a in range(5): #I wanted to make sure I checked 5 to make sure I don't get a outlier data set that effects my ability to use timeit reasonably.
user_info = [[randrange(MAXVALUE) for i in range(3)] for _ in range(1000)]
print(timeit(lambda: find_duplicates(user_info), number=10000))
Credit for improving code: Maarten Fabré
-
\$\begingroup\$ I think your second
if info in seen:
should beif info not in seen:
\$\endgroup\$Maarten Fabré– Maarten Fabré2017年11月21日 15:59:40 +00:00Commented Nov 21, 2017 at 15:59 -
\$\begingroup\$ @MaartenFabré You're right. When I looked at it I found that it wasn't faster checking that, so I removed it all together. Thanks. \$\endgroup\$Neil– Neil2017年11月21日 18:49:29 +00:00Commented Nov 21, 2017 at 18:49
A good day starts with a test
import unittest
#import or paste your function here
a = ('a', 'a@a', '1')
b = ('b', 'b@b', '2')
c = ('c', 'c@c', '3')
ab = ('a', 'b@b', '12')
tests = [
([a, b], [[0], [1]]),
([a, b, c], [[0], [1], [2]]),
([a, b], [[0], [1]]),
([a, ab, b], [[0, 1, 2]]),
([a, ab, b, c], [[0, 1, 2],[3]]),
([a, ab, c, b], [[0, 1, 3],[2]]),
([c, a, ab, b], [[0],[1, 2, 3]]),
([a, b, ab], [[0, 1, 2]]),
]
class Test(unittest.TestCase):
def test_some(self):
for n, t in enumerate(tests):
ud = t[0]
ref = t[1]
res = find_duplicates(ud)
assert ref==res, "n:{}, ud:{}, ref:{}, res:{}".format(n, ud, ref, res)
if __name__ == "__main__":
unittest.main()
gives
======================================================================
FAIL: test_some (user_info.Test)
----------------------------------------------------------------------
Traceback (most recent call last):
File "C:\Users\Verena\workspace\user_info\src\user_info.py", line xx, in test_some
assert ref==res, "n:{}, ud:{}, ref:{}, res:{}".format(n, ud, ref, res)
AssertionError: n:7, ud:[('a', 'a@a', '1'), ('b', 'b@b', '2'), ('a', 'b@b', '12')], ref:[[0, 1, 2]], res:[[0, 2], [1]]
----------------------------------------------------------------------
Ran 1 test in 0.001s
I was toying with this thinking a defaultdict
would also work for this instead of a list of lists
I had a solution I found rather elegant with getting an intersection of seen.keys()
and set(user)
, but that was prohibitively slow. This solution is about as fast as nfn neil's
def find_duplicates_defaultdict(user_info):
results = collections.defaultdict(list)
seen = dict()
for i, user in enumerate(user_info):
for info in user:
try:
pos = seen[info]
break
except KeyError:
pass
else:
pos = len(results)
results[pos].append(i)
for info in user:
if info not in seen:
seen[info] = pos
return results.values() # expects ordered defaultdict older versions of python might need something like
return [results[key] for key in sorted(results.keys())]
Order of iteration
These results depend on the order of iteration. For a list like this
info = [
("MLGuy42", "[email protected]", "123-4567"),
("CS229DungeonMaster", "123-4567", "[email protected]"),
("Doomguy", "[email protected]", "[email protected]"),
("andrew26", "[email protected]", "[email protected]"),
("andrew26", "[email protected]", "[email protected]")
]
it returns [[0, 1], [2], [3, 4]]
instead of [[0, 1, 3, 4], [2]]
, which would have been the result if line 4 was before line 3
To solve this, I found a different implementation. This requires a double iteration over the results, so it is slower, but more complete
def find_duplicates_set(user_info):
subsets = collections.defaultdict(set)
subsets[0] = set(user_info[0])
for user in user_info:
indices = {i for i, subset in subsets.items() if not subset.isdisjoint(user)}
if not indices:
subsets[max(subsets.keys()) + 1] = set(user)
elif len(indices) == 1:
subsets[indices.pop()].update(user)
else:
indices = sorted(indices)
i0 = indices.pop(0)
subsets[i0].update(user)
for i in indices:
subset = subsets.pop(i)
subsets[i0].update(subset)
results = collections.defaultdict(list)
for i, user in enumerate(user_info):
for index, subset in subsets.items():
if not subset.isdisjoint(user):
results[index].append(i)
break
return results.values()
Execution speed
MAXVALUE = 1000
for a in range(3):
user_info = [[randrange(MAXVALUE) for i in range(3)] for _ in range(10 * MAXVALUE)]
print('original: ', timeit(lambda: find_duplicates(user_info), number=1000))
print('nfn_neil: ', timeit(lambda: find_duplicates_nfn(user_info), number=1000))
print('defaultdict: ', timeit(lambda: find_duplicates_defaultdict(user_info), number=1000))
print('set: ', timeit(lambda: find_duplicates_set(user_info), number=1000))
original: 6.129232463274093 nfn_neil: 4.994730664504459 defaultdict: 4.738290764427802 set: 18.66864765893115 original: 5.425038123095874 nfn_neil: 4.78540134785726 defaultdict: 4.616922919762146 set: 19.058994075487135 original: 5.55867017791752 nfn_neil: 4.920460685316357 defaultdict: 4.8429226022271905 set: 19.008669542017742