-
Notifications
You must be signed in to change notification settings - Fork 19
Add a compiled inner func to sync amac #221
Conversation
When I log graph breaks I get:
|
Actually perf looks fine |
test/test_compile.py
Outdated
float8_mod = swap_linear_with_float8_linear(module, Float8DynamicLinear) | ||
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) | ||
compiled_swap_func(float8_mod) | ||
assert cnts.frame_count == 0, "Compiled graph should have 1 frame!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not intuitive to me. I was curious why the frame count is not 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not intuitive to me either and actually this test is not really doing what I thought it was doing, I talked to Brian yesterday. What I think I think is happening is that the outer compile which I am checking sees the inner compiled region and just backtracks out not compiling anything. So this test doesn't really do anything..
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) | ||
# This allows for the compile to succede on the inner func and fail on the graph breaks | ||
# at the beginning and and of syncing | ||
inner_func() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a little hard to tell but I think this is all code motion? (moving it inside of inner_func())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah tabs strike again here, but just indented one layer
# 4. set a flag to signal amaxes/scales are ready | ||
# We only update the flag if we know it will be checked by the modules | ||
if fp8_config.enable_amax_init: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the thing where you mentioned that this config is actually needed for correctness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was andrews original comment here: #220 (comment)
I used the wrong config setting
f"All layers must have the same last seen input_dtype, got {x_dtypes}" | ||
) | ||
x_dtype = next(iter(x_dtypes)) | ||
def inner_func(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a comment on why we want the inner fn? (we have graph breaks that we want to fix eventually, but in the meantime we want to ensure that everything in inner_func is compiled)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
See this comment: #220 (comment)
I still need to verify that this works, I also want to add a test for when we compile the outer func.. not sure what it happends
cc @awgu, @bdhirsh