10
10
// RUN: -convert-func-to-llvm \
11
11
// RUN: -canonicalize \
12
12
// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
13
- // RUN: 2&> 1 | FileCheck %s --check-prefixes=CHECK-PTX
13
+ // RUN: 2>& 1 | FileCheck %s --check-prefixes=CHECK-PTX
14
14
15
15
// CHECK-PTX: mbarrier.init.shared.b64
16
16
// CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
19
19
// CHECK-PTX: mbarrier.arrive.expect_tx.shared.b64
20
20
// CHECK-PTX: mbarrier.try_wait.parity.shared.b64
21
21
22
+ // RUN: mlir-opt %s --convert-nvgpu-to-nvvm \
23
+ // RUN: -gpu-kernel-outlining \
24
+ // RUN: -convert-nvvm-to-llvm \
25
+ // RUN: -convert-nvgpu-to-nvvm \
26
+ // RUN: -convert-scf-to-cf \
27
+ // RUN: -convert-vector-to-llvm \
28
+ // RUN: -convert-index-to-llvm=index-bitwidth=32 \
29
+ // RUN: -convert-arith-to-llvm \
30
+ // RUN: -finalize-memref-to-llvm='use-opaque-pointers=1' \
31
+ // RUN: -convert-func-to-llvm \
32
+ // RUN: -expand-strided-metadata --nvvm-attach-target="module=main_kernel features=+ptx80 chip=sm_90 O=3" \
33
+ // RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-index-to-llvm{index-bitwidth=32},canonicalize,cse))' \
34
+ // RUN: | mlir-opt --gpu-to-llvm --gpu-module-to-binary -canonicalize -cse -reconcile-unrealized-casts \
35
+ // RUN: | mlir-cpu-runner \
36
+ // RUN: --shared-libs=%mlir_cuda_runtime \
37
+ // RUN: --shared-libs=%mlir_runner_utils \
38
+ // RUN: --entry-point-result=void \
39
+ // RUN: | FileCheck %s
40
+
41
+
42
+ // CHECK: [GPU] TMA BEFORE lhs[45][7] 0.000000
43
+ // CHECK: [GPU] TMA BEFORE rhs[7][0] 0.000000
44
+ // CHECK: [GPU] TMA LOADED lhs[45][7] 7.000000
45
+ // CHECK: [GPU] TMA LOADED rhs[7][0] 3.000000
46
+
22
47
module @mymod {
23
48
memref.global " private" @bufferLhsGlobal : memref <64 x8 xf32 , 3 >
24
49
memref.global " private" @bufferRhsGlobal : memref <8 x128 xf32 , 3 >
@@ -87,4 +112,4 @@ module @mymod {
87
112
}
88
113
return
89
114
}
90
- }
115
+ }
0 commit comments