Skip to content

[mlir][ArmSME][test] Prepare tests for tile allocation changes #91358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented May 7, 2024

This patch:

  1. Removes some duplicate test cases
  2. Removes unnecessary uses of -convert-arm-sme-to-llvm
  3. Ensures tile values have uses via test.some_use()

1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.

This patch:

 1. Removes some duplicate test cases
 2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
 3. Ensures tile values have uses via `test.some_use()`

1 and 2 will make these tests easier to update. 3 will be needed as
ArmSME operations will be pure.
@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This patch:

  1. Removes some duplicate test cases
  2. Removes unnecessary uses of -convert-arm-sme-to-llvm
  3. Ensures tile values have uses via test.some_use()

1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.


Patch is 38.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91358.diff

7 Files Affected:

  • (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+22-2)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir (+1-1)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+6)
  • (renamed) mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir (+166-131)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+6-14)
  • (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+4-3)
  • (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+29-14)
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 81087cc02099..f48046a8d799 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -25,6 +25,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -36,6 +37,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -47,6 +49,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -58,6 +61,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -69,6 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -80,6 +85,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -91,6 +97,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -102,6 +109,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -113,6 +121,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -124,6 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -135,6 +145,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -146,6 +157,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -157,6 +169,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -168,6 +181,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -179,6 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -190,6 +205,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -201,6 +217,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -212,6 +229,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -441,7 +459,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
 func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -452,7 +471,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  %tile_update =  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 59665c471921..15767ff1dec3 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -9,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
   // expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error@+1 {{unsupported type}}
   %0 = arm_sme.outerproduct %lhs, %rhs  acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
-  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 6c393bc38af9..a2f2beff78c4 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -20,6 +20,7 @@
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
 func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
   %pad = arith.constant 0 : i32
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
   %c3 = arith.constant 3 : index
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
   %pad = arith.constant 0 : i32
   // expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
   %c0 = arith.constant 0 : index
   // expected-error@+1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
similarity index 52%
rename from mlir/test/Dialect/ArmSME/tile-allocation.mlir
rename to mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c368dd4fa23..e144bac970a7 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
 
 // -----
 
+// Note: Tile IDs >= 16 are in-memory tile IDs (i.e. spills).
+
 // CHECK-LABEL: mixed_tiles
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -18,76 +19,61 @@ func.func @mixed_tiles() {
   // CHECK-NEXT: tile_id = 7
   %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_b
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
   // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_b_overlapping_za_q
 func.func @za_b_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
-func.func @za0_h() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
   // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // CHECK-NEXT: tile_id = 1
   %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za_h__out_of_tiles
-func.func @za_h__out_of_tiles() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // CHECK-NEXT: tile_id = 1
-  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%next_tile) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -98,13 +84,15 @@ func.func @za_h_overlapping_za_s() {
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -121,40 +109,55 @@ func.func @za_h_overlapping_za_d() {
   // ZA7.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 7
   %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h_overlapping_za_q
 func.func @za_h_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
-func.func @za0_s() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
   // CHECK-NEXT: tile_id = 0
   %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
@@ -164,25 +167,20 @@ func.func @za_s() {
   %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  return
-}
-
-// -----
-
-func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue merged commit 1a49810 into llvm:main May 8, 2024
@MacDue MacDue deleted the test_refactor branch May 8, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants