-
-
Notifications
You must be signed in to change notification settings - Fork 161
-
It would be nice to have an ability to restore from automatic training checkpoint (.ckpt file). Apparent it is not possible now:
os.listdir('saved_models')
['classification-14_epoch=4-valid_loss=0.51.ckpt']
tabular_model=TabularModel.load_from_checkpoint(dir='saved_models')
FileNotFoundError Traceback (most recent call last)
in ()
----> 1 tabular_model=TabularModel.load_from_checkpoint(dir='saved_models')
1 frames
/usr/local/lib/python3.7/dist-packages/omegaconf/omegaconf.py in load(file_)
181
182 if isinstance(file_, (str, pathlib.Path)):
--> 183 with io.open(os.path.abspath(file_), "r", encoding="utf-8") as f:
184 obj = yaml.load(f, Loader=get_yaml_loader())
185 elif getattr(file_, "read", None):
FileNotFoundError: [Errno 2] No such file or directory: '/content/mnt/My Drive/projects/tabular/saved_models/config.yml'
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 6 comments
-
This thing is to initialize the TabularModel, we need to have a few more things, like the tabular data module, etc to maintain the data transformations etc.
But the api can be added to load from a saved checkpoint if you have a TabularModel already initialized. In this case, just the weights will be loaded again.
As a temporary measure, you can use the code here to load from checkpoint.. The PyTorch Model is saved under tabular_model.model and you can use pl_load from PyTochLightning to load the weights from a specific checkpoint.
Beta Was this translation helpful? Give feedback.
All reactions
-
Upcoming update in PyTorchTabular has a lot of ways you can do this.
- There is a
load_weightsfunction which loads weights from a saved checkpoint (either a PyTorch Lightning checkpoint or a pure PyTorch checkpoint) - You can pass the path to the state_dict or checkpoint in the
TabularModelconstructor so that it starts with the saved weights - It also will have a more granular API (apart from .fit) to give more flexibility in the modelling process.
Beta Was this translation helpful? Give feedback.
All reactions
-
I am trying to load the SSL model in a separate script which has been trained following the tutorial 08-Self-SupervisedLearning-DAE.ipynb. With the load_weights function I am getting error AttributeError: 'TabularModel' object has no attribute 'model'.
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
@kumar4372 @manujosephv
Same problem here.
You can't just initialize the TabularModel first, for example from configuration files:
tabular_model = TabularModel( data_config="DataConfig.yaml", model_config="ModelConfig.yaml", optimizer_config="OptimizerConfig.yaml", trainer_config="TrainerConfig.yaml", )
And then load from checkpoint:
model = tabular_model.load_weights("best_model.ckpt")
You will get the error:
AttributeError: 'TabularModel' object has no attribute 'model'
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
As far as I understood, load_weights is method of TabularModel class and you need to fully prepare TabularModel instance for loading weights. You need to create datamode, model itself, create a trainer. The easiest way to do this by calling protected method _prepare_for_training and pass created datamode and model to it:
tabular_model = TabularModel( data_config="DataConfig.yaml", model_config="ModelConfig.yaml", optimizer_config="OptimizerConfig.yaml", trainer_config="TrainerConfig.yaml", ) datamodule = tabular_model.prepare_dataloader( train=train, validation=validation, ) model = tabular_model.prepare_model( datamodule ) tabular_model._prepare_for_training( model, datamodule )
After that load_weights working fine:
tabular_model.load_weights("best_model.ckpt")
At least it is possible to call evaluate or predict.
I didn't find a detailed example of using load_weights in the documentation.
I think it may be relevant when for example after model_sweep you are interested not only in the best model , but several best models (for all other models you have only their checkpoint, but no TabularModel instance).
Beta Was this translation helpful? Give feedback.
All reactions
-
Let me convert this to a discussion. As a closed issue, I feel this will not get enough visibility. And also, if you think there is a need for some change in the code, feel free to raise a new issue so that we can track it.
Beta Was this translation helpful? Give feedback.