-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat(fp8): use fbgemm kernels and load fp8 weights directly #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
657071c
to
7453b85
Compare
@@ -302,6 +302,9 @@ def get_model( | |||
if quantize in ["awq", "exl2", "gptq", "marlin"]: | |||
# These quantizers only work with float16 params. | |||
dtype = torch.float16 | |||
elif quantize == "fp8": | |||
# gemm kernels are fp8xfp8->bf16 | |||
dtype = torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danieldk is this compatible with the marlin kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supports both float16
and bfloat16
.
However do we want to set the default to this, since most models are float16
? Is it needed for the fbgemm quantization kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's required for fbgemm.
I can add a method in layers/fp8.py
to check whether we will use fbgemm and set the default appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a check.
) | ||
|
||
if w.dtype == torch.float8_e4m3fn: | ||
# FIXME: here to avoid circular import | ||
from text_generation_server.layers.fp8 import Fp8Weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not happy about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, this breaks the abstraction quite a bit. This class is also used by other (explicit) quantizers like eetq.
I think it would make more sense to put this implementation in something like a HybridFP8FP16Loader
in the fp8
module. Then we could add some logic to the get_loader
function, along the lines of: when quantizer==None
is set, enumerate over the tensors (should be cheap I think for just getting the dtypes?) and then if an FP8 weight is encountered, return the hybrid loader.
That puts the implementation nicely with the fp8 code and it wouldn't clutter UnquantizedWeight
further if we e.g. also want to support bitsandbytes or eetq checkpoints in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, modified the code to add another loader.
if self.weight_scale is None: | ||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) | ||
return get_fp8_linear().from_fp8( | ||
self.weight, self.weight_scale, self.input_scale, bias, self.dtype | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Looks like a pattern we could reuse in the future for quantization that is pre-quantized or on the fly.
@@ -302,6 +302,9 @@ def get_model( | |||
if quantize in ["awq", "exl2", "gptq", "marlin"]: | |||
# These quantizers only work with float16 params. | |||
dtype = torch.float16 | |||
elif quantize == "fp8": | |||
# gemm kernels are fp8xfp8->bf16 | |||
dtype = torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supports both float16
and bfloat16
.
However do we want to set the default to this, since most models are float16
? Is it needed for the fbgemm quantization kernel?
) | ||
|
||
if w.dtype == torch.float8_e4m3fn: | ||
# FIXME: here to avoid circular import | ||
from text_generation_server.layers.fp8 import Fp8Weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, this breaks the abstraction quite a bit. This class is also used by other (explicit) quantizers like eetq.
I think it would make more sense to put this implementation in something like a HybridFP8FP16Loader
in the fp8
module. Then we could add some logic to the get_loader
function, along the lines of: when quantizer==None
is set, enumerate over the tensors (should be cheap I think for just getting the dtypes?) and then if an FP8 weight is encountered, return the hybrid loader.
That puts the implementation nicely with the fp8 code and it wouldn't clutter UnquantizedWeight
further if we e.g. also want to support bitsandbytes or eetq checkpoints in the future.
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False) | ||
return Fp8Weight( | ||
weight=w, | ||
weight_scale=scale, | ||
input_scale=input_scale, | ||
dtype=weights.dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed input_scale
to input_scale_ub
which is less ambiguous.
af89ce0
to
5789139
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
* feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout
* feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout
@danieldk, since you were the one that reworked the weights logic, do you think there is a better to plug the new fp8 weights in Transformers?