Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 89ddb71

Browse files
author
Tree
committed
Update
1 parent 6e918f1 commit 89ddb71

File tree

5 files changed

+86
-130
lines changed

5 files changed

+86
-130
lines changed

‎README.md‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ optional arguments:
3232
--batch-size BATCH_SIZE
3333
Batch size.(default: 256)
3434
--lr LR Learning rate.(default: 1e-5)
35-
--max-iter MAX_ITER Number of iterations.(default: 100)
35+
--max-iter MAX_ITER Number of iterations.(default: 300)
3636
--num-workers NUM_WORKERS
3737
Number of loading data threads.(default: 6)
3838
--topk TOPK Calculate map of top k.(default: all)
@@ -54,7 +54,7 @@ imagenet100: Top 100 classes, 5000 query images, 10000 training images, MAP@1000
5454

5555
bits | 16 | 32 | 48 | 128
5656
:-: | :-: | :-: | :-: | :-:
57-
cifar10@ALL |
58-
nus-wide-tc21@5000 |
59-
imagenet100@1000 |
57+
cifar10@ALL | 0.7290 | 0.7528 | 0.7512 | 0.7579
58+
nus-wide-tc21@5000 | 0.7981 | 0.8200 | 0.8300 | 0.8424
59+
imagenet100@1000 | 0.3651 | 0.4629 | 0.5094 | 0.5787
6060

‎hashnet.py‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from models.model_loader import load_model
77
from torch.optim.lr_scheduler import CosineAnnealingLR
8-
from utils.evaluate import mean_average_precision
8+
from utils.evaluate import mean_average_precision, pr_curve
99
from loguru import logger
1010

1111

@@ -99,6 +99,15 @@ def train(
9999
device,
100100
topk,
101101
)
102+
103+
# Compute pr curve
104+
P, R = pr_curve(
105+
query_code.to(device),
106+
retrieval_code.to(device),
107+
query_targets.to(device),
108+
retrieval_targets.to(device),
109+
device,
110+
)
102111

103112
# Log
104113
logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
@@ -120,6 +129,8 @@ def train(
120129
'rB': retrieval_code.cpu(),
121130
'qL': query_targets.cpu(),
122131
'rL': retrieval_targets.cpu(),
132+
'P': P,
133+
'R': R,
123134
'map': best_map,
124135
}
125136

‎run.py‎

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,32 @@ def run():
3939
)
4040

4141
# Training
42-
#for code_length in [16, 32, 48, 128]:
43-
for code_length in [16]:
44-
args.code_length = code_length
45-
checkpoint = hashnet.train(
46-
train_dataloader,
47-
query_dataloader,
48-
retrieval_dataloader,
49-
args.arch,
50-
args.code_length,
51-
args.device,
52-
args.lr,
53-
args.max_iter,
54-
args.alpha,
55-
args.topk,
56-
args.evaluate_interval,
42+
checkpoint = hashnet.train(
43+
train_dataloader,
44+
query_dataloader,
45+
retrieval_dataloader,
46+
args.arch,
47+
args.code_length,
48+
args.device,
49+
args.lr,
50+
args.max_iter,
51+
args.alpha,
52+
args.topk,
53+
args.evaluate_interval,
54+
)
55+
logger.info('[code_length:{}][map:{:.4f}]'.format(args.code_length, checkpoint['map']))
56+
57+
# Save checkpoint
58+
torch.save(
59+
checkpoint,
60+
os.path.join('checkpoints', '{}_model_{}_code_{}_alpha_{}_map_{:.4f}.pt'.format(
61+
args.dataset,
62+
args.arch,
63+
args.code_length,
64+
args.alpha,
65+
checkpoint['map']),
5766
)
58-
logger.info('[code_length:{}][map:{:.4f}]'.format(args.code_length, checkpoint['map']))
59-
60-
# Save checkpoint
61-
#torch.save(
62-
# checkpoint,
63-
# os.path.join('checkpoints', '{}_model_{}_code_{}_alpha_{}_map_{:.4f}.pt'.format(
64-
# args.dataset,
65-
# args.arch,
66-
# args.code_length,
67-
# args.alpha,
68-
# checkpoint['map']),
69-
# )
70-
#)
67+
)
7168

7269

7370
def load_config():
@@ -93,8 +90,8 @@ def load_config():
9390
help='Batch size.(default: 256)')
9491
parser.add_argument('--lr', default=1e-5, type=float,
9592
help='Learning rate.(default: 1e-5)')
96-
parser.add_argument('--max-iter', default=100, type=int,
97-
help='Number of iterations.(default: 100)')
93+
parser.add_argument('--max-iter', default=300, type=int,
94+
help='Number of iterations.(default: 300)')
9895
parser.add_argument('--num-workers', default=6, type=int,
9996
help='Number of loading data threads.(default: 6)')
10097
parser.add_argument('--topk', default=-1, type=int,

‎utils/cifar10_to_png.py‎

Lines changed: 0 additions & 95 deletions
This file was deleted.

‎utils/evaluate.py‎

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,46 @@ def mean_average_precision(query_code,
5252

5353
mean_AP = mean_AP / num_query
5454
return mean_AP
55+
56+
57+
def pr_curve(query_code, retrieval_code, query_targets, retrieval_targets, device):
58+
"""
59+
P-R curve.
60+
61+
Args
62+
query_code(torch.Tensor): Query hash code.
63+
retrieval_code(torch.Tensor): Retrieval hash code.
64+
query_targets(torch.Tensor): Query targets.
65+
retrieval_targets(torch.Tensor): Retrieval targets.
66+
device (torch.device): Using CPU or GPU.
67+
68+
Returns
69+
P(torch.Tensor): Precision.
70+
R(torch.Tensor): Recall.
71+
"""
72+
num_query = query_code.shape[0]
73+
num_bit = query_code.shape[1]
74+
P = torch.zeros(num_query, num_bit + 1).to(device)
75+
R = torch.zeros(num_query, num_bit + 1).to(device)
76+
for i in range(num_query):
77+
gnd = (query_targets[i].unsqueeze(0).mm(retrieval_targets.t()) > 0).float().squeeze()
78+
tsum = torch.sum(gnd)
79+
if tsum == 0:
80+
continue
81+
hamm = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t())
82+
tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(device)).float()
83+
total = tmp.sum(dim=-1)
84+
total = total + (total == 0).float() * 0.1
85+
t = gnd * tmp
86+
count = t.sum(dim=-1)
87+
p = count / total
88+
r = count / tsum
89+
P[i] = p
90+
R[i] = r
91+
mask = (P > 0).float().sum(dim=0)
92+
mask = mask + (mask == 0).float() * 0.1
93+
P = P.sum(dim=0) / mask
94+
R = R.sum(dim=0) / mask
95+
96+
return P, R
97+

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /