Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Ensure DTensor wraps inner float8 Tensors #224

Closed
wants to merge 8 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 22, 2024

Summary

This special cases the creation of Float8Tensor for when the input data and scale are DTensors. If thats the case we want to maintain the DTensor invariant that it is the outermost tensor that wraps Float8.

I add an example forw/backward which doesn't make the most sense and am currently getting this error:

Updated

I realized that I wasn't calling the float8 constructor in the backward that would be able to handle DTensor which lead to below for, reasons..

I have remedied that and I was able to get E2E working ( numerically correct not tested). This did requre that the mm_op had to manually wait be waited on in matmul.. too me this feels like bug in AsyncCollectiveTensor but need to track down.

Old errror
[rank1]:   File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_tensor.py", line 221, in __torch_dispatch__
[rank1]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank1]:   File "/home/drisspg/meta/float8_experimental/float8_experimental/float8_ops.py", line 81, in float8_mm
[rank1]:     assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor), (
[rank1]: AssertionError: Expecting  both Float8Tensor for mm inputs but found <class 'float8_experimental.float8_tensor.Float8Tensor'> and <class 'torch.distributed._tensor.api.DTensor'>
E0222 16:10:18.512000 140692135212864 torch/distri

Why the Error?

  1. The output of the scaled_mm is a regular dtensor(torch.Tensor), (converts two DTensors(Float8Tensors) -> DTensor(Float8Tensor)
  2. We send this output through the NoopForward which will do nothing in the forward and convert the grad to a Float8Tensor of the e5m2 dtype.
  3. This creation of the Float8Tensor from the backpropping Dtensor will hit the special logic in the Float8Tensor construction, that makes sure DTensor re-wraps the Float8Tensor.

@drisspg drisspg requested a review from wanchaol February 22, 2024 02:18
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 22, 2024
@drisspg drisspg mentioned this pull request Feb 22, 2024
5 tasks
@drisspg drisspg changed the title Insure DTensor wraps inner float8 Tensors Ensure DTensor wraps inner float8 Tensors Feb 22, 2024
…tance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)

[rank1]: AssertionError
@drisspg drisspg force-pushed the special_handiling_of_dtensor_input branch from e7c2c95 to bee5211 Compare February 23, 2024 00:14
a = args[0]
b = args[1]
if isinstance(b, AsyncCollectiveTensor):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how this is creeping in here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit surprising, so this change fixed the AsyncCollectiveTensor issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, exactly, Im not totally sure yet how this collective made it into the dispatch for Float8Tensor

Copy link
Contributor Author

@drisspg drisspg Feb 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah tbh this is the only part I don't love I am having a hard time figure out where this is creeping in..

This dispatch should be DTensor(Float8Tensor()) -> aten.mm.defualt
.... okay as I am writing this I think I migh tnow

I think we have

DTensor(Float8Tensor)), DTensor(AsyncTensor(Float8Tensor) -> aten.mm.defualt

The first layer of torch.dispatch goes through Dtensors table for both. Then we have unwrap

Float8Tensor(), Async(Float8Tensor) -> aten.mm.default

And the argparse chooses the leftmost subclass to dispatch too and thus we need this. I wonder if we can either handle this in DTensor or if there is another way?

I think this is becuase we iterate over args in order: https://github.com/pytorch/pytorch/blob/4328e772bf4d1b5e697a30d893d3b2d2e6a153c7/torch/csrc/utils/python_arg_parser.cpp#L273 (this is torch_function pointer though and I am not sure if applies to dispatch)

cc @bdhirsh if this makes sense as well

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If your Float8Tensor class expects any other subclass to run before you, you can add this to your class:

        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

The NotImplemented will trigger the handler of the next type (just like it does for __add__ in regular python)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alban this makes alot of sense, just updated and everything appears to be working as expected!

@drisspg drisspg force-pushed the special_handiling_of_dtensor_input branch from 1a57df3 to 77bc503 Compare February 23, 2024 23:22
@drisspg drisspg force-pushed the special_handiling_of_dtensor_input branch from 77bc503 to 48bd03a Compare February 23, 2024 23:23
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty decent to me! have only one question about ACT and some minor suggestions

@drisspg drisspg marked this pull request as ready for review February 24, 2024 00:23
@drisspg drisspg force-pushed the special_handiling_of_dtensor_input branch from 5f52303 to 93c8ac6 Compare February 26, 2024 17:35
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in b67e5cf.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants