Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR fixes a parity issue with Google's Gemma models by moving the addition of the unit offset to be after the dtype conversion.
Motivation and Context
The Gemma models from Hugging Face are loaded with
torch.bfloat16
precision by default. When a unit add is performed on atorch.bfloat16
tensor, the following behavior occurs.Example:
The value returned is
4.1250
instead of4.1406
and it will remain this value even when the returned value is converted totorch.float16
ortorch.float32
.If the unit add is performed after the dtype conversion, the value returned is the expected value.
When comparing the LayerNorm weights from the Hugging Face Gemma 2B model and the GGUF Gemma 2B model produced by
convert-hf-to-gguf.py
before this change, the tensor values are different. Each tensor below is of size 2048 and infloat16
precision.After converting the GGUF model to ONNX and running a parity test with ONNX Runtime, ORT reports a parity mismatch for both prompt processing and token generation.
When comparing the LayerNorm weights from the Hugging Face Gemma 2B model and the GGUF Gemma 2B model produced by
convert-hf-to-gguf.py
after this change, the tensor values are matching. Each tensor below is of size 2048 and infloat16
precision.After converting the new GGUF model to ONNX and running the same parity test with ONNX Runtime, ORT reports that parity is achieved.