Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to correct load weights in my model. #497

Answered by manujosephv
dpvargas asked this question in Q&A
Discussion options

Hi!

I'm trying to use a Pytorch Tabular Model inside a Federated Learning environment, so I need to obtain the weights from different nodes, aggregate them and load the new weights inside each node.

The extracting part is OK, but i am having trouble in the loading part. After I have defined my model and train it in each model with:

node_x_model = TabularModel(data_config=data_config, model_config=model_config, optimizer_config=OptimizerConfig(), trainer_config=trainer_config, experiment_config=experiment_config)
node_x_model.fit(train=trainset, seed = 42)

I am able to extract to weights correctly. After operate with these weights, I obtain a new weights that we'll call new_model, that i load with:
node_x_model.model.load_state_dict(new_model)

Up to this point, there is no problem, but when the next round start and each node has to train again, using the same sentence:

node_x_model.fit(train=trainset, seed = 42)

Each node obtained results are the same that the corresponding node obtained the last round. It seems that the function fit is overwriting the loaded weights. I have revised the API docs, but I have not found the solution to this. Does anyone know how to do this?

Thanks in advance.

You must be logged in to vote

I think you need to use the low-level API for this (its in the docs). When you call fit, we re-initialize the model (i,e, random weights).

In the Low Level API, you have access to the model and you can load weights to the model and just "train" the model.

Replies: 1 comment 1 reply

Comment options

I think you need to use the low-level API for this (its in the docs). When you call fit, we re-initialize the model (i,e, random weights).

In the Low Level API, you have access to the model and you can load weights to the model and just "train" the model.

You must be logged in to vote
1 reply
Comment options

I've already solve the problem using this way, but it's good that this keeps saved in the discussion for other people. Thanks for the answer!

Answer selected by dpvargas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /