85
85
import ctypes
86
86
from mlir import runtime as rt
87
87
88
+
88
89
def generate_matmul (input_type = np .float16 ,
89
90
output_type = np .float32 ,
90
91
M = 4096 ,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
96
97
use_warp_specilization = True ,
97
98
saveIR = False ,
98
99
max_num_stages = 3 ):
99
- with matmulBuilder .ir .Context () as ctx , matmulBuilder .ir .Location .unknown ():
100
+ with matmulBuilder .ir .Context () as ctx , matmulBuilder .ir .Location .unknown (
101
+ ):
100
102
if use_warp_specilization :
101
- mlir_nvgpu_module = matmulBuilder .generate_matmul_ws (input_type , output_type , M , N , K , BLOCK_M , BLOCK_N ,
102
- BLOCK_K , max_num_stages )
103
+ mlir_nvgpu_module = matmulBuilder .generate_matmul_ws (
104
+ input_type , output_type , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K ,
105
+ max_num_stages )
103
106
else :
104
- mlir_nvgpu_module = matmulBuilder .generate_matmul_multistage (input_type , output_type , M , N , K , BLOCK_M ,
105
- BLOCK_N , BLOCK_K , max_num_stages )
107
+ mlir_nvgpu_module = matmulBuilder .generate_matmul_multistage (
108
+ input_type , output_type , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K ,
109
+ max_num_stages )
106
110
107
111
mlir_nvgpu_module .operation .verify ()
108
-
112
+
109
113
# Save generated IR
110
114
if saveIR :
111
115
# print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
119
123
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
120
124
support_lib = os .getenv ("SUPPORT_LIB" )
121
125
if not os .path .exists (support_lib ):
122
- raise FileNotFoundError (errno .ENOENT , os .strerror (errno .ENOENT ), support_lib )
123
- compiler = nvgpucompiler .NvgpuCompiler (options , opt_level = 3 , shared_libs = [support_lib ])
126
+ raise FileNotFoundError (errno .ENOENT , os .strerror (errno .ENOENT ),
127
+ support_lib )
128
+ compiler = nvgpucompiler .NvgpuCompiler (options ,
129
+ opt_level = 3 ,
130
+ shared_libs = [support_lib ])
124
131
125
132
# Compile
126
133
engine = compiler .compile_and_jit (mlir_nvgpu_module )
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
144
151
ity = "f16" if input_type == np .float16 else "f32"
145
152
oty = "f16" if output_type == np .float16 else "f32"
146
153
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
147
- print ("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str (M ) + "x" + str (N ) +
148
- "x" + str (K ) + ", Tile " + str (BLOCK_M ) + "x" + str (BLOCK_N ) + "x" + str (BLOCK_K ) + ", stages " +
149
- str (max_num_stages ) + " --===" )
154
+ print ("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
155
+ ity + ", Size " + str (M ) + "x" + str (N ) + "x" + str (K ) + ", Tile " +
156
+ str (BLOCK_M ) + "x" + str (BLOCK_N ) + "x" + str (BLOCK_K ) +
157
+ ", stages " + str (max_num_stages ) + " --===" )
150
158
151
159
# Build IR and compile
152
- engine = generate_matmul (input_type , output_type , M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , use_warp_specilization ,
153
- saveIR , max_num_stages )
160
+ engine = generate_matmul (input_type , output_type , M , N , K , BLOCK_M ,
161
+ BLOCK_N , BLOCK_K , use_warp_specilization , saveIR ,
162
+ max_num_stages )
154
163
155
164
# Allocate matrices and invoke the matmul
156
165
c = np .zeros ((M , N ), output_type )
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
181
190
182
191
183
192
# GEMM Multistage f32 += f16 * f16
184
- matmul (np .float16 , np .float32 , 128 , 128 , 4096 , max_num_stages = 3 , use_warp_specilization = False )
193
+ matmul (np .float16 ,
194
+ np .float32 ,
195
+ 128 ,
196
+ 128 ,
197
+ 4096 ,
198
+ max_num_stages = 3 ,
199
+ use_warp_specilization = False )
185
200
# GEMM Warp Specilized f32 += f16 * f16
186
- matmul (np .float16 , np .float32 , 256 , 1024 , 512 , max_num_stages = 3 , use_warp_specilization = True )
201
+ matmul (np .float16 ,
202
+ np .float32 ,
203
+ 256 ,
204
+ 1024 ,
205
+ 512 ,
206
+ max_num_stages = 3 ,
207
+ use_warp_specilization = True )
0 commit comments