Skip to content

Commit 0246503

Browse files
committed
[func][spirv] Add necessary patch, pass-pipeline, and test case for new SPIR-V pipeline.
The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
1 parent b160c49 commit 0246503

File tree

4 files changed

+407
-0
lines changed

4 files changed

+407
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
Upstream PR in progress: https://github.com/llvm/llvm-project/pull/86750
2+
3+
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
4+
index 57d8e894a24b..3fc68c65de05 100644
5+
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
6+
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
7+
@@ -306,6 +306,18 @@ public:
8+
}
9+
};
10+
11+
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
12+
+class ExtractAlignedPointerAsIndexOpPattern
13+
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
14+
+public:
15+
+ using OpConversionPattern::OpConversionPattern;
16+
+
17+
+ LogicalResult
18+
+ matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
19+
+ OpAdaptor adaptor,
20+
+ ConversionPatternRewriter &rewriter) const override;
21+
+};
22+
+
23+
} // namespace
24+
25+
//===----------------------------------------------------------------------===//
26+
@@ -905,6 +917,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
27+
return success();
28+
}
29+
30+
+//===----------------------------------------------------------------------===//
31+
+// ExtractAlignedPointerAsIndexOp
32+
+//===----------------------------------------------------------------------===//
33+
+
34+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
35+
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
36+
+ ConversionPatternRewriter &rewriter) const {
37+
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
38+
+ Type indexType = typeConverter.getIndexType();
39+
+ rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
40+
+ adaptor.getSource());
41+
+ return success();
42+
+}
43+
+
44+
//===----------------------------------------------------------------------===//
45+
// Pattern population
46+
//===----------------------------------------------------------------------===//
47+
@@ -912,10 +938,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
48+
namespace mlir {
49+
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
50+
RewritePatternSet &patterns) {
51+
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
52+
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
53+
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
54+
- ReinterpretCastPattern, CastPattern>(typeConverter,
55+
- patterns.getContext());
56+
+ patterns
57+
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
58+
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
59+
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
60+
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
61+
+ typeConverter, patterns.getContext());
62+
}
63+
} // namespace mlir
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// gpu dialect with intel intrinsic functions (func dialect) to
2+
// llvm dialect (for host code) and
3+
// spirv dialect (for device code) lowering pipeline.
4+
// Ready for imex runner starting from GPU dialect.
5+
builtin.module(
6+
gpu.module(convert-func-to-spirv)
7+
gpu.module(convert-vector-to-spirv)
8+
imex-convert-gpu-to-spirv{enable-vc-intrinsic=true}
9+
spirv.module(spirv-lower-abi-attrs
10+
spirv-update-vce)
11+
func.func(llvm-request-c-wrappers)
12+
serialize-spirv
13+
convert-vector-to-scf
14+
convert-gpu-to-gpux
15+
convert-scf-to-cf
16+
convert-cf-to-llvm
17+
convert-vector-to-llvm
18+
convert-index-to-llvm
19+
convert-arith-to-llvm
20+
convert-func-to-llvm
21+
convert-math-to-llvm
22+
convert-gpux-to-llvm
23+
convert-index-to-llvm
24+
expand-strided-metadata
25+
lower-affine
26+
finalize-memref-to-llvm
27+
reconcile-unrealized-casts)
28+
// End
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
10+
module @gemm attributes {gpu.container_module,
11+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
12+
memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01>
13+
memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00>
14+
func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
15+
%c1 = arith.constant 1 : index
16+
%memref = gpu.alloc host_shared () : memref<8x16xf16>
17+
memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16>
18+
%memref_0 = gpu.alloc host_shared () : memref<16x16xf16>
19+
memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16>
20+
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
21+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>)
22+
gpu.dealloc %memref : memref<8x16xf16>
23+
gpu.dealloc %memref_0 : memref<16x16xf16>
24+
return %memref_1 : memref<8x16xf32>
25+
}
26+
27+
gpu.module @test_kernel {
28+
func.func private @llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<64xi64>) attributes{
29+
linkage_attributes=#spirv.linkage_attributes<
30+
linkage_name="llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64",
31+
linkage_type=<Import>
32+
>,
33+
VectorComputeFunctionINTEL}
34+
func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes{
35+
linkage_attributes=#spirv.linkage_attributes<
36+
linkage_name="llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32",
37+
linkage_type=<Import>
38+
>,
39+
VectorComputeFunctionINTEL}
40+
func.func private @llvm.genx.raw.send2.v128i32.i1.v8i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xi32>) -> vector<128xi32> attributes{
41+
linkage_attributes=#spirv.linkage_attributes<
42+
linkage_name="llvm.genx.raw.send2.v128i32.i1.v8i32",
43+
linkage_type=<Import>
44+
>,
45+
VectorComputeFunctionINTEL}
46+
func.func private @llvm.genx.raw.send2.v32i64.i1.v8i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<32xi64>) -> vector<32xi64> attributes{
47+
linkage_attributes=#spirv.linkage_attributes<
48+
linkage_name="llvm.genx.raw.send2.v32i64.i1.v8i32",
49+
linkage_type=<Import>
50+
>,
51+
VectorComputeFunctionINTEL}
52+
gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL,spirv.entry_point_abi = #spirv.entry_point_abi<>} {
53+
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<8x16xf16> to memref<128xf16>
54+
%cst = arith.constant dense<0> : vector<4xi64>
55+
%intptr = memref.extract_aligned_pointer_as_index %reinterpret_cast : memref<128xf16> -> index
56+
%0 = arith.index_castui %intptr : index to i64
57+
%1 = vector.insert %0, %cst[0] : i64 into vector<4xi64>
58+
%2 = vector.bitcast %1 : vector<4xi64> to vector<8xi32>
59+
%cst_0 = arith.constant dense<0> : vector<4xi64>
60+
%cst_to_non_cst = arith.addi %cst_0, %cst_0 : vector<4xi64>
61+
%intptr_1 = memref.extract_aligned_pointer_as_index %arg1 : memref<16x16xf16> -> index
62+
%3 = arith.index_castui %intptr_1 : index to i64
63+
%4 = vector.insert %3, %cst_to_non_cst[0] : i64 into vector<4xi64>
64+
%5 = vector.bitcast %4 : vector<4xi64> to vector<8xi32>
65+
%c31_i32 = arith.constant 31 : i32
66+
%c15_i32 = arith.constant 15 : i32
67+
%c31_i32_2 = arith.constant 31 : i32
68+
%6 = vector.insert %c31_i32, %5 [2] : i32 into vector<8xi32>
69+
%7 = vector.insert %c15_i32, %6 [3] : i32 into vector<8xi32>
70+
%8 = vector.insert %c31_i32_2, %7 [4] : i32 into vector<8xi32>
71+
%c0_i32 = arith.constant 0 : i32
72+
%c0_i32_3 = arith.constant 0 : i32
73+
%9 = vector.insert %c0_i32, %8 [5] : i32 into vector<8xi32>
74+
%10 = vector.insert %c0_i32_3, %9 [6] : i32 into vector<8xi32>
75+
%c3855_i32 = arith.constant 3855 : i32
76+
%11 = vector.insert %c3855_i32, %10 [7] : i32 into vector<8xi32>
77+
%reinterpret_cast_4 = memref.reinterpret_cast %arg2 to offset: [0], sizes: [128], strides: [1] : memref<8x16xf32> to memref<128xf32>
78+
%cst_5_t = arith.constant dense<0> : vector<4xi64>
79+
%cst_5 = arith.addi %cst_5_t, %cst_5_t : vector<4xi64>
80+
%intptr_6 = memref.extract_aligned_pointer_as_index %reinterpret_cast_4 : memref<128xf32> -> index
81+
%12 = arith.index_castui %intptr_6 : index to i64
82+
%13 = vector.insert %12, %cst_5 [0] : i64 into vector<4xi64>
83+
%14 = vector.bitcast %13 : vector<4xi64> to vector<8xi32>
84+
%c0_i8 = arith.constant 0 : i8
85+
%c0_i8_7 = arith.constant 0 : i8
86+
%true = arith.constant true
87+
%c1_i8 = arith.constant 1 : i8
88+
%c4_i8 = arith.constant 4 : i8
89+
%c15_i8 = arith.constant 15 : i8
90+
%c0_i32_8 = arith.constant 0 : i32
91+
%c42133376_i32 = arith.constant 42133376 : i32
92+
%cst_9 = arith.constant dense<0> : vector<32xi64>
93+
%15 = func.call @llvm.genx.raw.send2.v32i64.i1.v8i32(%c0_i8, %c0_i8_7, %true, %c1_i8, %c4_i8, %c15_i8, %c0_i32_8, %c42133376_i32, %2, %cst_9) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<32xi64>) -> vector<32xi64>
94+
%16 = vector.bitcast %15 : vector<32xi64> to vector<128xf16>
95+
%c0_i8_10 = arith.constant 0 : i8
96+
%c0_i8_11 = arith.constant 0 : i8
97+
%true_12 = arith.constant true
98+
%c1_i8_13 = arith.constant 1 : i8
99+
%c8_i8 = arith.constant 8 : i8
100+
%c15_i8_14 = arith.constant 15 : i8
101+
%c0_i32_15 = arith.constant 0 : i32
102+
%c42074755_i32 = arith.constant 42074755 : i32
103+
%cst_16 = arith.constant dense<0> : vector<128xi32>
104+
%17 = func.call @llvm.genx.raw.send2.v128i32.i1.v8i32(%c0_i8_10, %c0_i8_11, %true_12, %c1_i8_13, %c8_i8, %c15_i8_14, %c0_i32_15, %c42074755_i32, %11, %cst_16) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xi32>) -> vector<128xi32>
105+
%18 = vector.bitcast %16 : vector<128xf16> to vector<64xi32>
106+
%c134744586_i32 = arith.constant 134744586 : i32
107+
%19 = func.call @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(%17, %18, %c134744586_i32) : (vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32>
108+
%c0_i8_17 = arith.constant 0 : i8
109+
%c0_i8_18 = arith.constant 0 : i8
110+
%true_19 = arith.constant true
111+
%c1_i8_20 = arith.constant 1 : i8
112+
%c8_i8_21 = arith.constant 8 : i8
113+
%c15_i8_22 = arith.constant 15 : i8
114+
%c0_i32_23 = arith.constant 0 : i32
115+
%c33748868_i32 = arith.constant 33748868 : i32
116+
%20 = vector.bitcast %19 : vector<128xf32> to vector<64xi64>
117+
func.call @llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64(%c0_i8_17, %c0_i8_18, %true_19, %c1_i8_20, %c8_i8_21, %c15_i8_22, %c0_i32_23, %c33748868_i32, %14, %20) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<64xi64>) -> ()
118+
gpu.return
119+
}
120+
}
121+
func.func @main() attributes {llvm.emit_c_interface} {
122+
%0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16>
123+
%1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16>
124+
%2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
125+
%cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32>
126+
call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
127+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
128+
// CHECK-COUNT-128: 8.79688
129+
return
130+
}
131+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
132+
}

0 commit comments

Comments
 (0)