Skip to content

Commit e0addfc

Browse files
author
Lily Orth-Smith
committed
Comments
1 parent bcbab88 commit e0addfc

File tree

2 files changed

+18
-54
lines changed

2 files changed

+18
-54
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
9797
return success();
9898
}
9999

100+
// Helper to resolve the alignment for vector load/store, gather and scatter
101+
// ops. If useVectorAlignment is true, get the preferred alignment for the
102+
// vector type in the operation. This option is used for hardware backends with
103+
// vectorization. Otherwise, use the preferred alignment of the element type of
104+
// the memref. Note that if you choose to use vector alignment, the shape of the
105+
// vector type must be resolved before the ConvertVectorToLLVM pass is run.
100106
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
101107
VectorType vectorType,
102108
MemRefType memrefType, unsigned &align,

mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,9 @@ func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8
1111
return %0 : vector<8xf32>
1212
}
1313

14-
// VEC-ALIGN-LABEL: func @load
15-
// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
16-
// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
17-
// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
18-
// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
19-
// VEC-ALIGN: llvm.load %[[GEP]] {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
14+
// ALL-LABEL: func @load
2015

21-
// MEMREF-ALIGN-LABEL: func @load
22-
// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
23-
// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
24-
// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
25-
// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
16+
// VEC-ALIGN: llvm.load %[[GEP]] {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
2617
// MEMREF-ALIGN: llvm.load %[[GEP]] {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
2718

2819
// -----
@@ -37,18 +28,9 @@ func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
3728
return
3829
}
3930

40-
// VEC-ALIGN-LABEL: func @store
41-
// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
42-
// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
43-
// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
44-
// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
45-
// VEC-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
31+
// ALL-LABEL: func @store
4632

47-
// MEMREF-ALIGN-LABEL: func @store
48-
// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
49-
// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
50-
// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
51-
// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
33+
// VEC-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
5234
// MEMREF-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr
5335

5436
// -----
@@ -63,19 +45,10 @@ func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: ve
6345
return %0 : vector<16xf32>
6446
}
6547

66-
// VEC-ALIGN-LABEL: func @masked_load
67-
// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
68-
// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
69-
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
70-
// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
71-
// VEC-ALIGN: return %[[L]] : vector<16xf32>
48+
// ALL-LABEL: func @masked_load
7249

73-
// MEMREF-ALIGN-LABEL: func @masked_load
74-
// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
75-
// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
76-
// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
50+
// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
7751
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
78-
// MEMREF-ALIGN: return %[[L]] : vector<16xf32>
7952

8053
// -----
8154

@@ -89,16 +62,9 @@ func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: v
8962
return
9063
}
9164

92-
// VEC-ALIGN-LABEL: func @masked_store
93-
// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
94-
// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
95-
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
96-
// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
65+
// ALL-LABEL: func @masked_store
9766

98-
// MEMREF-ALIGN-LABEL: func @masked_store
99-
// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
100-
// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
101-
// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
67+
// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
10268
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
10369

10470
// -----
@@ -113,12 +79,9 @@ func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3x
11379
return
11480
}
11581

116-
// VEC-ALIGN-LABEL: func @scatter
117-
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
118-
// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
82+
// ALL-LABEL: func @scatter
11983

120-
// MEMREF-ALIGN-LABEL: func @scatter
121-
// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
84+
// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
12285
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
12386

12487
// -----
@@ -133,12 +96,7 @@ func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi
13396
return %1 : vector<3xf32>
13497
}
13598

136-
// VEC-ALIGN-LABEL: func @gather
137-
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
138-
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
139-
// VEC-ALIGN: return %[[G]] : vector<3xf32>
99+
// ALL-LABEL: func @gather
140100

141-
// MEMREF-ALIGN-LABEL: func @gather
142-
// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
101+
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
143102
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
144-
// MEMREF-ALIGN: return %[[G]] : vector<3xf32>

0 commit comments

Comments
 (0)