Skip to content

feat(fp8): use fbgemm kernels and load fp8 weights directly #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 20, 2024

Conversation

OlivierDehaene
Copy link
Contributor

@danieldk, since you were the one that reworked the weights logic, do you think there is a better to plug the new fp8 weights in Transformers?

@@ -302,6 +302,9 @@ def get_model(
if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8":
# gemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danieldk is this compatible with the marlin kernels?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supports both float16 and bfloat16.

However do we want to set the default to this, since most models are float16? Is it needed for the fbgemm quantization kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's required for fbgemm.
I can add a method in layers/fp8.py to check whether we will use fbgemm and set the default appropriately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a check.

)

if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not happy about this.

Copy link
Member

@danieldk danieldk Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, this breaks the abstraction quite a bit. This class is also used by other (explicit) quantizers like eetq.

I think it would make more sense to put this implementation in something like a HybridFP8FP16Loader in the fp8 module. Then we could add some logic to the get_loader function, along the lines of: when quantizer==None is set, enumerate over the tensors (should be cheap I think for just getting the dtypes?) and then if an FP8 weight is encountered, return the hybrid loader.

That puts the implementation nicely with the fp8 code and it wouldn't clutter UnquantizedWeight further if we e.g. also want to support bitsandbytes or eetq checkpoints in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, modified the code to add another loader.

Comment on lines 65 to 151
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.input_scale, bias, self.dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Looks like a pattern we could reuse in the future for quantization that is pre-quantized or on the fly.

@@ -302,6 +302,9 @@ def get_model(
if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8":
# gemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supports both float16 and bfloat16.

However do we want to set the default to this, since most models are float16? Is it needed for the fbgemm quantization kernel?

)

if w.dtype == torch.float8_e4m3fn:
# FIXME: here to avoid circular import
from text_generation_server.layers.fp8 import Fp8Weight
Copy link
Member

@danieldk danieldk Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, this breaks the abstraction quite a bit. This class is also used by other (explicit) quantizers like eetq.

I think it would make more sense to put this implementation in something like a HybridFP8FP16Loader in the fp8 module. Then we could add some logic to the get_loader function, along the lines of: when quantizer==None is set, enumerate over the tensors (should be cheap I think for just getting the dtypes?) and then if an FP8 weight is encountered, return the hybrid loader.

That puts the implementation nicely with the fp8 code and it wouldn't clutter UnquantizedWeight further if we e.g. also want to support bitsandbytes or eetq checkpoints in the future.

Comment on lines 133 to 139
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed input_scale to input_scale_ub which is less ambiguous.

Copy link
Member

@danieldk danieldk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@OlivierDehaene OlivierDehaene merged commit 53ec0b7 into main Jul 20, 2024
9 checks passed
@OlivierDehaene OlivierDehaene deleted the feat/fp8_fbgemm branch July 20, 2024 17:02
ErikKaum pushed a commit that referenced this pull request Jul 25, 2024
* feat(fp8): add support for fbgemm

* allow loading fp8 weights directly

* update outlines

* fix makefile

* build fbgemm

* avoid circular import and fix dockerfile

* add default dtype

* refactored weights loader

* fix auto conversion

* fix quantization config parsing

* force new nccl on install

* missing get_weights implementation

* increase timeout
ErikKaum pushed a commit that referenced this pull request Jul 26, 2024
* feat(fp8): add support for fbgemm

* allow loading fp8 weights directly

* update outlines

* fix makefile

* build fbgemm

* avoid circular import and fix dockerfile

* add default dtype

* refactored weights loader

* fix auto conversion

* fix quantization config parsing

* force new nccl on install

* missing get_weights implementation

* increase timeout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants