Two partitions of a set have a greatest lower bound (meet) and a least upper bound (join).
See here: Meets and joins in the lattice of partitions (Math SE)
The meet is easy to calculate. I am trying to improve the calculation of the join.
As part of a library I have written a class for set partitions.
Its join
method can be found here.
It uses the method pairs
, which can be found here.
Creating all the 2-subsets of a block is not a great idea for big blocks.
I intended this approach only as a prototype.
But to my surprise I was still faster than my Python interpretation of the answer on Math SE.
The SetPart
is essentially the list of blocks wrapped in a class.
This is the property pairs
: (file)
from itertools import combinations
def pairs(self):
result = set()
for block in self.blocks:
for pair in combinations(block, 2):
result.add(pair)
return result
The method join_pairs
is the essential part of the algorithm.: (file)
The method merge_pair
merges two elements into the same block.
def join_pairs(self, other):
from discretehelpers.set_part import SetPart
result = SetPart([], self.domain)
for pair in self.pairs.union(other.pairs):
result.merge_pair(*pair)
return result
The method join
uses join_pairs
. It removes redundant elements (called trash) from the blocks of both partitions. This trash is then added to the blocks of the result again. This way only the necessary amount of 2-element subsets is used in the calculation.
def join(self, other):
from itertools import chain
from discretehelpers.set_part import SetPart
meet_part = self.meet(other)
trash = set()
rep_to_trash = dict()
for block in meet_part.blocks:
block_rep = min(block)
block_trash = set(block) - {block_rep}
trash |= block_trash
rep_to_trash[block_rep] = block_trash
clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
clean_s_part = SetPart(clean_s_blocks)
clean_o_part = SetPart(clean_o_blocks)
clean_join_part = clean_s_part.join_pairs(clean_o_part)
s_elements = set(chain.from_iterable(self.blocks))
o_elements = set(chain.from_iterable(other.blocks))
dirty_domain = s_elements | o_elements
clean_domain = dirty_domain - trash
dirty_join_blocks = []
clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
for clean_block in clean_blocks_with_singletons:
dirty_block = set(clean_block)
for element in clean_block:
if element in rep_to_trash:
dirty_block |= rep_to_trash[element]
dirty_join_blocks.append(sorted(dirty_block))
return SetPart(dirty_join_blocks, domain=self.domain)
My standard example are these two partitions:
a = SetPart([[0, 1, 2, 4], [5, 6, 9], [7, 8]])
b = SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])
To test performance I have increased their size by factor 10, e.g. by replacing 5 with 50...59:
import timeit
from discretehelpers.set_part import SetPart
a = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]])
b = SetPart([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
print(timeit.timeit(lambda: a.meet(b), number=1000)) # 0.03588864533230662
print(timeit.timeit(lambda: a.meet_pairs(b), number=1000)) # 1.766225082334131
print(timeit.timeit(lambda: a.join(b), number=1000)) # 0.7479327972978354
print(timeit.timeit(lambda: a.join_pairs(b), number=1000)) # 4.01305090310052
My question is mainly about improving the algorithm. But fancy Python tricks are also welcome.
I have edited the question. By removing redundant elements before the calculation, I have avoided choking the process with factorial growth. But this is probably still far from the optimal solution.
A simplified version of the SetPart
class can be found here.
It can be run in a single Python file, and has no dependencies.
As requested, here is the whole class, including the methods shown above:
from functools import cached_property
from itertools import combinations, product, chain
def have(val):
return val is not None
########################################################################################################################
class SetPart(object):
def __init__(self, blocks=None, domain='N'):
"""
:param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
:param domain: Set of allowed elements of blocks. Can be a finite set. Elements are usually integers.
By default, the domain is the set of non-negative integers.
"""
if domain not in ['N', 'Z']:
assert type(domain) in [set, list, tuple, range]
self.domain = set(domain)
else:
self.domain = domain # keep the letters
if blocks is None:
self.set_trivial()
return
blocks = sorted(sorted(block) for block in blocks if len(block) > 1)
if not blocks:
self.set_trivial()
return
self.blocks = blocks
_ = dict()
for block_index, block in enumerate(self.blocks):
for element in block:
_[element] = block_index
self.non_singleton_to_block_index = _
self.non_singletons = set(self.non_singleton_to_block_index.keys())
assert len(self.non_singletons) == sum([len(block) for block in self.blocks])
if self.domain == 'N':
self.length = max(self.non_singletons) + 1
self.trivial = False
##############################################################################################
def set_trivial(self):
self.trivial = True
self.blocks = []
self.non_singleton_to_block_index = dict()
self.non_singletons = set()
if self.domain == 'N':
self.length = 0
###############################################
def __eq__(self, other):
return self.blocks == other.blocks
###############################################
def element_in_domain(self, element):
if self.domain == 'N':
return type(element) == int and element >= 0
elif self.domain == 'Z':
return type(element) == int
else:
return element in self.domain
###############################################
def merge_pair(self, a, b):
"""
When elements `a` and `b` are in different blocks, both blocks will be merged.
Changes the partition. Returns nothing.
"""
if a == b:
return # nothing to do
assert self.element_in_domain(a) and self.element_in_domain(b)
a_found = a in self.non_singletons
b_found = b in self.non_singletons
if a_found and b_found:
block_index_a = self.non_singleton_to_block_index[a]
block_index_b = self.non_singleton_to_block_index[b]
if block_index_a == block_index_b:
return # nothing to do
block_a = self.blocks[block_index_a]
block_b = self.blocks[block_index_b]
merged_block = sorted(block_a + block_b)
self.blocks.remove(block_a)
self.blocks.remove(block_b)
self.blocks.append(merged_block)
elif not a_found and not b_found:
self.blocks.append([a, b])
else: # a_found and not b_found
if b_found and not a_found:
a, b = b, a
block_index_a = self.non_singleton_to_block_index[a]
self.blocks[block_index_a].append(b)
self.__init__(self.blocks, self.domain) # reinitialize
###############################################
def blocks_with_singletons(self, elements=None):
"""
:param elements: Any subset of the domain.
:return: Blocks with added singleton-blocks for each element in `elements` that is not in an actual block.
"""
assert type(elements) in [set, list, range]
assert self.non_singletons.issubset(set(elements))
singletons = set(elements).difference(self.non_singletons)
singleton_blocks = [[_] for _ in singletons]
return sorted(self.blocks + singleton_blocks)
##############################################################################################
@cached_property
def pairs(self):
"""
:return: For each block the set of all is 2-element subsets. All those in one set.
Slow for big blocks, because of factorial growth.
"""
result = set()
for block in self.blocks:
for pair in combinations(block, 2):
result.add(pair)
return result
###############################################
def join_pairs(self, other):
"""
:param other: another set partition
:return: The join of the two set partitions.
This method uses the property `pairs`, so it is also slow for big blocks.
"""
assert self.domain == other.domain
result = SetPart([], self.domain)
for pair in self.pairs.union(other.pairs):
result.merge_pair(*pair)
return result
###############################################
def meet(self, other):
"""
:param other: another set partition
:return: The meet of the two set partitions.
Let M be the meet of partitions A and B. The blocks of M are the intersections of the blocks of A and B.
"""
meet_blocks = []
for s_block, o_block in product(self.blocks, other.blocks):
intersection = set(s_block) & set(o_block)
if intersection:
meet_blocks.append(sorted(intersection))
return SetPart(meet_blocks, self.domain)
###############################################
def join(self, other):
"""
:param other: another set partition
:return: The join of the two set partitions.
This method uses the method `join_pairs`.
The danger of factorial growth is reduced, by making the input partitions smaller.
"""
meet_part = self.meet(other)
trash = set()
rep_to_trash = dict()
for block in meet_part.blocks:
block_rep = min(block)
block_trash = set(block) - {block_rep}
trash |= block_trash
rep_to_trash[block_rep] = block_trash
clean_s_blocks = [sorted(set(block) - trash) for block in self.blocks]
clean_o_blocks = [sorted(set(block) - trash) for block in other.blocks]
clean_s_part = SetPart(clean_s_blocks)
clean_o_part = SetPart(clean_o_blocks)
clean_join_part = clean_s_part.join_pairs(clean_o_part)
s_elements = set(chain.from_iterable(self.blocks))
o_elements = set(chain.from_iterable(other.blocks))
dirty_domain = s_elements | o_elements
clean_domain = dirty_domain - trash
dirty_join_blocks = []
clean_blocks_with_singletons = clean_join_part.blocks_with_singletons(elements=clean_domain)
for clean_block in clean_blocks_with_singletons:
dirty_block = set(clean_block)
for element in clean_block:
if element in rep_to_trash:
dirty_block |= rep_to_trash[element]
dirty_join_blocks.append(sorted(dirty_block))
return SetPart(dirty_join_blocks, domain=self.domain)
2 Answers 2
This is an attempt to build the same functionality. I didn't read all of your code, but tried to implement the algorithm you linked to on Math StackExchange.
I suspect you can do it even faster for large sets if you use pandas.
class FastSetPartition:
'''Set partition with similar interface to SetPar'''
def __init__(self, blocks):
"""
:param blocks: List of lists. Each block is a list with at least two elements. The blocks do not intersect.
A block with a single element is named a singleton. Singletons are not stored.
"""
self.blocks = frozenset([frozenset(block) for block in blocks])
def join(self, other):
"""
:param self: One set partition
:param other: Another set partition
:return: The join of the two set partitions.
"""
# Find the domain relevant for the join.
dm_self = frozenset([x for block in self.blocks for x in block])
dm_other = frozenset([x for block in other.blocks for x in block])
dm = dm_self | dm_other
# For each element in the domain, register which block it is part of. Assign an index to each block
self_block_list = [block for block in self.blocks]
x2self_no = {x: block_no for block_no, block in enumerate(self_block_list) for x in block}
other_block_list = [block for block in other.blocks]
x2other_no = {x: block_no for block_no, block in enumerate(other_block_list) for x in block}
join_block_list = []
x2join_no = {}
# We build the join-blocks one block at a time
# The one we are currently working on is called current_block
# The latest addition is called new_elements
current_block = set()
new_elements = set()
def new_block(x):
'''Begin building a new join-block'''
nonlocal current_block
nonlocal new_elements
# It is important that current_block starts out empty.
# This way we add blocks from both self and other.
current_block = set([])
new_elements = set([x])
def join_blocks(block_list, x2block_no):
'''Expand the current block with the given partition'''
nonlocal current_block
nonlocal new_elements
# Expand only new elements via the block_list
added_elements = set()
added_blocks = set()
for x in new_elements:
if x in x2block_no:
added_blocks.add(x2block_no[x])
else: # x is a singleton in the current partition
added_elements.add(x)
for block_no in added_blocks:
added_elements |= block_list[block_no]
# Grow current block
new_elements = added_elements - current_block
current_block |= added_elements
return len(new_elements)
def close_block():
'''Add the current block to the join result'''
join_block_no = len(join_block_list)
for x in current_block:
x2join_no[x] = join_block_no
join_block_list.append(current_block)
def join_self_blocks():
return join_blocks(self_block_list, x2self_no)
def join_other_blocks():
return join_blocks(other_block_list, x2other_no)
for x in dm:
if x in x2join_no:
continue
new_block(x)
while join_self_blocks() + join_other_blocks() > 0:
pass
close_block()
-
\$\begingroup\$ Thanks for the algorithm. It works better than mine. My rewrite is shown as a separate answer, and I have added it to the library. Developers rarely say nice things about
nonlocal
. (See e.g. this answer to my other question on this page.) So someone may still find ways to improve the code. (But probably not me.) \$\endgroup\$Watchduck– Watchduck2024年06月15日 13:07:15 +00:00Commented Jun 15, 2024 at 13:07
This is my rewritten version of the solution by Peer Sommerlund: (file)
def join_fast(self, other):
from discretehelpers.set_part import SetPart
domain = self.non_singletons | other.non_singletons
result_blocks = []
result_dict = {}
# We build the join-blocks one block at a time.
# The one we are currently working on is called `current_block`.
# The latest addition is called `new_elements`.
current_block = set()
new_elements = set()
def new_block(e):
"""begin building a new join-block"""
nonlocal current_block
nonlocal new_elements
# It is important that current_block starts out empty.
# This way we add blocks from both self and other.
current_block = set([])
new_elements = set([e])
def join_blocks(part):
"""expand the current block with the given partition"""
nonlocal current_block
nonlocal new_elements
part_dict = part.non_singleton_to_block_index
# expand only new elements via the block_list
added_elements = set()
added_blocks = set()
for e in new_elements:
if e in part_dict:
added_blocks.add(part_dict[e])
else: # `e` is a singleton in the current partition
added_elements.add(e)
for block_index in added_blocks:
added_elements |= set(part.blocks[block_index])
# grow current block
new_elements = added_elements - current_block
current_block |= added_elements
return len(new_elements)
def close_block():
"""add the current block to the join result"""
join_block_index = len(result_blocks)
for e in current_block:
result_dict[e] = join_block_index
result_blocks.append(current_block)
for element in domain:
if element in result_dict:
continue
new_block(element)
while join_blocks(self) + join_blocks(other) > 0:
pass
close_block()
return SetPart(result_blocks)
For the big example partitions it is 12 times faster than my improved method join_pairs
.
print(timeit.timeit(lambda: a.join_slow(b), number=1000)) # 0.8285696366801858
print(timeit.timeit(lambda: a.join_fast(b), number=1000)) # 0.06911674793809652
if len(pairs) > 0:
check is redundant and can be safely removed. Also have you tried using a different interpreter than the standard CPython interpreter like pypy to see if there is any performance boost from that? \$\endgroup\$SetPart([[0, 1], [2, 3, 4], [6, 8, 9]])
, I do not see 5. Replacing a single number by appending 0..9 should increase size by 9. I may have found en.wikiversity on Discrete helpers - it is miles from the source code, which does not follow PEPs 8&257 regarding documentation. \$\endgroup\$[5, 6, 9]
is a block of partitiona
. It does contain the element 5. I meant increase by factor 10. \$\endgroup\$join
usesmeet
andjoin_pairs
, and that usespairs
andmerge_pair
. But the rest of the class can be ignored. The question is only about thejoin
method. \$\endgroup\$