-
Notifications
You must be signed in to change notification settings - Fork 19
Ensure DTensor wraps inner float8 Tensors #224
Conversation
…tance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) [rank1]: AssertionError
e7c2c95
to
bee5211
Compare
float8_experimental/float8_ops.py
Outdated
a = args[0] | ||
b = args[1] | ||
if isinstance(b, AsyncCollectiveTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure how this is creeping in here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit surprising, so this change fixed the AsyncCollectiveTensor issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, exactly, Im not totally sure yet how this collective made it into the dispatch for Float8Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah tbh this is the only part I don't love I am having a hard time figure out where this is creeping in..
This dispatch should be DTensor(Float8Tensor()) -> aten.mm.defualt
.... okay as I am writing this I think I migh tnow
I think we have
DTensor(Float8Tensor)), DTensor(AsyncTensor(Float8Tensor) -> aten.mm.defualt
The first layer of torch.dispatch goes through Dtensors table for both. Then we have unwrap
Float8Tensor(), Async(Float8Tensor) -> aten.mm.default
And the argparse chooses the leftmost subclass to dispatch too and thus we need this. I wonder if we can either handle this in DTensor or if there is another way?
I think this is becuase we iterate over args in order: https://github.com/pytorch/pytorch/blob/4328e772bf4d1b5e697a30d893d3b2d2e6a153c7/torch/csrc/utils/python_arg_parser.cpp#L273 (this is torch_function pointer though and I am not sure if applies to dispatch)
cc @bdhirsh if this makes sense as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If your Float8Tensor class expects any other subclass to run before you, you can add this to your class:
if not all(issubclass(cls, t) for t in types):
return NotImplemented
The NotImplemented will trigger the handler of the next type (just like it does for __add__
in regular python)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alban this makes alot of sense, just updated and everything appears to be working as expected!
1a57df3
to
77bc503
Compare
77bc503
to
48bd03a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty decent to me! have only one question about ACT and some minor suggestions
5f52303
to
93c8ac6
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This special cases the creation of Float8Tensor for when the input data and scale are DTensors. If thats the case we want to maintain the DTensor invariant that it is the outermost tensor that wraps Float8.
I add an example forw/backward which doesn't make the most sense and am currently getting this error:
Updated
I realized that I wasn't calling the float8 constructor in the backward that would be able to handle DTensor which lead to below for, reasons..
I have remedied that and I was able to get E2E working ( numerically correct not tested). This did requre that the mm_op had to manually wait be waited on in matmul.. too me this feels like bug in AsyncCollectiveTensor but need to track down.
Old errror
Why the Error?