When embedding_dim exceeds 1024, the results of embedding_lookup become incorrect; however, FBGEMM doesn’t report it as unsupported until embedding_dim is greater than 2048. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=12345 --nnodes=1 --nproc-per-node=1 test_tw_shard.py,and use the environment torchrec==1.4.0+cu126, torch==2.9.0+cu126, fbgemm-gpu==1.4.0+cu126.
import os
from typing import Dict, cast
import torch
import torch.distributed as dist
import torchrec
import numpy as np
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
large_table_cnt = 2
small_table_cnt = 2
embedding_dim = 1280
# embedding_dim = 1024
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=embedding_dim,
num_embeddings=40960,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=embedding_dim,
num_embeddings=128,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(small_table_cnt)
]
def gen_constraints(
sharding_type: ShardingType = ShardingType.DATA_PARALLEL,
) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"large_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(large_table_cnt)
}
small_table_constraints = {
"small_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(small_table_cnt)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
@torch.fx.wrap
def _print_emb(feat: torch.Tensor, emb: torch.Tensor) -> None:
for i in range(emb.shape[0]):
print(f"row_{i}:", "feat:", feat[i], "emb:", emb[i], flush=True)
class DebugModel(nn.Module):
def __init__(self):
super().__init__()
self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
self.linear = nn.Linear(embedding_dim * (small_table_cnt + large_table_cnt), 1)
def forward(self, kjt: KeyedJaggedTensor):
emb = self.ebc(kjt)
_print_emb(kjt.to_dict()["large_table_feature_0"].values(), emb.to_dict()["large_table_feature_0"])
return torch.mean(self.linear(emb.values()))
rank = int(os.environ["RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
print("world_size:", world_size)
model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})
topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.TABLE_WISE)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)
sharded_model = DistributedModelParallel(
model,
plan=plan,
sharders=sharders,
device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])
print(f"rank:{rank},sharding plan: {plan}")
batch_size = 4096
kjt = KeyedJaggedTensor(
keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
+ ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
values=torch.cat([
torch.randint(0, 10, (batch_size * 2,))
, torch.randint(0, 10, (batch_size * 2,))]
),
lengths=torch.ones(batch_size * (small_table_cnt + large_table_cnt), dtype=torch.int32),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()
From the logs showing the correspondence between the printed feature values and the embeddings, it can be seen that when the sample row index exceeds 3263, zeros start appearing at the end of the embedding. When the row index exceeds 3295, the embedding becomes entirely zero, which is incorrect.
param | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
large_table_0 | table_wise | fused | [0]
large_table_1 | table_wise | fused | [0]
small_table_0 | table_wise | fused | [0]
small_table_1 | table_wise | fused | [0]
param | shard offsets | shard sizes | placement
------------- | ------------- | ------------- | -------------
large_table_0 | [0, 0] | [40960, 1280] | rank:0/cuda:0
large_table_1 | [0, 0] | [40960, 1280] | rank:0/cuda:0
small_table_0 | [0, 0] | [128, 1280] | rank:0/cuda:0
small_table_1 | [0, 0] | [128, 1280] | rank:0/cuda:0
row_0: feat: tensor(8, device='cuda:0') emb: tensor([-8.4739e-05, 1.7072e-04, 1.8372e-03, ..., 3.0118e-03,
4.7338e-04, 3.2164e-03], device='cuda:0', grad_fn=<SelectBackward0>)
row_1: feat: tensor(8, device='cuda:0') emb: tensor([-8.4739e-05, 1.7072e-04, 1.8372e-03, ..., 3.0118e-03,
4.7338e-04, 3.2164e-03], device='cuda:0', grad_fn=<SelectBackward0>)
row_2: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0030, 0.0031, 0.0031], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3: feat: tensor(0, device='cuda:0') emb: tensor([ 0.0007, -0.0018, 0.0049, ..., -0.0016, -0.0007, 0.0004],
device='cuda:0', grad_fn=<SelectBackward0>)
row_4: feat: tensor(3, device='cuda:0') emb: tensor([-0.0022, -0.0043, -0.0019, ..., 0.0009, -0.0046, -0.0012],
device='cuda:0', grad_fn=<SelectBackward0>)
row_5: feat: tensor(5, device='cuda:0') emb: tensor([-0.0012, 0.0030, 0.0022, ..., 0.0025, 0.0027, -0.0018],
device='cuda:0', grad_fn=<SelectBackward0>)
row_6: feat: tensor(8, device='cuda:0') emb: tensor([-8.4739e-05, 1.7072e-04, 1.8372e-03, ..., 3.0118e-03,
4.7338e-04, 3.2164e-03], device='cuda:0', grad_fn=<SelectBackward0>)
row_7: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0030, 0.0031, 0.0031], device='cuda:0',
grad_fn=<SelectBackward0>)
row_8: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0030, 0.0031, 0.0031], device='cuda:0',
grad_fn=<SelectBackward0>)
row_9: feat: tensor(7, device='cuda:0') emb: tensor([-1.4103e-03, -7.9132e-05, -1.3803e-03, ..., 1.9040e-03,
1.7904e-03, -2.7463e-03], device='cuda:0', grad_fn=<SelectBackward0>)
row_10: feat: tensor(6, device='cuda:0') emb: tensor([ 0.0026, 0.0001, -0.0009, ..., -0.0030, 0.0020, 0.0028],
device='cuda:0', grad_fn=<SelectBackward0>)
...
row_3263: feat: tensor(2, device='cuda:0') emb: tensor([-0.0023, -0.0007, -0.0030, ..., -0.0008, -0.0013, -0.0034],
device='cuda:0', grad_fn=<SelectBackward0>)
row_3264: feat: tensor(7, device='cuda:0') emb: tensor([-1.4103e-03, -7.9132e-05, -1.3803e-03, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0', grad_fn=<SelectBackward0>)
row_3265: feat: tensor(7, device='cuda:0') emb: tensor([-1.4103e-03, -7.9132e-05, -1.3803e-03, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0', grad_fn=<SelectBackward0>)
row_3266: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0000, 0.0000, 0.0000], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3267: feat: tensor(4, device='cuda:0') emb: tensor([ 0.0005, -0.0034, 0.0009, ..., 0.0000, 0.0000, 0.0000],
device='cuda:0', grad_fn=<SelectBackward0>)
row_3268: feat: tensor(3, device='cuda:0') emb: tensor([-0.0022, -0.0043, -0.0019, ..., 0.0000, 0.0000, 0.0000],
device='cuda:0', grad_fn=<SelectBackward0>)
row_3269: feat: tensor(4, device='cuda:0') emb: tensor([ 0.0005, -0.0034, 0.0009, ..., 0.0000, 0.0000, 0.0000],
device='cuda:0', grad_fn=<SelectBackward0>)
row_3270: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0000, 0.0000, 0.0000], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3271: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0000, 0.0000, 0.0000], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3272: feat: tensor(9, device='cuda:0') emb: tensor([0.0021, 0.0006, 0.0036, ..., 0.0000, 0.0000, 0.0000], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3273: feat: tensor(3, device='cuda:0') emb: tensor([-0.0022, -0.0043, -0.0019, ..., 0.0000, 0.0000, 0.0000],
device='cuda:0', grad_fn=<SelectBackward0>)
···
row_3295: feat: tensor(6, device='cuda:0') emb: tensor([ 0.0026, 0.0001, -0.0009, ..., 0.0000, 0.0000, 0.0000],
device='cuda:0', grad_fn=<SelectBackward0>)
row_3296: feat: tensor(6, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3297: feat: tensor(6, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3298: feat: tensor(7, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3299: feat: tensor(7, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3300: feat: tensor(7, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3301: feat: tensor(9, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3302: feat: tensor(1, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3303: feat: tensor(2, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3304: feat: tensor(4, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
row_3305: feat: tensor(5, device='cuda:0') emb: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0',
grad_fn=<SelectBackward0>)
When embedding_dim exceeds 1024, the results of embedding_lookup become incorrect; however, FBGEMM doesn’t report it as unsupported until embedding_dim is greater than 2048. We can reproduce this using the following command:
torchrun --master_addr=localhost --master_port=12345 --nnodes=1 --nproc-per-node=1 test_tw_shard.py,and use the environmenttorchrec==1.4.0+cu126, torch==2.9.0+cu126, fbgemm-gpu==1.4.0+cu126.From the logs showing the correspondence between the printed feature values and the embeddings, it can be seen that when the sample row index exceeds 3263, zeros start appearing at the end of the embedding. When the row index exceeds 3295, the embedding becomes entirely zero, which is incorrect.