Skip to content

Commit 4d33082

Browse files
authored
[mlir][nvgpu] NVGPU Tutorials (#87065)
I have a tutorial at EuroLLVM 2024 ([Zero to Hero: Programming Nvidia Hopper Tensor Core with MLIR's NVGPU Dialect](https://llvm.swoogo.com/2024eurollvm/session/2086997/zero-to-hero-programming-nvidia-hopper-tensor-core-with-mlir's-nvgpu-dialect)). For that, I implemented tutorial codes in Python. The focus is the nvgpu dialect and how to use its advanced features. I thought it might be useful to upstream this. The tutorial codes are as follows: - **Ch0.py:** Hello World - **Ch1.py:** 2D Saxpy - **Ch2.py:** 2D Saxpy using TMA - **Ch3.py:** GEMM 128x128x64 using Tensor Core and TMA - **Ch4.py:** Multistage performant GEMM using Tensor Core and TMA - **Ch5.py:** Warp Specialized GEMM using Tensor Core and TMA I might implement one more chapter: - **Ch6.py:** Warp Specialized Persistent ping-pong GEMM This PR also introduces the nvdsl class, making IR building in the tutorial easier.
1 parent 506c84a commit 4d33082

File tree

10 files changed

+1490
-0
lines changed

10 files changed

+1490
-0
lines changed

mlir/test/Examples/NVGPU/Ch0.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
# ===----------------------------------------------------------------------===//
5+
# Chapter 0 : Hello World
6+
# ===----------------------------------------------------------------------===//
7+
#
8+
# This program demonstrates Hello World:
9+
# 1. Build MLIR function with arguments
10+
# 2. Build MLIR GPU kernel
11+
# 3. Print from a GPU thread
12+
# 4. Pass arguments, JIT compile and run the MLIR function
13+
#
14+
# ===----------------------------------------------------------------------===//
15+
16+
17+
from mlir.dialects import gpu
18+
from tools.nvdsl import *
19+
20+
21+
# 1. The decorator generates a MLIR func.func.
22+
# Everything inside the Python function becomes the body of the func.
23+
# The decorator also translates `alpha` to an `index` type.
24+
@NVDSL.mlir_func
25+
def main(alpha):
26+
# 2. The decorator generates a MLIR gpu.launch.
27+
# Everything inside the Python function becomes the body of the gpu.launch.
28+
# This allows for late outlining of the GPU kernel, enabling optimizations
29+
# like constant folding from host to device.
30+
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(4, 1, 1))
31+
def kernel():
32+
tidx = gpu.thread_id(gpu.Dimension.x)
33+
# + operator generates arith.addi
34+
myValue = alpha + tidx
35+
# Print from a GPU thread
36+
gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
37+
38+
# 3. Call the GPU kernel
39+
kernel()
40+
41+
42+
alpha = 100
43+
# 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
44+
main(alpha)
45+
46+
47+
# CHECK: GPU thread 0 has 100
48+
# CHECK: GPU thread 1 has 101
49+
# CHECK: GPU thread 2 has 102
50+
# CHECK: GPU thread 3 has 103

mlir/test/Examples/NVGPU/Ch1.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
# ===----------------------------------------------------------------------===//
5+
# Chapter 1 : 2D Saxpy
6+
# ===----------------------------------------------------------------------===//
7+
#
8+
# This program demonstrates 2D Saxpy:
9+
# 1. Use GPU dialect to allocate and copy memory host to gpu and vice versa
10+
# 2. Computes 2D SAXPY kernel using operator overloading
11+
# 3. Pass numpy arrays to MLIR as memref arguments
12+
# 4. Verify MLIR program with reference computation in python
13+
#
14+
# ===----------------------------------------------------------------------===//
15+
16+
17+
from mlir import ir
18+
from mlir.dialects import gpu, memref
19+
from tools.nvdsl import *
20+
import numpy as np
21+
22+
23+
@NVDSL.mlir_func
24+
def saxpy(x, y, alpha):
25+
# 1. Use MLIR GPU dialect to allocate and copy memory
26+
token_ty = ir.Type.parse("!gpu.async.token")
27+
t1 = gpu.wait(token_ty, [])
28+
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
29+
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
30+
t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
31+
t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
32+
t6 = gpu.wait(token_ty, [t5])
33+
34+
# 2. Compute 2D SAXPY kernel
35+
@NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
36+
def saxpy_kernel():
37+
bidx = gpu.block_id(gpu.Dimension.x)
38+
tidx = gpu.thread_id(gpu.Dimension.x)
39+
x_val = memref.load(x_dev, [bidx, tidx])
40+
y_val = memref.load(y_dev, [bidx, tidx])
41+
42+
# SAXPY: y[i] += a * x[i];
43+
y_val += x_val * alpha
44+
45+
memref.store(y_val, y_dev, [bidx, tidx])
46+
47+
saxpy_kernel()
48+
49+
t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
50+
gpu.wait(token_ty, [t7])
51+
52+
53+
# 3. Pass numpy arrays to MLIR
54+
M = 256
55+
N = 32
56+
alpha = 2.0
57+
x = np.random.randn(M, N).astype(np.float32)
58+
y = np.ones((M, N), np.float32)
59+
saxpy(x, y, alpha)
60+
61+
# 4. Verify MLIR with reference computation
62+
ref = np.ones((M, N), np.float32)
63+
ref += x * alpha
64+
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
65+
print("PASS")
66+
# CHECK-NOT: Mismatched elements

mlir/test/Examples/NVGPU/Ch2.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
# ===----------------------------------------------------------------------===//
5+
# Chapter 2 : 2D Saxpy with TMA
6+
# ===----------------------------------------------------------------------===//
7+
#
8+
# This program demonstrates 2D Saxpy. It is same as Chapter 1,
9+
# but it loads data using TMA (Tensor Memory Accelerator)
10+
#
11+
# This chapter introduces demonstrates:
12+
# 1. Computes 2D SAXPY in the same way as Ch1.py but loads data using TMA
13+
# 2. Create and initialize 1 asynchronous transactional barrier (mbarrier)
14+
# 3. Thread-0 Load request data load from TMA for each thread block
15+
# 4. Each thread block loads <1x32xf32> for x and y.
16+
# 5. Wait for completion of TMA load with mbarrier
17+
#
18+
# ===----------------------------------------------------------------------===//
19+
20+
from mlir import ir
21+
from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
22+
from tools.nvdsl import *
23+
from mlir import runtime as rt
24+
from mlir.extras import types as T
25+
import numpy as np
26+
27+
28+
@NVDSL.mlir_func
29+
def saxpy(x, y, alpha):
30+
token_ty = ir.Type.parse("!gpu.async.token")
31+
t1 = gpu.wait(token_ty, [])
32+
x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
33+
y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
34+
t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
35+
t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
36+
t6 = gpu.wait(token_ty, [t5])
37+
38+
x_tma = TMA([1, N], x.type)
39+
y_tma = TMA([1, N], y.type)
40+
x_tma.create_descriptor(x_dev)
41+
y_tma.create_descriptor(y_dev)
42+
sz_x = get_type_size(x_tma.tma_memref)
43+
sz_y = get_type_size(x_tma.tma_memref)
44+
sz = sz_x + sz_y
45+
46+
@NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1), smem=sz)
47+
def saxpy_tma_kernel():
48+
bidx = gpu.block_id(gpu.Dimension.x)
49+
tidx = gpu.thread_id(gpu.Dimension.x)
50+
isThread0 = tidx == 0
51+
52+
# 1. Create and initialize asynchronous transactional barrier (mbarrier)
53+
mbar_group = Mbarriers(number_of_barriers=1)
54+
mbar_group[0].init(1, predicate=isThread0)
55+
56+
# 2. Execute Tensor Memory Accelerator (TMA) Load
57+
x_smem = get_dynamic_shared_memory([1, N], T.f32())
58+
y_smem = get_dynamic_shared_memory([1, N], T.f32(), offset=sz_x)
59+
x_tma.load(x_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
60+
y_tma.load(y_smem, mbar_group[0], coords=[0, bidx], predicate=isThread0)
61+
mbar_group[0].arrive(txcount=sz, predicate=isThread0)
62+
63+
# 3. Wait for completion of TMA load with mbarrier
64+
mbar_group[0].try_wait()
65+
66+
x_val = memref.load(x_smem, [const(0), tidx])
67+
y_val = memref.load(y_smem, [const(0), tidx])
68+
69+
# SAXPY: y[i] += a * x[i];
70+
y_val += x_val * alpha
71+
72+
memref.store(y_val, y_dev, [bidx, tidx])
73+
74+
saxpy_tma_kernel()
75+
76+
t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
77+
gpu.wait(token_ty, [t7])
78+
79+
80+
# 3. Pass numpy arrays to MLIR
81+
M = 256
82+
N = 32
83+
alpha = 2.0
84+
x = np.random.randn(M, N).astype(np.float32)
85+
y = np.ones((M, N), np.float32)
86+
saxpy(x, y, alpha)
87+
88+
# 4. Verify MLIR with reference computation
89+
ref = np.ones((M, N), np.float32)
90+
ref += x * alpha
91+
np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
92+
print("PASS")
93+
# CHECK-NOT: Mismatched elements

mlir/test/Examples/NVGPU/Ch3.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
2+
# RUN: %PYTHON %s | FileCheck %s
3+
4+
# ===----------------------------------------------------------------------===//
5+
# Chapter 3 : GEMM 128x128x64 with Tensor Core
6+
# ===----------------------------------------------------------------------===//
7+
#
8+
# This program demonstrates a GEMM operation with 128x128x64 matrix multiplication
9+
#
10+
# This chapter introduces demonstrates:
11+
# 1. Execute TMA Load for two input matrices
12+
# 2. Performs Tensor Core GEMM 128x128x64 by warpgroup
13+
# 3. Stores fragmented registers to global memory by warpgroup
14+
#
15+
# ===----------------------------------------------------------------------===//
16+
17+
18+
from mlir import ir
19+
from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
20+
from tools.nvdsl import *
21+
from mlir.extras import types as T
22+
import numpy as np
23+
24+
25+
def tma_load(
26+
mbar_group: Mbarriers,
27+
a_tma: TMA,
28+
b_tma: TMA,
29+
p,
30+
):
31+
"""
32+
TMA loads two input matrices from global memory to shared memory. It performs the following operations:
33+
34+
- tma.load a_shared_memory[0] at coordinate [0, 0] (Loads 128x64)
35+
- tma.load b_shared_memory[0] at coordinate [0, 0] (Loads 64x64)
36+
- tma.load b_shared_memory[0] at coordinate [64, 0] (Loads 64x64)
37+
38+
mbarrier.arrive ta_count = 128x64xf16 + 64x128xf16
39+
"""
40+
41+
size_tma_a = get_type_size(a_tma.tma_memref)
42+
size_tma_b = get_type_size(b_tma.tma_memref)
43+
ta_count = size_tma_a + (size_tma_b * 2)
44+
45+
off_b = size_tma_a
46+
off_b2 = off_b + size_tma_b
47+
a_elem_ty = a_tma.tma_memref.element_type
48+
b_elem_ty = b_tma.tma_memref.element_type
49+
a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty)
50+
b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
51+
b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
52+
53+
mbar_group[0].arrive(ta_count, predicate=p)
54+
55+
a_tma.load(a, mbar_group[0], coords=[0, 0], predicate=p)
56+
b_tma.load(b1, mbar_group[0], coords=[0, 0], predicate=p)
57+
b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
58+
59+
60+
@NVDSL.mlir_func
61+
def gemm_128_128_64(a, b, d):
62+
token_ty = ir.Type.parse("!gpu.async.token")
63+
t1 = gpu.wait(token_ty, [])
64+
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
65+
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
66+
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
67+
t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
68+
t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
69+
t7 = gpu.wait(token_ty, [t6])
70+
71+
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
72+
a_tma = TMA([128, 64], a.type, swizzle=sw)
73+
b_tma = TMA([64, 64], b.type, swizzle=sw)
74+
a_tma.create_descriptor(a_dev)
75+
b_tma.create_descriptor(b_dev)
76+
a_size = get_type_size(a.type)
77+
b_size = get_type_size(b.type)
78+
smem_size_in_bytes = a_size + b_size
79+
80+
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
81+
def gemm_tma_kernel():
82+
tidx = gpu.thread_id(gpu.Dimension.x)
83+
84+
mbar_group = Mbarriers(number_of_barriers=1)
85+
isThread0 = tidx == 0
86+
87+
mbar_group[0].init(1, predicate=isThread0)
88+
a_tma.prefetch(predicate=isThread0)
89+
b_tma.prefetch(predicate=isThread0)
90+
91+
a_smem = get_dynamic_shared_memory((M, K), T.f16())
92+
b_smem = get_dynamic_shared_memory((K, N), T.f16(), offset=a_size)
93+
94+
# 1. TMA Load for two input matrices
95+
tma_load(mbar_group, a_tma, b_tma, isThread0)
96+
97+
# 2. All threads wait TMA load completion
98+
mbar_group[0].try_wait()
99+
100+
# 3. Performs Tensor Core GEMM 128x128x64 by warpgroup
101+
A = WGMMAMatrix(WGMMAType.Descriptor, [M, K], desc=a_tma, smem=a_smem)
102+
B = WGMMAMatrix(WGMMAType.Descriptor, [K, N], desc=b_tma, smem=b_smem)
103+
D = WGMMAMatrix(WGMMAType.Accumulator, shape=[M, N], ty=T.f32())
104+
105+
# Matrix Multiply
106+
D += A @ B
107+
108+
# 4. Stores fragmented registers to global memory by warpgroup
109+
D.store_accumulator(d_dev)
110+
111+
gemm_tma_kernel()
112+
113+
t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
114+
gpu.wait(None, [t8])
115+
116+
117+
# Python pass arguments to MLIR
118+
M = 128
119+
N = 128
120+
K = 64
121+
a = np.random.randn(M, K).astype(np.float16)
122+
b = np.random.randn(K, N).astype(np.float16)
123+
d = np.zeros((M, N), np.float32)
124+
gemm_128_128_64(a, b, d)
125+
126+
ref_d = a.astype(np.float16) @ b.astype(np.float16)
127+
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
128+
print("PASS")
129+
# CHECK-NOT: Mismatched elements

0 commit comments

Comments
 (0)