2

I am using the following custom feature extractor for my StableBaselines3 model:

import torch.nn as nn
from stable_baselines3 import PPO
class Encoder(nn.Module):
 def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim=2):
 super(Encoder, self).__init__()
 self.encoder = nn.Sequential(
 nn.Linear(input_dim, embedding_dim),
 nn.ReLU()
 )
 self.regressor = nn.Sequential(
 nn.Linear(embedding_dim, hidden_dim),
 nn.ReLU(),
 )
 
 def forward(self, x):
 x = self.encoder(x)
 x = self.regressor(x)
 return x
 
model = Encoder(input_dim, embedding_dim, hidden_dim)
model.load_state_dict(torch.load('trained_model.pth'))
# Freeze all layers
for param in model.parameters():
 param.requires_grad = False
class CustomFeatureExtractor(BaseFeaturesExtractor):
 def __init__(self, observation_space, features_dim):
 super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
 self.model = model # Use the pre-trained model as the feature extractor
 self._features_dim = features_dim
 def forward(self, observations):
 features = self.model(observations)
 return features
policy_kwargs = {
 "features_extractor_class": CustomFeatureExtractor,
 "features_extractor_kwargs": {"features_dim": 64}
 }
 model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)

The model is trained well so far with no issues and good results. Now I want to not freeze the weights, and try to train the Feature Extractor as well starting from the initial pre-trained weight. How can I do that with such a custom Feature Extractor defined as a class inside another class? My feature extractor is not the same as the one defined in the documentation, so I am not sure if it will be trained. Or will it start training if I unfreeze the layers?

asked Jul 7, 2024 at 22:11

1 Answer 1

1
+50

UPDATED answer

Because your CustomFE imports already freezer Encoder (with requires_grad = False) you have that kind of situation where all weights of CustomFE are frozen. Thus by default CustomFE is not trainable. You will need to unfreeze it manually:


model = PPO("MlpPolicy", env='FrozenLake8x8', policy_kwargs=policy_kwargs)
# get model feature extractor
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
# convert all parameters to trainable
for name, param in feature_extr.named_parameters():
 param.requires_grad = True
# check parameters before training
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
 print(name, param.mean())
# train the model
model.learn(total_timesteps = 5)
# check parameters after training (if mean changed parameters are training)
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
 print(name, param.mean())
answered Jul 16, 2024 at 12:55
Sign up to request clarification or add additional context in comments.

7 Comments

Thank you for your answer, Johnny. Yes, I understand that. My question, however, was more about Stable Baselines3. Are we by running model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs), also updating the Feature Encoder or not? Or is it only training policy and value networks. Also, my Feature Extractor class is composed of 2 classes, not 1 like in the documentation: stable-baselines3.readthedocs.io/en/v1.0/guide/…. Is it a problem?
Simply calling ` model = PPO(...)` doesn't start the training procedure. And the PPO model inherits FE layers with their attributes. Later, during training everything that have gradients will be trained. My guess that you freezed FE layers before starting the training procedure and maybe evern earlier before merging Encoder into PPO. However this is not clear just from your snippet, because your snippet doesn't allow to reproduce training procedure that you are using. Number of nn.modules in FE doesn't matter, because the PPO model only cares about output dims (hidden_dims).
Understood, thank you for the follow-up! My concern was that CustomFeatureExtractor calls my pre-trained network like so: features = self.model(observations) so I thought it might not be trained as it would could be considered as a function. But this is not a problem?
By calling self.model(observations) you calculate gradients and thus later update weights. It is possible to do only inference if the snippet is within context with torch.no_grad():, but I doubt that this is your case. In any case if you'll provide code snippet with training I can check this as well.
Understood, it's much more clear now. After defining the model as above, for training I just run model.learn().
|

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.