Skip to content

Commit 623e90f

Browse files
committed
format it with black
1 parent 26da27e commit 623e90f

File tree

3 files changed

+701
-463
lines changed

3 files changed

+701
-463
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py

Lines changed: 118 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -86,35 +86,52 @@
8686
from mlir import runtime as rt
8787

8888

89-
def generate_matmul(input_type=np.float16,
90-
output_type=np.float32,
91-
M=4096,
92-
N=4096,
93-
K=4096,
94-
BLOCK_M=128,
95-
BLOCK_N=128,
96-
BLOCK_K=64,
97-
use_warp_specilization=True,
98-
saveIR=False,
99-
max_num_stages=3):
100-
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
101-
):
89+
def generate_matmul(
90+
input_type=np.float16,
91+
output_type=np.float32,
92+
M=4096,
93+
N=4096,
94+
K=4096,
95+
BLOCK_M=128,
96+
BLOCK_N=128,
97+
BLOCK_K=64,
98+
use_warp_specilization=True,
99+
saveIR=False,
100+
max_num_stages=3,
101+
):
102+
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
102103
if use_warp_specilization:
103104
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
104-
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
105-
max_num_stages)
105+
input_type,
106+
output_type,
107+
M,
108+
N,
109+
K,
110+
BLOCK_M,
111+
BLOCK_N,
112+
BLOCK_K,
113+
max_num_stages,
114+
)
106115
else:
107116
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
108-
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
109-
max_num_stages)
117+
input_type,
118+
output_type,
119+
M,
120+
N,
121+
K,
122+
BLOCK_M,
123+
BLOCK_N,
124+
BLOCK_K,
125+
max_num_stages,
126+
)
110127

111128
mlir_nvgpu_module.operation.verify()
112129

113130
# Save generated IR
114131
if saveIR:
115132
# print(mlir_nvgpu_module)
116133
original_stdout = sys.stdout
117-
with open('gemm.mlir', 'w') as f:
134+
with open("gemm.mlir", "w") as f:
118135
sys.stdout = f
119136
print(mlir_nvgpu_module)
120137
sys.stdout = original_stdout
@@ -123,43 +140,77 @@ def generate_matmul(input_type=np.float16,
123140
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
124141
support_lib = os.getenv("SUPPORT_LIB")
125142
if not os.path.exists(support_lib):
126-
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
127-
support_lib)
128-
compiler = nvgpucompiler.NvgpuCompiler(options,
129-
opt_level=3,
130-
shared_libs=[support_lib])
143+
raise FileNotFoundError(
144+
errno.ENOENT, os.strerror(errno.ENOENT), support_lib
145+
)
146+
compiler = nvgpucompiler.NvgpuCompiler(
147+
options, opt_level=3, shared_libs=[support_lib]
148+
)
131149

132150
# Compile
133151
engine = compiler.compile_and_jit(mlir_nvgpu_module)
134152
return engine
135153

136154

137-
def matmul(input_type=np.float16,
138-
output_type=np.float32,
139-
M=128,
140-
N=128,
141-
K=128,
142-
BLOCK_M=128,
143-
BLOCK_N=128,
144-
BLOCK_K=64,
145-
use_warp_specilization=True,
146-
saveIR=False,
147-
max_num_stages=3,
148-
print_results=False,
149-
no_verify=False):
155+
def matmul(
156+
input_type=np.float16,
157+
output_type=np.float32,
158+
M=128,
159+
N=128,
160+
K=128,
161+
BLOCK_M=128,
162+
BLOCK_N=128,
163+
BLOCK_K=64,
164+
use_warp_specilization=True,
165+
saveIR=False,
166+
max_num_stages=3,
167+
print_results=False,
168+
no_verify=False,
169+
):
150170
# Print the configuration
151171
ity = "f16" if input_type == np.float16 else "f32"
152172
oty = "f16" if output_type == np.float16 else "f32"
153173
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
154-
print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
155-
ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
156-
str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
157-
", stages " + str(max_num_stages) + " --===")
174+
print(
175+
"===-- Running GEMM "
176+
+ gemmty
177+
+ " "
178+
+ oty
179+
+ " += "
180+
+ ity
181+
+ " * "
182+
+ ity
183+
+ ", Size "
184+
+ str(M)
185+
+ "x"
186+
+ str(N)
187+
+ "x"
188+
+ str(K)
189+
+ ", Tile "
190+
+ str(BLOCK_M)
191+
+ "x"
192+
+ str(BLOCK_N)
193+
+ "x"
194+
+ str(BLOCK_K)
195+
+ ", stages "
196+
+ str(max_num_stages)
197+
+ " --==="
198+
)
158199

159200
# Build IR and compile
160-
engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
161-
BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
162-
max_num_stages)
201+
engine = generate_matmul(
202+
input_type,
203+
output_type,
204+
M,
205+
N,
206+
K,
207+
BLOCK_M,
208+
BLOCK_N,
209+
BLOCK_K,
210+
use_warp_specilization,
211+
saveIR,
212+
max_num_stages,
213+
)
163214

164215
# Allocate matrices and invoke the matmul
165216
c = np.zeros((M, N), output_type)
@@ -168,13 +219,17 @@ def matmul(input_type=np.float16,
168219
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
169220
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
170221
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
171-
kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
222+
kernelName = (
223+
"mlir_matmul_warpspecialized"
224+
if use_warp_specilization
225+
else "mlir_matmul_multistage"
226+
)
172227

173228
# Launch the MLIR generated kernel
174229
engine.invoke(kernelName, mem_a, mem_b, mem_c)
175230

176231
float_formatter = "{:.2f}".format
177-
np.set_printoptions(formatter={'float_kind': float_formatter})
232+
np.set_printoptions(formatter={"float_kind": float_formatter})
178233

179234
if print_results:
180235
print(c)
@@ -190,18 +245,22 @@ def matmul(input_type=np.float16,
190245

191246

192247
# GEMM Multistage f32 += f16 * f16
193-
matmul(np.float16,
194-
np.float32,
195-
128,
196-
128,
197-
4096,
198-
max_num_stages=3,
199-
use_warp_specilization=False)
248+
matmul(
249+
np.float16,
250+
np.float32,
251+
128,
252+
128,
253+
4096,
254+
max_num_stages=3,
255+
use_warp_specilization=False,
256+
)
200257
# GEMM Warp Specilized f32 += f16 * f16
201-
matmul(np.float16,
202-
np.float32,
203-
256,
204-
1024,
205-
512,
206-
max_num_stages=3,
207-
use_warp_specilization=True)
258+
matmul(
259+
np.float16,
260+
np.float32,
261+
256,
262+
1024,
263+
512,
264+
max_num_stages=3,
265+
use_warp_specilization=True,
266+
)

0 commit comments

Comments
 (0)