-
Notifications
You must be signed in to change notification settings - Fork 564
Open
@rayhuang90
Description
Hi there, the ManagedCollisionEmbeddingCollection with multiple tables + shared features returns all-zero embeddings after applying apply_optimizer_in_backward
with RowWiseAdagrad. This result is unexpected.
This bug likely relates to the initialization and updating of the RowWiseAdagrad state and associated embeddings during eviction events in ManagedCollisionEmbeddingCollection.
Below is a minimal reproducible Python code example:
torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_rowrisegrad_bug.py
Unexpected Result
[RANK0] emb_result key: item_tag, jt: JaggedTensor({ [[[-0.00694586057215929, 0.005635389592498541, 0.029554935172200203, -0.014213510788977146, 0.027853110805153847, 0.023257633671164513, 0.004495333414524794, -0.01736217364668846]]] }) [RANK0] emb_result key: user_tag, jt: JaggedTensor({ [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]] # --> all-zero embeddings, bug }) [RANK0] emb_result key: item_id, jt: JaggedTensor({ [[[0.026882518082857132, -0.008349019102752209, 0.025774799287319183, 0.010714510455727577, 0.022058645263314247, -0.02674921043217182, 0.029537828639149666, 0.007071810774505138]]] }) [RANK0] remapped_ids: KeyedJaggedTensor({ "item_tag": [[997]], "user_tag": [[998]], "item_id": [[998]] })
My Current Environment
fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0
Metadata
Metadata
Assignees
Labels
No labels