4

Let's suppose we have a binary matrix A with shape n x m, I want to identify rows that have duplicates in the matrix, i.e. there is another index on the same dimension with the same elements in the same positions.

It's very important not to convert this matrix into a dense representation, since the real matrices I'm using are quite large and difficult to handle in terms of memory.

Using PyTorch for the implementation:

# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()

Intuitively, we can perform the dot product of this matrix producing a new m x m matrix which contains in terms i, j, the number of 1s that the index i has in the same position of the index j at dimension 0.

B = A.T @ A # In PyTorch, this operation will also produce a sparse representation

At this point, I've tried to combine these values, comparing them with A.sum(0),

num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
 num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
 num_elements[B.indices()[0]] == B.values()
])

But this did not work!

I think that the solution can be written only by using operations on PyTorch Sparse tensors (without using Python loops and so on), and this could also be a benefit in terms of performance.

asked Mar 4 at 21:48
3
  • 1
    Are you looking for duplicate rows or columns? B = A.T @ A would compare columns, while [email protected] would compare rows. Are you just trying to indicate duplicate rows in general, or do you want to assign duplicate rows to groups? Have you looked into writing loops at the python level and using torchscript to compile? Commented Mar 4 at 22:14
  • It is indifferent, If I'm able to do that on rows/columns then I can just perform the transpose to do that on the other dimension. About torchscript, thank you for the advice I will take a look. Commented Mar 4 at 22:18
  • A couple of thoughts from a scipy.sparse perspective. sum is performed by @ with an appropriate array, the result is a dense array. unique on that array sorts it, and checks for adjacent duplicates. That kind of sort could reduce the search space for duplicates. With CSR format you could also deduce the row length from the indptr steps. Since the matrix is binary you are only interested in duplicate column indices. Commented Mar 5 at 2:01

2 Answers 2

2

I've found a solution that only takes advantage of torch sparse representation and is very efficient in terms of memory computation and memory consumption:

# A is the sparse matrix
B = A.T @ A # or A @ A.T depending on the dimension we are working on
num_elements = A.sum(0).to_dense()
duplicates = torch.logical_and(
 B.indices()[0] < B.indices()[1], # Consider only elements over the upper diagonal
 torch.logical_and(
 B.values() == num_elements[B.indices()[0]],
 B.values() == num_elements[B.indices()[1]],
 )
)
duplicate_indices = B.indices()[1, duplicates].unique()

At this point we can use the generated mask duplicate_indices in order to remove duplicate indices.

unique_indices = A.indices()[:,
 ~torch.isin(
 A.indices()[1],
 duplicate_edges
)]

unique_indices is a sparse representation of the filtered matrix A.


Additionally, we can normalize the result to remove unused indices:

_, unique_indices[1] = torch.unique(unique_indices[1], return_inverse=True)
answered Mar 5 at 12:47
Sign up to request clarification or add additional context in comments.

Comments

1

Here is an implementation where the duplicate rows in a binary sparse matrix are identified. It returns a mask of the rows to keep from the sparse matrix, but can easily be adjusted to give e.g. indices of duplicate rows. It also handles cases where 3 or more rows are duplicates of each other and only keeps 1 row per group (the lowest index row is always kept for simplicity).

def get_unique_row_mask_sparse(A):
 # Number of matching 1s between each pair of rows
 B = A @ A.T
 
 # Number of 1s in each row
 row_sums = torch.sparse.sum(A, dim=1).to_dense()
 
 indices = B.indices()
 i, j = indices[0], indices[1]
 
 # Two rows i and j are duplicates if:
 # 1) B[i,j] == row_sums[i] == row_sums[j]
 # 2) i != j (exclude diagonal)
 # Moreover, we only keep the upper diagonal of the matrix to avoid duplicates 
 same_row_sums = row_sums[i] == row_sums[j]
 matches_equal_sums = B.values() == row_sums[i]
 not_diagonal = i != j
 upper_triangular = i < j
 is_duplicate_pair = same_row_sums & matches_equal_sums & not_diagonal & upper_triangular
 duplicate_pairs = indices[:, is_duplicate_pair]
 # For each duplicate pair (i,j), we keep row i
 keep_mask = torch.ones(A.size(0), dtype=torch.bool)
 for pair_idx in range(duplicate_pairs.size(1)):
 row_i, row_j = duplicate_pairs[:, pair_idx]
 keep_mask[row_j] = False
 return keep_mask

Testing code:

torch.manual_seed(42)
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()
# Force some duplicate rows for testing
A_dense = A.to_dense()
A_dense[3] = A_dense[1]
A_dense[6] = A_dense[1]
A_dense[9] = A_dense[2]
A = A_dense.to_sparse()
keep_mask = get_unique_row_mask_sparse(A)
print(keep_mask)

Gives the result:

tensor([ True, True, True, False, True, True, False, True, True, False])

You can run the following to create a new sparse tensor from this.

A_indices = A.indices()
rows_mask = keep_mask[A_indices[0]]
A_unique = torch.sparse_coo_tensor(
 A_indices[:, rows_mask],
 A.values()[rows_mask],
 (keep_mask.sum().item(), A.size(1))
).coalesce()
answered Mar 5 at 3:20

1 Comment

Thanks! Your solution works well with my data. I've also found another solution that only uses PyTorch operations. Nonetheless, I will accept your answer, which is quite good!

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.