-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
integrate SAM (segment anything) encoder with Unet #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hi @qubvel is there any update on this?
I've just trained a model using this branch and it worked.
@Rusteam how to train a model ,can u give some outlines?as author is not responding pls help me to train a model.. I have sent u an mail pls give a look
make sure you install this package from my fork pip instal git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam
and then initialize your model as usual create_model("SAM", "sam-vit_b", encoder_weights=None, **kwargs)
and run your training. You could pass weights="sa-1b"
in kwargs if you want to fine-tune from pre-trained weights.
So far I have been able to train the model, but I can't say it's learning. I'm still struggling there. Also I cannot fit more than 1 sample per batch on a 32gb gpu with a 512 input size.
ccl-private
commented
May 16, 2023
@Rusteam how about this: https://github.com/tianrun-chen/SAM-Adapter-PyTorch
thanks for sharing, I'll try it if my current approach does not work. I've able to get some learning with this transformers notebook
Hi @Rusteam, thanks a lot for your contribution and sorry for the delay, I am going to review the request and will let you know
Hey hey hey. While this solution worked I can't say the model was able to learn on my data. We might need to use the version before my ddp adjustments or make the model handle points and boxes as inputs, or use Sam image encoder with unet or other architectures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a pip package? probably need to add to reqs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just added it to reqs, or should we make it optional?
Yes, I was actually thinking about just pre-trained encoder integration, did you test it?
can we use this model to train on custom data??
@qubvel It didn't work with Unet yet, but I can make it work. Which models would be essential to integrate?
@Rusteam @qubvel can we use this model to train on custom data??
that was my intention as well, but I was unable to make it learn without passing box/point prompts. However, when passing a prompt along with input image, it does learn. We might need to integrate multiple inputs to forward()
call for it to work, or just use sam's image encoder with other arches like Unet
siddpiku
commented
Jul 5, 2023
The following worked for me:
-git clone the sam branch,
-modify the sam.py file like below to get rid of the errors:
-change def forward(self, x: torch.Tensor) -> list[torch.Tensor]: to def forward(self, x: torch.Tensor):
- import segmentation_models_pytorch as smp (python file in same folder as git clone branch)
- smp.create_model("Unet", "sam-vit_b", encoder_weights="sa-1b", encoder_depth=4, decoder_channels=[256, 128, 64, 32])
- Try training
What did not work - - For me, I tried fine tuning with 2 RTX A6000 GPU with batch size of 2 on the ACDC data (https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) but my Dice loss did not improve after 700 epochs. (Maybe some other setting works, but I did not have time to recreate it)
@qubvel hey any updates?
Rusab
commented
Sep 6, 2023
Please add this, this library hasn't have new features for a long time
This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.
csaroff
commented
Nov 17, 2023
Is this PR ready?
It's ready.
17SIM
commented
Nov 21, 2023
The current PR seems to work with image with the size of 1024x1024 only.
Yes, as the original Sam model
This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.
Stinosko
commented
Jan 28, 2024
Any progress on this?
Rusab
commented
Jan 29, 2024
Why is the library dying? no new updates in a long time
This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.
@qubvel can you merge this? It did work
isaaccorley
commented
Apr 9, 2024
giswqs
commented
Jan 18, 2025
A relevant PR: huggingface/transformers#32317
I think you can already do this because timm supports the SAM ViT weights like:
Unet("tu-samvit_base_patch16.sa1b")
But I'm not sure how well SAM works with U-Net instead of their own custom decoder.
isaaccorley
commented
Jan 18, 2025
Agreed, it's likely highly dependent on the prompt embeddings as well.
ogencoglu
commented
Jun 10, 2025
I don't think SAM works out of the box like this.
Uh oh!
There was an error while loading. Please reload this page.
Closes #756
Added:
vit_h
,vit_b
andvit_l
) to encodersChanged: