Skip to content

[mlir][spirv] Add integration test for vector.deinterleave #95469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024

Conversation

angelz913
Copy link
Contributor

This commit is dependent on #95313.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % one suggestion

@angelz913 angelz913 force-pushed the vector-deinterleave-e2e branch from 871ae6b to 6c28f21 Compare June 13, 2024 22:06
@llvmbot llvmbot added the mlir label Jun 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes

This commit is dependent on #95313.


Full diff: https://github.com/llvm/llvm-project/pull/95469.diff

1 Files Affected:

  • (added) mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir (+81)
diff --git a/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
new file mode 100644
index 0000000000000..b7616f7b2ee3e
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-vulkan-runner %s \
+// RUN:  --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
+// RUN:  --entry-point-result=void | FileCheck %s
+
+// CHECK: [0, 2]
+// CHECK: [1, 3]
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_vector_deinterleave(%arg0 : memref<4xi32>, %arg1 : memref<2xi32>, %arg2 : memref<2xi32>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      %idx0 = arith.constant 0 : index
+      %idx1 = arith.constant 1 : index
+      %idx2 = arith.constant 2 : index
+      %idx3 = arith.constant 3 : index
+      %idx4 = arith.constant 4 : index
+
+      %src = arith.constant dense<[0, 0, 0, 0]> : vector<4xi32>
+
+      %val0 = memref.load %arg0[%idx0] : memref<4xi32>
+      %val1 = memref.load %arg0[%idx1] : memref<4xi32>
+      %val2 = memref.load %arg0[%idx2] : memref<4xi32>
+      %val3 = memref.load %arg0[%idx3] : memref<4xi32>
+
+      %src0 = vector.insert %val0, %src[%idx0] : i32 into vector<4xi32>
+      %src1 = vector.insert %val1, %src0[%idx1] : i32 into vector<4xi32>
+      %src2 = vector.insert %val2, %src1[%idx2] : i32 into vector<4xi32>
+      %src3 = vector.insert %val3, %src2[%idx3] : i32 into vector<4xi32>
+
+      %res0, %res1 = vector.deinterleave %src3 : vector<4xi32> -> vector<2xi32>
+
+      %res0_0 = vector.extract %res0[%idx0] : i32 from vector<2xi32>
+      %res0_1 = vector.extract %res0[%idx1] : i32 from vector<2xi32>
+      %res1_0 = vector.extract %res1[%idx0] : i32 from vector<2xi32>
+      %res1_1 = vector.extract %res1[%idx1] : i32 from vector<2xi32>
+
+      memref.store %res0_0, %arg1[%idx0]: memref<2xi32>
+      memref.store %res0_1, %arg1[%idx1]: memref<2xi32>
+      memref.store %res1_0, %arg2[%idx0]: memref<2xi32>
+      memref.store %res1_1, %arg2[%idx1]: memref<2xi32>
+
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    // Allocate 3 buffers.
+    %buf0 = memref.alloc() : memref<4xi32>
+    %buf1 = memref.alloc() : memref<2xi32>
+    %buf2 = memref.alloc() : memref<2xi32>
+
+    %idx0 = arith.constant 0 : index
+    %idx1 = arith.constant 1 : index
+    %idx4 = arith.constant 4 : index
+
+    // Initialize input buffer.
+    %buf0_vals = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+    vector.store %buf0_vals, %buf0[%idx0] : memref<4xi32>, vector<4xi32>
+
+    // Initialize output buffers.
+    %value0 = arith.constant 0 : i32
+    %buf3 = memref.cast %buf1 : memref<2xi32> to memref<?xi32>
+    %buf4 = memref.cast %buf2 : memref<2xi32> to memref<?xi32>
+    call @fillResource1DInt(%buf3, %value0) : (memref<?xi32>, i32) -> ()
+    call @fillResource1DInt(%buf4, %value0) : (memref<?xi32>, i32) -> ()
+
+    gpu.launch_func @kernels::@kernel_vector_deinterleave
+        blocks in (%idx4, %idx1, %idx1) threads in (%idx1, %idx1, %idx1)
+        args(%buf0 : memref<4xi32>, %buf1 : memref<2xi32>, %buf2 : memref<2xi32>)
+    %buf5 = memref.cast %buf3 : memref<?xi32> to memref<*xi32>
+    %buf6 = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
+    call @printMemrefI32(%buf5) : (memref<*xi32>) -> ()
+    call @printMemrefI32(%buf6) : (memref<*xi32>) -> ()
+    return
+  }
+  func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func.func private @printMemrefI32(%ptr : memref<*xi32>)
+}

@angelz913 angelz913 force-pushed the vector-deinterleave-e2e branch from 6c28f21 to a8683f3 Compare June 13, 2024 22:07
@angelz913 angelz913 force-pushed the vector-deinterleave-e2e branch from a8683f3 to d2fa33a Compare June 14, 2024 03:36
@kuhar kuhar merged commit b1b7643 into llvm:main Jun 14, 2024
5 of 6 checks passed
@angelz913 angelz913 deleted the vector-deinterleave-e2e branch June 14, 2024 14:17
@cota
Copy link
Contributor

cota commented Jun 14, 2024

The added test here is broken. Can you please revert or quickly fix forward? The test should be using arith.constant's, not literals, as indices. Thanks!

angelz913 added a commit to angelz913/llvm-project that referenced this pull request Jun 14, 2024
kuhar pushed a commit that referenced this pull request Jun 14, 2024
…#95607)

Reverts #95469 because using literals instead of
`arith.constant` as indices broke the tests.
@angelz913 angelz913 restored the vector-deinterleave-e2e branch June 17, 2024 14:20
@angelz913 angelz913 deleted the vector-deinterleave-e2e branch June 17, 2024 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants