Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 39f2da0

Browse files
committed
test backend
1 parent c2b9506 commit 39f2da0

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

‎dgmc/models/dgmc.py‎

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,14 @@ def reset_parameters(self):
8080

8181
def __top_k__(self, x_s, x_t):
8282
r"""Memory-efficient top-k correspondence computation."""
83-
x_s, x_t = LazyTensor(x_s.unsqueeze(-2)), LazyTensor(x_t.unsqueeze(-3))
84-
S_ij = (-x_s * x_t).sum(dim=-1)
85-
return S_ij.argKmin(self.k, dim=2, backend=self.backend)
83+
x_s, x_t = x_s.unsqueeze(-2), x_t.unsqueeze(-3)
84+
if self.backend != 'test': # pragma: no cover
85+
x_s, x_t = LazyTensor(x_s), LazyTensor(x_t)
86+
S_ij = (-x_s * x_t).sum(dim=-1)
87+
return S_ij.argKmin(self.k, dim=2, backend=self.backend)
88+
else:
89+
S_ij = (x_s * x_t).sum(dim=-1)
90+
return S_ij.topk(self.k, dim=2)[1]
8691

8792
def __include_gt__(self, S_idx, s_mask, y):
8893
r"""Includes the ground-truth values in :obj:`y` to the index tensor

‎test/models/test_dgmc.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_dgmc_repr():
2929
def test_dgmc_on_single_graphs():
3030
set_seed()
3131
model = DGMC(psi_1, psi_2, num_steps=1)
32-
model.backend = 'CPU'
32+
model.backend = 'test'
3333
x, e = data.x, data.edge_index
3434
y = torch.arange(data.num_nodes)
3535
y = torch.stack([y, y], dim=0)
@@ -68,7 +68,7 @@ def test_dgmc_on_single_graphs():
6868
def test_dgmc_on_multiple_graphs():
6969
set_seed()
7070
model = DGMC(psi_1, psi_2, num_steps=1)
71-
model.backend = 'CPU'
71+
model.backend = 'test'
7272

7373
batch = Batch.from_data_list([data, data])
7474
x, e, b = batch.x, batch.edge_index, batch.batch
@@ -88,7 +88,7 @@ def test_dgmc_on_multiple_graphs():
8888

8989
def test_dgmc_include_gt():
9090
model = DGMC(psi_1, psi_2, num_steps=1)
91-
model.backend = 'CPU'
91+
model.backend = 'test'
9292

9393
S_idx = torch.tensor([[[0, 1], [1, 2]], [[1, 2], [0, 1]]])
9494
s_mask = torch.tensor([[True, False], [True, True]])

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /