-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] GEMM Hopper Tensor Core Integration Test #81478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
8548c41
[mlir] GEMM Hopper Tensor Core Integration Test
grypp 379fa54
format with yapf
grypp 39f7213
format it with black
grypp 9b65ffd
fix the spelling mistake
grypp a31f9b1
address comments
grypp 777f208
format
grypp 0934dcb
Allow multiple stages, and fix the kernels.
grypp 861fe27
format
grypp 89370ef
Add asserts 128x128x64
grypp 8f7efd6
Address comments
grypp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests: | ||
config.unsupported = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,341 @@ | ||
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ | ||
# RUN: %PYTHON %s | FileCheck %s | ||
|
||
|
||
# ===--- GEMM Hopper Tensor Core Integration Test ---=== | ||
# | ||
# This test aims to validate the correctness of the supported GEMM kernels in | ||
# NVGPU dialects, with current support for Multistage and Warp Specialization | ||
# kernels. | ||
# The test constructs and metaprograms IR using Python bindings, allowing | ||
# generic IR building. This flexibility enables changes to the shape, | ||
# tile size, or data type of the GEMM for testing purposes. | ||
# The entry function is `matmul`, where one can specify GEMM shape, tile size, | ||
# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum | ||
# number of stages. | ||
# Verification is done via numpy's matmul operation. | ||
# | ||
# Example: | ||
# matmul(input_type=np.float16, # input types | ||
# output_type=np.float32, # output type | ||
# M=4096, N=4096, K=4096, # Shape | ||
# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size | ||
# use_warp_specialization=True, # Enable Warp Specialization | ||
# max_num_stages=3) # Number of stages in shared memory | ||
# | ||
# ===--- Parallelism Across CTAs ---=== | ||
# | ||
# GEMM includes three loops defining the shape of the GEMM, specified in the | ||
# `matmul` function. | ||
# The program builds IR using the following loop structure, tiling the loops | ||
# with the given tile size and parallelizing the two outermost loops into the | ||
# first and second dimensions of CTAs. | ||
# | ||
# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x | ||
# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y | ||
# for(bk = 0; k < K; K += BLOCK_K) | ||
# for(i = bi; i < (bi + BLOCK_M); ++i) | ||
# for(j = bj; j < (bj + BLOCK_N); ++j) | ||
# for(k = bk; k < (bk + BLOCK_K); ++k) | ||
# | ||
# ===--- Multistage Kernel ---=== | ||
# | ||
# This kernel launches a single warp group (128 threads). The primary thread | ||
# (pthread) requests load from TMA. Threads collectively wait for the data and | ||
# perform mma operations. After completing the shape, threads together store | ||
# first fragmented registers to shared memory, then from shared memory to global | ||
# memory; this part is called the epilogue. | ||
# | ||
# Execution Timeline of Multistage Kernel with 3 stages: | ||
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ | ||
# | |Prologue ----> |MainLoop ----> |Epilogue | | ||
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ | ||
# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]| | ||
# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]| | ||
# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+ | ||
# | ||
# ===--- Warp Specialization Kernel ---=== | ||
# | ||
# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one | ||
# as `producer warp group` and another as `consumer warp group`. The | ||
# `producer warp group` is responsible for requesting TMA load, while the | ||
# `consumer warp group` performs the mma operation. The epilogue section is | ||
# handled by the `consumer warp group` as its threads own the fragmented registers. | ||
# | ||
# Execution Timeline of Warp Specialization Kernel with 2 stages: | ||
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ | ||
# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue | | ||
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ | ||
# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] | | ||
# |wgroup1 | .......| | | | | | [shmem->global] | | ||
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ | ||
# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]| | ||
# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+ | ||
|
||
import errno | ||
import numpy as np | ||
import subprocess | ||
import ctypes | ||
from tools import nvgpucompiler | ||
from tools import matmulBuilder | ||
import contextlib | ||
import os | ||
import sys | ||
import pathlib | ||
import ctypes | ||
from mlir import runtime as rt | ||
|
||
|
||
def generate_matmul( | ||
input_type=np.float16, | ||
output_type=np.float32, | ||
M=4096, | ||
N=4096, | ||
K=4096, | ||
BLOCK_M=128, | ||
BLOCK_N=128, | ||
BLOCK_K=64, | ||
use_warp_specialization=True, | ||
saveIR=False, | ||
max_num_stages=3, | ||
options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3", | ||
): | ||
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(): | ||
if use_warp_specialization: | ||
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws( | ||
input_type, | ||
output_type, | ||
M, | ||
N, | ||
K, | ||
BLOCK_M, | ||
BLOCK_N, | ||
BLOCK_K, | ||
max_num_stages, | ||
) | ||
else: | ||
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage( | ||
input_type, | ||
output_type, | ||
M, | ||
N, | ||
K, | ||
BLOCK_M, | ||
BLOCK_N, | ||
BLOCK_K, | ||
max_num_stages, | ||
) | ||
|
||
mlir_nvgpu_module.operation.verify() | ||
|
||
# Save generated IR | ||
if saveIR: | ||
# print(mlir_nvgpu_module) | ||
original_stdout = sys.stdout | ||
with open("gemm.mlir", "w") as f: | ||
sys.stdout = f | ||
print(mlir_nvgpu_module) | ||
sys.stdout = original_stdout | ||
|
||
# Get compiler | ||
support_lib = os.getenv("SUPPORT_LIB") | ||
if not os.path.exists(support_lib): | ||
raise FileNotFoundError( | ||
errno.ENOENT, os.strerror(errno.ENOENT), support_lib | ||
) | ||
compiler = nvgpucompiler.NvgpuCompiler( | ||
options, opt_level=3, shared_libs=[support_lib] | ||
) | ||
|
||
# Compile | ||
engine = compiler.compile_and_jit(mlir_nvgpu_module) | ||
return engine | ||
|
||
|
||
def matmul( | ||
input_type=np.float16, | ||
output_type=np.float32, | ||
M=128, | ||
N=128, | ||
K=128, | ||
BLOCK_M=128, | ||
BLOCK_N=128, | ||
BLOCK_K=64, | ||
use_warp_specialization=True, | ||
saveIR=False, | ||
max_num_stages=3, | ||
print_results=False, | ||
no_verify=False, | ||
): | ||
# Print the configuration | ||
required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N) | ||
num_stages = min(required_stages, max_num_stages) | ||
ity = "f16" if input_type == np.float16 else "f32" | ||
oty = "f16" if output_type == np.float16 else "f32" | ||
gemmty = "Warp specialization" if use_warp_specialization else "Multistage" | ||
print( | ||
"===-- Running GEMM " | ||
+ gemmty | ||
+ " " | ||
+ oty | ||
+ " += " | ||
+ ity | ||
+ " * " | ||
+ ity | ||
+ ", Size " | ||
+ str(M) | ||
+ "x" | ||
+ str(N) | ||
+ "x" | ||
+ str(K) | ||
+ ", Tile " | ||
+ str(BLOCK_M) | ||
+ "x" | ||
+ str(BLOCK_N) | ||
+ "x" | ||
+ str(BLOCK_K) | ||
+ ", stages " | ||
+ str(num_stages) | ||
+ " --===" | ||
) | ||
|
||
# Build IR and compile | ||
engine = generate_matmul( | ||
input_type, | ||
output_type, | ||
M, | ||
N, | ||
K, | ||
BLOCK_M, | ||
BLOCK_N, | ||
BLOCK_K, | ||
use_warp_specialization, | ||
saveIR, | ||
num_stages, | ||
) | ||
|
||
# Allocate matrices and invoke the matmul | ||
c = np.zeros((M, N), output_type) | ||
a = np.random.randn(M, K).astype(input_type) | ||
b = np.random.randn(K, N).astype(input_type) | ||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) | ||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) | ||
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) | ||
kernelName = matmulBuilder.make_kernel_name( | ||
input_type, | ||
output_type, | ||
M, | ||
N, | ||
K, | ||
BLOCK_M, | ||
BLOCK_N, | ||
BLOCK_K, | ||
num_stages, | ||
use_warp_specialization, | ||
) | ||
|
||
# Launch the MLIR generated kernel | ||
engine.invoke(kernelName, mem_a, mem_b, mem_c) | ||
|
||
float_formatter = "{:.2f}".format | ||
np.set_printoptions(formatter={"float_kind": float_formatter}) | ||
|
||
if print_results: | ||
print(c) | ||
|
||
# Verify the results | ||
if not no_verify: | ||
ref = a.astype(input_type) @ b.astype(input_type) | ||
if print_results: | ||
print(ref) | ||
np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01) | ||
|
||
print("PASS ") | ||
|
||
|
||
# Takes longer time to run | ||
def test_long(): | ||
for stages in range(1, 7): | ||
for M in [128, 512, 1024, 4096, 8192]: | ||
for N in [128, 512, 1024, 4096, 8192]: | ||
for K in [64, 128, 512, 1024, 4096, 8192]: | ||
matmul( | ||
np.float16, | ||
np.float32, | ||
M, | ||
N, | ||
K, | ||
max_num_stages=stages, | ||
use_warp_specialization=False, | ||
no_verify=True, | ||
) | ||
matmul( | ||
np.float16, | ||
np.float32, | ||
M, | ||
N, | ||
K, | ||
max_num_stages=stages, | ||
use_warp_specialization=True, | ||
) | ||
|
||
|
||
def test_short(): | ||
for stages in [1, 3]: | ||
for M in [128, 512]: | ||
for N in [128]: | ||
for K in [64, 256]: | ||
matmul( | ||
np.float16, | ||
np.float32, | ||
M, | ||
N, | ||
K, | ||
max_num_stages=stages, | ||
use_warp_specialization=False, | ||
) | ||
matmul( | ||
np.float16, | ||
np.float32, | ||
M, | ||
N, | ||
K, | ||
max_num_stages=stages, | ||
use_warp_specialization=True, | ||
) | ||
|
||
|
||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== | ||
# CHECK: PASS | ||
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --=== | ||
# CHECK: PASS | ||
|
||
test_short() |
3 changes: 3 additions & 0 deletions
3
mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Files in this directory are tools, not tests. | ||
config.unsupported = True | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.