Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1668312

Browse files
committed
Ver 0.4
1 parent ff8cee2 commit 1668312

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

‎README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@ import torchbnn
4242

4343
### Version 0.3
4444
* **bayesian_kl_loss/BKLLoss returns tensor.Tensor([0]) as default**
45-
* In the previous version, bayesian_kl_loss/BKLLoss returns 0 of int type if there is no Bayesian layers. However, considering all torch loss returns tensor and .item() is used to make them to int type, they are changed to return tensor.Tensor([0]) if there is no Bayesian layers.
45+
* In the previous version, bayesian_kl_loss/BKLLoss returns 0 of int type if there is no Bayesian layers. However, considering all torch loss returns tensor and .item() is used to make them to int type, they are changed to return tensor.Tensor([0]) if there is no Bayesian layers.
46+
47+
### Version 0.4
48+
* **bayesian_kl_loss/BKLLoss is modified**
49+
* In some cases, the device problem(cuda/cpu) has occurred. Thus, losses are initialized with tensor.Tensor([0]) on the device on which the model is.

‎torchbnn/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def bayesian_kl_loss(model, reduction='mean', last_layer_only=False) :
3232
last_layer_only (Bool): True for return only the last layer's KL divergence.
3333
3434
"""
35-
36-
kl = torch.Tensor([0])
37-
kl_sum = torch.Tensor([0])
35+
device=torch.device("cuda"ifnext(model.parameters()).is_cudaelse"cpu")
36+
kl = torch.Tensor([0]).to(device)
37+
kl_sum = torch.Tensor([0]).to(device)
3838
n = 0
3939

4040
for m in model.modules() :

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /