Skip to content

Commit cb6411d

Browse files
committed
[mlir][ArmSME] Move tests out of vector-ops-to-llvm.mlir
These tests basically were integration tests as unit tests, checking too many passes at once to be useful, and brittle to any changes. This patch moves these tests to the `vector -> ArmSME` conversion tests. The rest of the lowerings are already checked (e.g. in ArmSME to SCF tests).
1 parent 79d4d16 commit cb6411d

File tree

2 files changed

+216
-357
lines changed

2 files changed

+216
-357
lines changed

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,219 @@ func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
620620
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
621621
// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
622622
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
623+
624+
//===----------------------------------------------------------------------===//
625+
// vector.load
626+
//===----------------------------------------------------------------------===//
627+
628+
// -----
629+
630+
// CHECK-LABEL: @vector_load_i8_with_offset(
631+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xi8>)
632+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
633+
// CHECK: %[[C123:.*]] = arith.constant 123 : index
634+
// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C123]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
635+
func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
636+
%c0 = arith.constant 0 : index
637+
%c123 = arith.constant 123 : index
638+
%tile = vector.load %arg0[%c123, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
639+
return %tile : vector<[16]x[16]xi8>
640+
}
641+
642+
// -----
643+
644+
// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
645+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xi8>)
646+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
647+
// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
648+
func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
649+
%c0 = arith.constant 0 : index
650+
%tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
651+
return %tile : vector<[16]x[16]xi8>
652+
}
653+
654+
// -----
655+
656+
// CHECK-LABEL: @vector_load_i16(
657+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
658+
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
659+
%c0 = arith.constant 0 : index
660+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
661+
return %tile : vector<[8]x[8]xi16>
662+
}
663+
664+
// -----
665+
666+
// CHECK-LABEL: @vector_load_i32(
667+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
668+
func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
669+
%c0 = arith.constant 0 : index
670+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
671+
return %tile : vector<[4]x[4]xi32>
672+
}
673+
674+
// -----
675+
676+
// CHECK-LABEL: @vector_load_i64(
677+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
678+
func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
679+
%c0 = arith.constant 0 : index
680+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
681+
return %tile : vector<[2]x[2]xi64>
682+
}
683+
684+
// -----
685+
686+
// CHECK-LABEL: @vector_load_f16(
687+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
688+
func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
689+
%c0 = arith.constant 0 : index
690+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
691+
return %tile : vector<[8]x[8]xf16>
692+
}
693+
694+
// -----
695+
696+
// CHECK-LABEL: @vector_load_bf16(
697+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
698+
func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
699+
%c0 = arith.constant 0 : index
700+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
701+
return %tile : vector<[8]x[8]xbf16>
702+
}
703+
704+
// -----
705+
706+
// CHECK-LABEL: @vector_load_f32(
707+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
708+
func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
709+
%c0 = arith.constant 0 : index
710+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
711+
return %tile : vector<[4]x[4]xf32>
712+
}
713+
714+
// -----
715+
716+
// CHECK-LABEL: @vector_load_f64(
717+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
718+
func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
719+
%c0 = arith.constant 0 : index
720+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
721+
return %tile : vector<[2]x[2]xf64>
722+
}
723+
724+
// -----
725+
726+
// CHECK-LABEL: @vector_load_i128(
727+
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
728+
func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
729+
%c0 = arith.constant 0 : index
730+
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
731+
return %tile : vector<[1]x[1]xi128>
732+
}
733+
734+
735+
//===----------------------------------------------------------------------===//
736+
// vector.store
737+
//===----------------------------------------------------------------------===//
738+
739+
// -----
740+
741+
// CHECK-LABEL: @vector_store_i8(
742+
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xi8>) {
743+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
744+
// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
745+
// CHECK: arm_sme.tile_store %[[TILE]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
746+
func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
747+
%c0 = arith.constant 0 : index
748+
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
749+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
750+
return
751+
}
752+
753+
// -----
754+
755+
// CHECK-LABEL: @vector_store_i16
756+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
757+
func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
758+
%c0 = arith.constant 0 : index
759+
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
760+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
761+
return
762+
}
763+
764+
// -----
765+
766+
// CHECK-LABEL: @vector_store_i32
767+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
768+
func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
769+
%c0 = arith.constant 0 : index
770+
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
771+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
772+
return
773+
}
774+
775+
// -----
776+
777+
// CHECK-LABEL: @vector_store_i64
778+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
779+
func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
780+
%c0 = arith.constant 0 : index
781+
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
782+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
783+
return
784+
}
785+
786+
// -----
787+
788+
// CHECK-LABEL: @vector_store_f16
789+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
790+
func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
791+
%c0 = arith.constant 0 : index
792+
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
793+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
794+
return
795+
}
796+
797+
// -----
798+
799+
// CHECK-LABEL: @vector_store_bf16
800+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
801+
func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
802+
%c0 = arith.constant 0 : index
803+
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
804+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
805+
return
806+
}
807+
// -----
808+
809+
// CHECK-LABEL: @vector_store_f32
810+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
811+
func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
812+
%c0 = arith.constant 0 : index
813+
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
814+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
815+
return
816+
}
817+
818+
// -----
819+
820+
// CHECK-LABEL: @vector_store_f64
821+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
822+
func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
823+
%c0 = arith.constant 0 : index
824+
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
825+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
826+
return
827+
}
828+
829+
// -----
830+
831+
// CHECK-LABEL: @vector_store_i128
832+
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
833+
func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
834+
%c0 = arith.constant 0 : index
835+
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
836+
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
837+
return
838+
}

0 commit comments

Comments
 (0)