Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit cdb7867

Browse files
weifengpyfacebook-github-bot
authored andcommitted
set vocab_size=32 to avoid must be divisible by 16 error (#265)
Summary: `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic` ``` E File "/home/weif/local/pytorch-official/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 205, in forward E output = self.output(h).float() E File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl E return self._call_impl(*args, **kwargs) E File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl E return forward_call(*args, **kwargs) E File "/data/users/weif/float8_experimental/float8_experimental/float8_dynamic_linear.py", line 71, in forward E y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) E File "/data/users/weif/float8_experimental/float8_experimental/float8_tensor.py", line 297, in __torch_dispatch__ E return FLOAT8_OPS_TABLE[func](func, args, kwargs) E File "/data/users/weif/float8_experimental/float8_experimental/float8_ops.py", line 151, in float8_mm E tensor_out, amax = addmm_float8_unwrapped( E File "/data/users/weif/float8_experimental/float8_experimental/float8_python_api.py", line 55, in addmm_float8_unwrapped E output, output_amax = torch._scaled_mm( E RuntimeError: mat2 shape (768x8 must be divisible by 16 E Exception raised from _scaled_mm_out_cuda at /data/users/weif/pytorch-official/pytorch/aten/src/ATen/native/cuda/Blas.cpp:874 (most recent call first): ``` Pull Request resolved: #265 Reviewed By: drisspg, awgu Differential Revision: D57596582 Pulled By: weifengpy fbshipit-source-id: 8a00601457c4e72271adbba29dd2af8273173aa3
1 parent f7a920d commit cdb7867

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ def init_multi_module(self) -> nn.Module:
5757
def init_transformer(self, weight_tying: bool) -> nn.Module:
5858
torch.manual_seed(42)
5959
args = ModelArgs(
60-
n_layers=3, dim=768, n_heads=12, dropout_p=0.0, weight_tying=weight_tying
60+
n_layers=3,
61+
dim=768,
62+
n_heads=12,
63+
dropout_p=0.0,
64+
weight_tying=weight_tying,
65+
vocab_size=32,
6166
)
6267
module = Transformer(args).cuda()
6368
self.broadcast_module(module)

0 commit comments

Comments
 (0)