Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ac51032

Browse files
committed
Ver 0.6
1 parent c80990a commit ac51032

File tree

7 files changed

+113
-28
lines changed

7 files changed

+113
-28
lines changed

‎README.md‎

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,31 @@ import torchbnn
2828
## Update Records
2929

3030
### Version 0.1
31-
* **modules** : BayesLinear, BayesConv2d, BayesBatchNorm2d
32-
* **utils** : convert_model(nonbayes_to_bayes, bayes_to_nonbayes)
33-
* **functional** : bayesian_kl_loss
31+
* **modules** : BayesLinear, BayesConv2d, BayesBatchNorm2d are added.
32+
* **utils** : convert_model(nonbayes_to_bayes, bayes_to_nonbayes) is added.
33+
* **functional.py** : bayesian_kl_loss is added.
3434

3535
### Version 0.2
36-
* **prior_sigma** is used when initialize modules and functions instead of **prior_log_sigma**
37-
* **Modules(BayesLinear, BayesConv2d, BayesBatchNorm2d)** are re-defined with prior_sigma instead of prior_log_sigma.
38-
* **convert_model(nonbayes_to_bayes, bayes_to_nonbayes)** is also changed with prior_sigma instead of prior_log_sigma.
39-
* **Modules(BayesLinear, BayesConv2d, BayesBatchNorm2d)** : Base initialization method is changed to the method of Adv-BNN from the original torch method.
40-
* **functional** : **bayesian_kl_loss** is changed similar to ones in **torch.functional**
41-
* **loss** : **BKLLoss** is added based on bayesian_kl_loss similar to ones in **torch.loss**
36+
* **prior_sigma** is used when initialize modules and functions instead of **prior_log_sigma**.
37+
* **modules** are re-defined with prior_sigma instead of prior_log_sigma.
38+
* **utils/convert_model.py** is also changed with prior_sigma instead of prior_log_sigma.
39+
* **modules** : Base initialization method is changed to the method of Adv-BNN from the original torch method.
40+
* **functional.py** : **bayesian_kl_loss** is changed similar to ones in **torch.functional**.
41+
* **modules/loss.py** : **BKLLoss** is added based on bayesian_kl_loss similar to ones in **torch.loss**.
4242

4343
### Version 0.3
44-
* **bayesian_kl_loss/BKLLoss returns tensor.Tensor([0]) as default**
45-
* In the previous version, bayesian_kl_loss/BKLLoss returns 0 of int type if there is no Bayesian layers. However, considering all torch loss returns tensor and .item() is used to make them to int type, they are changed to return tensor.Tensor([0]) if there is no Bayesian layers.
44+
* **functional.py** :
45+
***bayesian_kl_loss returns tensor.Tensor([0]) as default** : In the previous version, bayesian_kl_loss returns 0 of int type if there is no Bayesian layers. However, considering all torch loss returns tensor and .item() is used to make them to int type, they are changed to return tensor.Tensor([0]) if there is no Bayesian layers.
4646

4747
### Version 0.4
48-
* **bayesian_kl_loss/BKLLoss is modified**
49-
* In some cases, the device problem(cuda/cpu) has occurred. Thus, losses are initialized with tensor.Tensor([0]) on the device on which the model is.
48+
* **functional.py** :
49+
***bayesian_kl_loss is modified** : In some cases, the device problem(cuda/cpu) has occurred. Thus, losses are initialized with tensor.Tensor([0]) on the device on which the model is.
5050

5151
### Version 0.5
52-
* **nonbayes_to_bayes, bayes_to_nonbayes is modified**
53-
* Before this version, they replace the original model. From now, we can handle it with the 'inplace' argument. Set 'inplace=True' for replace the input model and 'inplace=False' for getting a new model. 'inplace=True' is recommended cause it shortens memories and no future problems with deepcopy.
52+
* **utils/convert_model.py** :
53+
* **nonbayes_to_bayes, bayes_to_nonbayes is modified** : Before this version, they replace the original model. From now, we can handle it with the 'inplace' argument. Set 'inplace=True' for replace the input model and 'inplace=False' for getting a new model. 'inplace=True' is recommended cause it shortens memories and no future problems with deepcopy.
54+
55+
### Version 0.6
56+
* **utils/freeze_model.py** :
57+
* **freeze, unfreeze methods are added** : bayesian modules always returns different outputs even if inputs are same. It is because of their randomized forward propagation. Sometimes, however, we need to freeze this randomized process for analyzing the model deeply. Then you can use this freeze method for changing the bayesian model into non-bayesian model with same parameters.
58+
* **modules** : For supporting **freeze** method, freeze, weight_eps and bias_eps is added to each modules. If freeze is False (Defalt), weight_eps and bias_eps will be initialized with normal noise at every forward. If freeze is True, weight_eps and bias_eps won't be changed.

‎torchbnn/functional.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def bayesian_kl_loss(model, reduction='mean', last_layer_only=False) :
3535
device = torch.device("cuda" if next(model.parameters()).is_cuda else "cpu")
3636
kl = torch.Tensor([0]).to(device)
3737
kl_sum = torch.Tensor([0]).to(device)
38-
n = 0
38+
n = torch.Tensor([0]).to(device)
3939

4040
for m in model.modules() :
4141
if isinstance(m, (BayesLinear, BayesConv2d)):
@@ -57,8 +57,8 @@ def bayesian_kl_loss(model, reduction='mean', last_layer_only=False) :
5757
kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma)
5858
kl_sum += kl
5959
n += len(m.bias_mu.view(-1))
60-
61-
if last_layer_only :
60+
61+
if last_layer_only orn==0:
6262
return kl
6363

6464
if reduction == 'mean' :

‎torchbnn/modules/batchnorm.py‎

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,16 @@ def __init__(self, prior_mu, prior_sigma, num_features, eps=1e-5, momentum=0.1,
3535
self.prior_mu = prior_mu
3636
self.prior_sigma = prior_sigma
3737
self.prior_log_sigma = math.log(prior_sigma)
38+
39+
self.freeze = False
40+
3841
self.weight_mu = Parameter(torch.Tensor(num_features))
3942
self.weight_log_sigma = Parameter(torch.Tensor(num_features))
43+
self.weight_eps = self.weight_eps = torch.randn_like(self.weight_log_sigma)
44+
4045
self.bias_mu = Parameter(torch.Tensor(num_features))
4146
self.bias_log_sigma = Parameter(torch.Tensor(num_features))
47+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
4248
else:
4349
self.register_parameter('weight_mu', None)
4450
self.register_parameter('weight_log_sigma', None)
@@ -95,8 +101,11 @@ def forward(self, input):
95101
exponential_average_factor = self.momentum
96102

97103
if self.affine :
98-
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
99-
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
104+
if not self.freeze :
105+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
106+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
107+
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
108+
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
100109
else :
101110
weight = None
102111
bias = None

‎torchbnn/modules/conv.py‎

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,28 @@ def __init__(self, prior_mu, prior_sigma, in_channels, out_channels, kernel_size
4747
self.prior_mu = prior_mu
4848
self.prior_sigma = prior_sigma
4949
self.prior_log_sigma = math.log(prior_sigma)
50-
self.bias = bias
5150

51+
self.freeze = False
5252

5353
if transposed:
5454
self.weight_mu = Parameter(torch.Tensor(
5555
in_channels, out_channels // groups, *kernel_size))
5656
self.weight_log_sigma = Parameter(torch.Tensor(
5757
in_channels, out_channels // groups, *kernel_size))
58+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
5859
else:
5960
self.weight_mu = Parameter(torch.Tensor(
6061
out_channels, in_channels // groups, *kernel_size))
6162
self.weight_log_sigma = Parameter(torch.Tensor(
6263
out_channels, in_channels // groups, *kernel_size))
64+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
6365

66+
self.bias = bias
67+
6468
if bias:
6569
self.bias_mu = Parameter(torch.Tensor(out_channels))
6670
self.bias_log_sigma = Parameter(torch.Tensor(out_channels))
71+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
6772
else:
6873
self.register_parameter('bias_mu', None)
6974
self.register_parameter('bias_log_sigma', None)
@@ -126,7 +131,9 @@ def __init__(self, prior_mu, prior_log_sigma, in_channels, out_channels, kernel_
126131
def conv2d_forward(self, input, weight):
127132

128133
if self.bias:
129-
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
134+
if not self.freeze :
135+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
136+
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
130137
else :
131138
bias = None
132139

@@ -140,6 +147,8 @@ def conv2d_forward(self, input, weight):
140147
self.padding, self.dilation, self.groups)
141148

142149
def forward(self, input):
143-
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
150+
if not self.freeze :
151+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
152+
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
144153

145154
return self.conv2d_forward(input, weight)

‎torchbnn/modules/linear.py‎

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,25 @@ class BayesLinear(Module):
2222

2323
def __init__(self, prior_mu, prior_sigma, in_features, out_features, bias=True):
2424
super(BayesLinear, self).__init__()
25+
self.in_features = in_features
26+
self.out_features = out_features
27+
2528
self.prior_mu = prior_mu
2629
self.prior_sigma = prior_sigma
2730
self.prior_log_sigma = math.log(prior_sigma)
28-
self.in_features=in_features
29-
self.out_features = out_features
31+
32+
self.freeze = False
3033

3134
self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
3235
self.weight_log_sigma = Parameter(torch.Tensor(out_features, in_features))
36+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
3337

3438
self.bias = bias
3539

3640
if bias:
3741
self.bias_mu = Parameter(torch.Tensor(out_features))
3842
self.bias_log_sigma = Parameter(torch.Tensor(out_features))
43+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
3944
else:
4045
self.register_parameter('bias_mu', None)
4146
self.register_parameter('bias_log_sigma', None)
@@ -63,10 +68,14 @@ def reset_parameters(self):
6368
# self.bias_log_sigma.data.fill_(self.prior_log_sigma)
6469

6570
def forward(self, input):
66-
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * torch.randn_like(self.weight_log_sigma)
71+
if not self.freeze :
72+
self.weight_eps = torch.randn_like(self.weight_log_sigma)
73+
weight = self.weight_mu + torch.exp(self.weight_log_sigma) * self.weight_eps
6774

6875
if self.bias:
69-
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * torch.randn_like(self.bias_log_sigma)
76+
if not self.freeze :
77+
self.bias_eps = torch.randn_like(self.bias_log_sigma)
78+
bias = self.bias_mu + torch.exp(self.bias_log_sigma) * self.bias_eps
7079
else :
7180
bias = None
7281

‎torchbnn/utils/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .convert_model import bayes_to_nonbayes, nonbayes_to_bayes
1+
from .convert_model import bayes_to_nonbayes, nonbayes_to_bayes
2+
from .freeze_model import freeze, unfreeze

‎torchbnn/utils/freeze_model.py‎

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from ..modules import *
4+
5+
bayes_layer = [BayesLinear, BayesConv2d, BayesBatchNorm2d]
6+
7+
"""
8+
Methods for freezing bayesian-layer.
9+
10+
Arguments:
11+
layer (nn.Module): a layer to be freezed.
12+
13+
"""
14+
15+
def _freeze(layer):
16+
for inst in bayes_layer :
17+
if isinstance(layer, inst) :
18+
layer.freeze = True
19+
else :
20+
continue
21+
22+
def _unfreeze(layer):
23+
for inst in bayes_layer :
24+
if isinstance(layer, inst) :
25+
layer.freeze = False
26+
else :
27+
continue
28+
return layer
29+
30+
"""
31+
Methods for freezing bayesian-model.
32+
33+
Arguments:
34+
model (nn.Module): a model to be freezed.
35+
36+
"""
37+
38+
def freeze(model):
39+
for name, m in model.named_children() :
40+
if isinstance(m, nn.Sequential) :
41+
for layer in m :
42+
_freeze(layer)
43+
else :
44+
_freeze(layer)
45+
46+
def unfreeze(model):
47+
for name, m in model.named_children() :
48+
if isinstance(m, nn.Sequential) :
49+
for layer in m :
50+
_unfreeze(layer)
51+
else :
52+
_unfreeze(layer)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /