You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Summary:
In certain cases, activations and gradients can have a bounded range.
For example, consider sigmoid -> fc -> ln -> sigmoid:
1. range of sigmoid in the forward is bounded, so we can scale
statically if we are ok with a slight accuracy drop in the case that
the observed values do not reach the theoretical bound
2. range of derivative of sigmoid is bounded
(https://math.stackexchange.com/questions/78575/derivative-of-sigmoid-function-sigma-x-frac11e-x)
3. derivative of LN (https://liorsinai.github.io/mathematics/2022/05/18/layernorm.html)
depends on the incoming gradient and the trainable LN parameters, so we can
derive a bound based on the incoming bound and calculating max of LN parameters
This PR adds static scaling as an option for x, w, dL_dY, and a quick
benchmark to verify performance is as we expect.
TODO add numerics testing.
Test Plan:
```
// baseline
python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid
...
experiment 0_ref 1_float8 f8_div_ref ref_div_f8
category
0_gemm 0.160 0.098 0.613 1.632
1_f8_overhead 0.000 0.100 inf 0.000
2_other 0.147 0.121 0.823 1.215
All 0.307 0.319 1.040 0.962
// static scaling for x (easier to justify numerics given a bounded activation such as sigmoid)
python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static
experiment 0_ref 1_float8 f8_div_ref ref_div_f8
category
0_gemm 0.665 0.362 0.545 1.834
1_f8_overhead 0.000 0.269 inf 0.000
2_other 0.396 0.273 0.689 1.452
All 1.061 0.904 0.853 1.173
// static scaling for x and dL_dY (handwaving for now, the actual code would
// need to read the LN params to get the max)
python benchmarks/profile_linear_float8.py ~/local/tmp/test --model_type sigmoid_linear_ln_sigmoid --scaling_type_x static --scaling_type_dL_dY static
...
experiment 0_ref 1_float8 f8_div_ref ref_div_f8
category
0_gemm 0.665 0.365 0.549 1.822
1_f8_overhead 0.000 0.242 inf 0.000
2_other 0.395 0.273 0.690 1.448
All 1.060 0.879 0.830 1.205
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: b8fd4e9
Pull Request resolved: #306
0 commit comments