Skip to content

Commit 26da27e

Browse files
committed
format with yapf
1 parent 72e6a31 commit 26da27e

File tree

3 files changed

+483
-238
lines changed

3 files changed

+483
-238
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
import ctypes
8686
from mlir import runtime as rt
8787

88+
8889
def generate_matmul(input_type=np.float16,
8990
output_type=np.float32,
9091
M=4096,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
9697
use_warp_specilization=True,
9798
saveIR=False,
9899
max_num_stages=3):
99-
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
100+
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
101+
):
100102
if use_warp_specilization:
101-
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
102-
BLOCK_K, max_num_stages)
103+
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
104+
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
105+
max_num_stages)
103106
else:
104-
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
105-
BLOCK_N, BLOCK_K, max_num_stages)
107+
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
108+
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
109+
max_num_stages)
106110

107111
mlir_nvgpu_module.operation.verify()
108-
112+
109113
# Save generated IR
110114
if saveIR:
111115
# print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
119123
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
120124
support_lib = os.getenv("SUPPORT_LIB")
121125
if not os.path.exists(support_lib):
122-
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
123-
compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
126+
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
127+
support_lib)
128+
compiler = nvgpucompiler.NvgpuCompiler(options,
129+
opt_level=3,
130+
shared_libs=[support_lib])
124131

125132
# Compile
126133
engine = compiler.compile_and_jit(mlir_nvgpu_module)
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
144151
ity = "f16" if input_type == np.float16 else "f32"
145152
oty = "f16" if output_type == np.float16 else "f32"
146153
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
147-
print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
148-
"x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
149-
str(max_num_stages) + " --===")
154+
print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
155+
ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
156+
str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
157+
", stages " + str(max_num_stages) + " --===")
150158

151159
# Build IR and compile
152-
engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
153-
saveIR, max_num_stages)
160+
engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
161+
BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
162+
max_num_stages)
154163

155164
# Allocate matrices and invoke the matmul
156165
c = np.zeros((M, N), output_type)
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
181190

182191

183192
# GEMM Multistage f32 += f16 * f16
184-
matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
193+
matmul(np.float16,
194+
np.float32,
195+
128,
196+
128,
197+
4096,
198+
max_num_stages=3,
199+
use_warp_specilization=False)
185200
# GEMM Warp Specilized f32 += f16 * f16
186-
matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
201+
matmul(np.float16,
202+
np.float32,
203+
256,
204+
1024,
205+
512,
206+
max_num_stages=3,
207+
use_warp_specilization=True)

0 commit comments

Comments
 (0)