Skip to content

[mlir][nvgpu] Fix transposeB in nvgpu.warpgroup.mma #79271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 25, 2024
Merged

Conversation

grypp
Copy link
Member

@grypp grypp commented Jan 24, 2024

The #76150 fixed meaning of transposeB in NVVM dialect which was initially implemented with opposite meaning.

This PR fixes the lowering of nvgpu.warpgroup.mma to NVVM dialect.

This will fix two integration tests:
gemm_f32_f16_f16_128x128x128.mlir
gemm_pred_f32_f16_f16_128x128x128.mlir

The llvm#76150 fixed meaning of `transposeB` in NVVM dialect which was initially implemented with opposite meaning.

This PR fixes the lowering of `nvgpu.warpgroup.mma` to NVVM dialect.

This will fix two integration tests:
gemm_f32_f16_f16_128x128x128.mlir
gemm_pred_f32_f16_f16_128x128x128.mlir
@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

The #76150 fixed meaning of transposeB in NVVM dialect which was initially implemented with opposite meaning.

This PR fixes the lowering of nvgpu.warpgroup.mma to NVVM dialect.

This will fix two integration tests:
gemm_f32_f16_f16_128x128x128.mlir
gemm_pred_f32_f16_f16_128x128x128.mlir


Full diff: https://github.com/llvm/llvm-project/pull/79271.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+1-1)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 43d05b872a4fbc8..5080956a4589828 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1407,7 +1407,7 @@ struct NVGPUWarpgroupMmaOpLowering
       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
       NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
-      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
 
       auto overflow = NVVM::MMAIntOverflowAttr::get(
           op->getContext(), NVVM::MMAIntOverflow::wrapped);

@grypp
Copy link
Member Author

grypp commented Jan 25, 2024

I am submitting this PR without reviews as it is fixes the sm_90 integration tests. The nightly tests did not catch them, because we don't test sm_90 targets yet.

@grypp grypp merged commit fa13c3e into llvm:main Jan 25, 2024
@grypp grypp deleted the fix-test branch January 25, 2024 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants