Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add a compiled inner func to sync amac #221

Closed
wants to merge 5 commits into from
Closed

Add a compiled inner func to sync amac #221

wants to merge 5 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 16, 2024

Summary

See this comment: #220 (comment)

I still need to verify that this works, I also want to add a test for when we compile the outer func.. not sure what it happends
cc @awgu, @bdhirsh

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 16, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Feb 16, 2024

When I log graph breaks I get:

[DEBUG]   File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
[rank7]:[2024-02-16 10:08:07,178] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     raise Unsupported(msg)
[rank7]:[2024-02-16 10:08:07,178] [0/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] torch._dynamo.exc.Unsupported: call_function args: NestedUserFunctionVariable() ```

@drisspg
Copy link
Contributor Author

drisspg commented Feb 16, 2024

Actually perf looks fine

float8_mod = swap_linear_with_float8_linear(module, Float8DynamicLinear)
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
compiled_swap_func(float8_mod)
assert cnts.frame_count == 0, "Compiled graph should have 1 frame!"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intuitive to me. I was curious why the frame count is not 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not intuitive to me either and actually this test is not really doing what I thought it was doing, I talked to Brian yesterday. What I think I think is happening is that the outer compile which I am checking sees the inner compiled region and just backtracks out not compiling anything. So this test doesn't really do anything..

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
# This allows for the compile to succede on the inner func and fail on the graph breaks
# at the beginning and and of syncing
inner_func()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little hard to tell but I think this is all code motion? (moving it inside of inner_func())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah tabs strike again here, but just indented one layer

# 4. set a flag to signal amaxes/scales are ready
# We only update the flag if we know it will be checked by the modules
if fp8_config.enable_amax_init:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the thing where you mentioned that this config is actually needed for correctness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was andrews original comment here: #220 (comment)

I used the wrong config setting

f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))
def inner_func():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment on why we want the inner fn? (we have graph breaks that we want to fix eventually, but in the meantime we want to ensure that everything in inner_func is compiled)

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in b508920.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants