86
86
from mlir import runtime as rt
87
87
88
88
89
- def generate_matmul (input_type = np .float16 ,
90
- output_type = np .float32 ,
91
- M = 4096 ,
92
- N = 4096 ,
93
- K = 4096 ,
94
- BLOCK_M = 128 ,
95
- BLOCK_N = 128 ,
96
- BLOCK_K = 64 ,
97
- use_warp_specilization = True ,
98
- saveIR = False ,
99
- max_num_stages = 3 ):
100
- with matmulBuilder .ir .Context () as ctx , matmulBuilder .ir .Location .unknown (
101
- ):
89
+ def generate_matmul (
90
+ input_type = np .float16 ,
91
+ output_type = np .float32 ,
92
+ M = 4096 ,
93
+ N = 4096 ,
94
+ K = 4096 ,
95
+ BLOCK_M = 128 ,
96
+ BLOCK_N = 128 ,
97
+ BLOCK_K = 64 ,
98
+ use_warp_specilization = True ,
99
+ saveIR = False ,
100
+ max_num_stages = 3 ,
101
+ ):
102
+ with matmulBuilder .ir .Context () as ctx , matmulBuilder .ir .Location .unknown ():
102
103
if use_warp_specilization :
103
104
mlir_nvgpu_module = matmulBuilder .generate_matmul_ws (
104
- input_type , output_type , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K ,
105
- max_num_stages )
105
+ input_type ,
106
+ output_type ,
107
+ M ,
108
+ N ,
109
+ K ,
110
+ BLOCK_M ,
111
+ BLOCK_N ,
112
+ BLOCK_K ,
113
+ max_num_stages ,
114
+ )
106
115
else :
107
116
mlir_nvgpu_module = matmulBuilder .generate_matmul_multistage (
108
- input_type , output_type , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K ,
109
- max_num_stages )
117
+ input_type ,
118
+ output_type ,
119
+ M ,
120
+ N ,
121
+ K ,
122
+ BLOCK_M ,
123
+ BLOCK_N ,
124
+ BLOCK_K ,
125
+ max_num_stages ,
126
+ )
110
127
111
128
mlir_nvgpu_module .operation .verify ()
112
129
113
130
# Save generated IR
114
131
if saveIR :
115
132
# print(mlir_nvgpu_module)
116
133
original_stdout = sys .stdout
117
- with open (' gemm.mlir' , 'w' ) as f :
134
+ with open (" gemm.mlir" , "w" ) as f :
118
135
sys .stdout = f
119
136
print (mlir_nvgpu_module )
120
137
sys .stdout = original_stdout
@@ -123,43 +140,77 @@ def generate_matmul(input_type=np.float16,
123
140
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
124
141
support_lib = os .getenv ("SUPPORT_LIB" )
125
142
if not os .path .exists (support_lib ):
126
- raise FileNotFoundError (errno .ENOENT , os .strerror (errno .ENOENT ),
127
- support_lib )
128
- compiler = nvgpucompiler .NvgpuCompiler (options ,
129
- opt_level = 3 ,
130
- shared_libs = [support_lib ])
143
+ raise FileNotFoundError (
144
+ errno .ENOENT , os .strerror (errno .ENOENT ), support_lib
145
+ )
146
+ compiler = nvgpucompiler .NvgpuCompiler (
147
+ options , opt_level = 3 , shared_libs = [support_lib ]
148
+ )
131
149
132
150
# Compile
133
151
engine = compiler .compile_and_jit (mlir_nvgpu_module )
134
152
return engine
135
153
136
154
137
- def matmul (input_type = np .float16 ,
138
- output_type = np .float32 ,
139
- M = 128 ,
140
- N = 128 ,
141
- K = 128 ,
142
- BLOCK_M = 128 ,
143
- BLOCK_N = 128 ,
144
- BLOCK_K = 64 ,
145
- use_warp_specilization = True ,
146
- saveIR = False ,
147
- max_num_stages = 3 ,
148
- print_results = False ,
149
- no_verify = False ):
155
+ def matmul (
156
+ input_type = np .float16 ,
157
+ output_type = np .float32 ,
158
+ M = 128 ,
159
+ N = 128 ,
160
+ K = 128 ,
161
+ BLOCK_M = 128 ,
162
+ BLOCK_N = 128 ,
163
+ BLOCK_K = 64 ,
164
+ use_warp_specilization = True ,
165
+ saveIR = False ,
166
+ max_num_stages = 3 ,
167
+ print_results = False ,
168
+ no_verify = False ,
169
+ ):
150
170
# Print the configuration
151
171
ity = "f16" if input_type == np .float16 else "f32"
152
172
oty = "f16" if output_type == np .float16 else "f32"
153
173
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
154
- print ("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
155
- ity + ", Size " + str (M ) + "x" + str (N ) + "x" + str (K ) + ", Tile " +
156
- str (BLOCK_M ) + "x" + str (BLOCK_N ) + "x" + str (BLOCK_K ) +
157
- ", stages " + str (max_num_stages ) + " --===" )
174
+ print (
175
+ "===-- Running GEMM "
176
+ + gemmty
177
+ + " "
178
+ + oty
179
+ + " += "
180
+ + ity
181
+ + " * "
182
+ + ity
183
+ + ", Size "
184
+ + str (M )
185
+ + "x"
186
+ + str (N )
187
+ + "x"
188
+ + str (K )
189
+ + ", Tile "
190
+ + str (BLOCK_M )
191
+ + "x"
192
+ + str (BLOCK_N )
193
+ + "x"
194
+ + str (BLOCK_K )
195
+ + ", stages "
196
+ + str (max_num_stages )
197
+ + " --==="
198
+ )
158
199
159
200
# Build IR and compile
160
- engine = generate_matmul (input_type , output_type , M , N , K , BLOCK_M ,
161
- BLOCK_N , BLOCK_K , use_warp_specilization , saveIR ,
162
- max_num_stages )
201
+ engine = generate_matmul (
202
+ input_type ,
203
+ output_type ,
204
+ M ,
205
+ N ,
206
+ K ,
207
+ BLOCK_M ,
208
+ BLOCK_N ,
209
+ BLOCK_K ,
210
+ use_warp_specilization ,
211
+ saveIR ,
212
+ max_num_stages ,
213
+ )
163
214
164
215
# Allocate matrices and invoke the matmul
165
216
c = np .zeros ((M , N ), output_type )
@@ -168,13 +219,17 @@ def matmul(input_type=np.float16,
168
219
mem_a = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (a )))
169
220
mem_b = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (b )))
170
221
mem_c = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (c )))
171
- kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
222
+ kernelName = (
223
+ "mlir_matmul_warpspecialized"
224
+ if use_warp_specilization
225
+ else "mlir_matmul_multistage"
226
+ )
172
227
173
228
# Launch the MLIR generated kernel
174
229
engine .invoke (kernelName , mem_a , mem_b , mem_c )
175
230
176
231
float_formatter = "{:.2f}" .format
177
- np .set_printoptions (formatter = {' float_kind' : float_formatter })
232
+ np .set_printoptions (formatter = {" float_kind" : float_formatter })
178
233
179
234
if print_results :
180
235
print (c )
@@ -190,18 +245,22 @@ def matmul(input_type=np.float16,
190
245
191
246
192
247
# GEMM Multistage f32 += f16 * f16
193
- matmul (np .float16 ,
194
- np .float32 ,
195
- 128 ,
196
- 128 ,
197
- 4096 ,
198
- max_num_stages = 3 ,
199
- use_warp_specilization = False )
248
+ matmul (
249
+ np .float16 ,
250
+ np .float32 ,
251
+ 128 ,
252
+ 128 ,
253
+ 4096 ,
254
+ max_num_stages = 3 ,
255
+ use_warp_specilization = False ,
256
+ )
200
257
# GEMM Warp Specilized f32 += f16 * f16
201
- matmul (np .float16 ,
202
- np .float32 ,
203
- 256 ,
204
- 1024 ,
205
- 512 ,
206
- max_num_stages = 3 ,
207
- use_warp_specilization = True )
258
+ matmul (
259
+ np .float16 ,
260
+ np .float32 ,
261
+ 256 ,
262
+ 1024 ,
263
+ 512 ,
264
+ max_num_stages = 3 ,
265
+ use_warp_specilization = True ,
266
+ )
0 commit comments