You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Summary:
This special cases the creation of Float8Tensor for when the input data and scale are DTensors. If thats the case we want to maintain the DTensor invariant that it is the outermost tensor that wraps Float8.
I add an example forw/backward which doesn't make the most sense and am currently getting this error:
## Updated
I realized that I wasn't calling the float8 constructor in the backward that would be able to handle DTensor which lead to below for, reasons..
I have remedied that and I was able to get E2E working ( numerically correct not tested). This did requre that the mm_op had to manually wait be waited on in matmul.. too me this feels like bug in AsyncCollectiveTensor but need to track down.
##### Old errror
```
[rank1]: File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_tensor.py", line 221, in __torch_dispatch__
[rank1]: return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank1]: File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_ops.py", line 81, in float8_mm
[rank1]: assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor), (
[rank1]: AssertionError: Expecting both Float8Tensor for mm inputs but found <class 'float8_experimental.float8_tensor.Float8Tensor'> and <class 'torch.distributed._tensor.api.DTensor'>
E0222 16:10:18.512000 140692135212864 torch/distri
```
Why the Error?
1. The output of the scaled_mm is a regular dtensor(torch.Tensor), (converts two DTensors(Float8Tensors) -> DTensor(Float8Tensor)
2. We send this output through the NoopForward which will do nothing in the forward and convert the grad to a Float8Tensor of the e5m2 dtype.
3. This creation of the Float8Tensor from the backpropping Dtensor will hit the special logic in the Float8Tensor construction, that makes sure DTensor re-wraps the Float8Tensor.
Pull Request resolved: #224
Reviewed By: bdhirsh
Differential Revision: D54204167
Pulled By: drisspg
fbshipit-source-id: 9c3c3ccb3cae8b90f5ab5c61fc0e7b96d89176d3
0 commit comments