@@ -95,12 +95,12 @@ def generate_matmul(
95
95
BLOCK_M = 128 ,
96
96
BLOCK_N = 128 ,
97
97
BLOCK_K = 64 ,
98
- use_warp_specilization = True ,
98
+ use_warp_specialization = True ,
99
99
saveIR = False ,
100
100
max_num_stages = 3 ,
101
101
):
102
102
with matmulBuilder .ir .Context () as ctx , matmulBuilder .ir .Location .unknown ():
103
- if use_warp_specilization :
103
+ if use_warp_specialization :
104
104
mlir_nvgpu_module = matmulBuilder .generate_matmul_ws (
105
105
input_type ,
106
106
output_type ,
@@ -161,7 +161,7 @@ def matmul(
161
161
BLOCK_M = 128 ,
162
162
BLOCK_N = 128 ,
163
163
BLOCK_K = 64 ,
164
- use_warp_specilization = True ,
164
+ use_warp_specialization = True ,
165
165
saveIR = False ,
166
166
max_num_stages = 3 ,
167
167
print_results = False ,
@@ -170,7 +170,7 @@ def matmul(
170
170
# Print the configuration
171
171
ity = "f16" if input_type == np .float16 else "f32"
172
172
oty = "f16" if output_type == np .float16 else "f32"
173
- gemmty = "Warp Specilization " if use_warp_specilization else "Multistage"
173
+ gemmty = "Warp specialization " if use_warp_specialization else "Multistage"
174
174
print (
175
175
"===-- Running GEMM "
176
176
+ gemmty
@@ -207,7 +207,7 @@ def matmul(
207
207
BLOCK_M ,
208
208
BLOCK_N ,
209
209
BLOCK_K ,
210
- use_warp_specilization ,
210
+ use_warp_specialization ,
211
211
saveIR ,
212
212
max_num_stages ,
213
213
)
@@ -221,7 +221,7 @@ def matmul(
221
221
mem_c = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (c )))
222
222
kernelName = (
223
223
"mlir_matmul_warpspecialized"
224
- if use_warp_specilization
224
+ if use_warp_specialization
225
225
else "mlir_matmul_multistage"
226
226
)
227
227
@@ -252,7 +252,7 @@ def matmul(
252
252
128 ,
253
253
4096 ,
254
254
max_num_stages = 3 ,
255
- use_warp_specilization = False ,
255
+ use_warp_specialization = False ,
256
256
)
257
257
# GEMM Warp Specilized f32 += f16 * f16
258
258
matmul (
@@ -262,5 +262,5 @@ def matmul(
262
262
1024 ,
263
263
512 ,
264
264
max_num_stages = 3 ,
265
- use_warp_specilization = True ,
265
+ use_warp_specialization = True ,
266
266
)
0 commit comments