-
Notifications
You must be signed in to change notification settings - Fork 19
Removed module
arg from fsdp_pre_all_gather
#217
Conversation
[ghstack-poisoned]
@@ -232,7 +232,7 @@ def cast_w_to_float8( | |||
self.fp8_scale_w, | |||
scale_fn_name, | |||
torch.float8_e4m3fn, | |||
is_amax_initialized, | |||
self.is_amax_initialized, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to push this is_amax_initialized
into the cast function since the Float8LinearWeightTensor
subclass cannot keep a reference to the bool.
return o | ||
|
||
with torch._C.DisableTorchFunctionSubclass(): | ||
if isinstance(args[0], cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The heuristic here is just to propagate the Float8LinearWeightTensor
as long as it is the first argument.
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
@@ -142,12 +147,34 @@ def from_float( | |||
|
|||
|
|||
class Float8DynamicLinearWeightTensor(torch.Tensor): | |||
# TODO: Remove `module` arg, save state on subclass, and propagate it. | |||
def fsdp_pre_all_gather( | |||
self, module: nn.Module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem statement is that in order to implement the pre-all-gather transform, the subclass needs some state (e.g. mainly emulate but preferably the cast function as well). In the first prototype, I shortcutted by passing module
into fsdp_pre_all_gather()
so that the transform could read state off the module instead of storing it on the subclass itself.
However, the morally right thing (from a design perspective) should be to put that state on the subclass. I wonder though, how does this kind of __torch_function__
subclass interact with torch.compile
today. cc: @bdhirsh
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
module
arg fromfsdp_pre_all_gather
#217use_activation_hooks: bool
to swap #214amax_and_scale_synced
unconditionally #220On the discussion of whether fp8 all-gather is viable without compiling the pre-all-gather cast to fp8, this PR would add more CPU overhead due to the
__torch_function__
override, making it less viable.