-
Notifications
You must be signed in to change notification settings - Fork 19
Enable restricted split + cat in order to enable SP #253
Conversation
return list(out) | ||
|
||
|
||
# Errors cant `cat_cuda float8 e4m3fn` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh so this means that the torch.cat
can't apply to dtype e4m3fn?
Normally I feel this is something that we can just make our cuda kernel to support concatting tensors with the same dtype, but not sure if there're further complications there for fp8 dtype.
But if the job we want to do is to simply concatting the fp8 tensors together, one simpler way to can do:
we can just try to do fp8_inner_tensor.view(torch.uint8)
, perform torch.cat
, then after the cat operation, we do fp8_catted_tensor.view(torch.float8_e4m3fn
, I wonder if this would unblock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm! thanks for supporting this!
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
74d7c9e
to
2623617
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2623617
to
2d480fd
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2d480fd
to
37363af
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
This comes from needing to support sequence parallelism in torchtitan