-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add integration test for vector.deinterleave
#95469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % one suggestion
871ae6b
to
6c28f21
Compare
@llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) ChangesThis commit is dependent on #95313. Full diff: https://github.com/llvm/llvm-project/pull/95469.diff 1 Files Affected:
diff --git a/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
new file mode 100644
index 0000000000000..b7616f7b2ee3e
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-vulkan-runner %s \
+// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN: --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 2]
+// CHECK: [1, 3]
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+ gpu.module @kernels {
+ gpu.func @kernel_vector_deinterleave(%arg0 : memref<4xi32>, %arg1 : memref<2xi32>, %arg2 : memref<2xi32>)
+ kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx2 = arith.constant 2 : index
+ %idx3 = arith.constant 3 : index
+ %idx4 = arith.constant 4 : index
+
+ %src = arith.constant dense<[0, 0, 0, 0]> : vector<4xi32>
+
+ %val0 = memref.load %arg0[%idx0] : memref<4xi32>
+ %val1 = memref.load %arg0[%idx1] : memref<4xi32>
+ %val2 = memref.load %arg0[%idx2] : memref<4xi32>
+ %val3 = memref.load %arg0[%idx3] : memref<4xi32>
+
+ %src0 = vector.insert %val0, %src[%idx0] : i32 into vector<4xi32>
+ %src1 = vector.insert %val1, %src0[%idx1] : i32 into vector<4xi32>
+ %src2 = vector.insert %val2, %src1[%idx2] : i32 into vector<4xi32>
+ %src3 = vector.insert %val3, %src2[%idx3] : i32 into vector<4xi32>
+
+ %res0, %res1 = vector.deinterleave %src3 : vector<4xi32> -> vector<2xi32>
+
+ %res0_0 = vector.extract %res0[%idx0] : i32 from vector<2xi32>
+ %res0_1 = vector.extract %res0[%idx1] : i32 from vector<2xi32>
+ %res1_0 = vector.extract %res1[%idx0] : i32 from vector<2xi32>
+ %res1_1 = vector.extract %res1[%idx1] : i32 from vector<2xi32>
+
+ memref.store %res0_0, %arg1[%idx0]: memref<2xi32>
+ memref.store %res0_1, %arg1[%idx1]: memref<2xi32>
+ memref.store %res1_0, %arg2[%idx0]: memref<2xi32>
+ memref.store %res1_1, %arg2[%idx1]: memref<2xi32>
+
+ gpu.return
+ }
+ }
+
+ func.func @main() {
+ // Allocate 3 buffers.
+ %buf0 = memref.alloc() : memref<4xi32>
+ %buf1 = memref.alloc() : memref<2xi32>
+ %buf2 = memref.alloc() : memref<2xi32>
+
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx4 = arith.constant 4 : index
+
+ // Initialize input buffer.
+ %buf0_vals = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+ vector.store %buf0_vals, %buf0[%idx0] : memref<4xi32>, vector<4xi32>
+
+ // Initialize output buffers.
+ %value0 = arith.constant 0 : i32
+ %buf3 = memref.cast %buf1 : memref<2xi32> to memref<?xi32>
+ %buf4 = memref.cast %buf2 : memref<2xi32> to memref<?xi32>
+ call @fillResource1DInt(%buf3, %value0) : (memref<?xi32>, i32) -> ()
+ call @fillResource1DInt(%buf4, %value0) : (memref<?xi32>, i32) -> ()
+
+ gpu.launch_func @kernels::@kernel_vector_deinterleave
+ blocks in (%idx4, %idx1, %idx1) threads in (%idx1, %idx1, %idx1)
+ args(%buf0 : memref<4xi32>, %buf1 : memref<2xi32>, %buf2 : memref<2xi32>)
+ %buf5 = memref.cast %buf3 : memref<?xi32> to memref<*xi32>
+ %buf6 = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+ call @printMemrefI32(%buf5) : (memref<*xi32>) -> ()
+ call @printMemrefI32(%buf6) : (memref<*xi32>) -> ()
+ return
+ }
+ func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+ func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}
|
6c28f21
to
a8683f3
Compare
a8683f3
to
d2fa33a
Compare
The added test here is broken. Can you please revert or quickly fix forward? The test should be using |
…lvm#95469)" This reverts commit b1b7643.
This commit is dependent on #95313.