Skip to content

Commit 1e7d6d3

Browse files
authored
[mlir][vector] Propagate scalability to gather/scatter ptrs vector (llvm#97584)
In convert-vector-to-llvm the first operand (vector of pointers holding all memory addresses to read) to the masked.gather (and scatter) intrinsic has a fixed vector type. This may result in intrinsics where the scalable flag has been dropped: ``` %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32> ``` Fortunately the operand is overloaded on the result type so we end up with the correct IR when lowering to LLVM, but this is still incorrect. This patch fixes it by propagating scalability.
1 parent efc5a6a commit 1e7d6d3

File tree

5 files changed

+91
-10
lines changed

5 files changed

+91
-10
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,8 @@ def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> {
900900
$_resultType, $ptrs, $mask, $pass_thru, $_int_attr($alignment));
901901
}];
902902
list<int> llvmArgIndices = [0, 2, 3, 1];
903+
904+
let hasVerifier = 1;
903905
}
904906

905907
/// Create a call to Masked Scatter intrinsic.
@@ -919,6 +921,8 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
919921
$value, $ptrs, $mask, $_int_attr($alignment));
920922
}];
921923
list<int> llvmArgIndices = [0, 1, 3, 2];
924+
925+
let hasVerifier = 1;
922926
}
923927

924928
/// Create a call to Masked Expand Load intrinsic.

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,14 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
102102
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
103103
const LLVMTypeConverter &typeConverter,
104104
MemRefType memRefType, Value llvmMemref, Value base,
105-
Value index, uint64_t vLen) {
105+
Value index, VectorType vectorType) {
106106
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107107
"unsupported memref type");
108+
assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
108109
auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
109-
auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
110+
auto ptrsType =
111+
LLVM::getVectorType(pType, vectorType.getDimSize(0),
112+
/*isScalable=*/vectorType.getScalableDims()[0]);
110113
return rewriter.create<LLVM::GEPOp>(
111114
loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
112115
base, index);
@@ -288,9 +291,9 @@ class VectorGatherOpConversion
288291
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
289292
auto vType = gather.getVectorType();
290293
// Resolve address.
291-
Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
292-
memRefType, base, ptr, adaptor.getIndexVec(),
293-
/*vLen=*/vType.getDimSize(0));
294+
Value ptrs =
295+
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296+
base, ptr, adaptor.getIndexVec(), vType);
294297
// Replace with the gather intrinsic.
295298
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
296299
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
@@ -305,8 +308,7 @@ class VectorGatherOpConversion
305308
// Resolve address.
306309
Value ptrs = getIndexedPtrs(
307310
rewriter, loc, typeConverter, memRefType, base, ptr,
308-
/*index=*/vectorOperands[0],
309-
LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
311+
/*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
310312
// Create the gather intrinsic.
311313
return rewriter.create<LLVM::masked_gather>(
312314
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
@@ -343,9 +345,9 @@ class VectorScatterOpConversion
343345
VectorType vType = scatter.getVectorType();
344346
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
345347
adaptor.getIndices(), rewriter);
346-
Value ptrs = getIndexedPtrs(
347-
rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
348-
ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
348+
Value ptrs =
349+
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350+
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
349351

350352
// Replace with the scatter intrinsic.
351353
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,6 +3076,40 @@ void InlineAsmOp::getEffects(
30763076
}
30773077
}
30783078

3079+
//===----------------------------------------------------------------------===//
3080+
// masked_gather (intrinsic)
3081+
//===----------------------------------------------------------------------===//
3082+
3083+
LogicalResult LLVM::masked_gather::verify() {
3084+
auto ptrsVectorType = getPtrs().getType();
3085+
Type expectedPtrsVectorType =
3086+
LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
3087+
LLVM::getVectorNumElements(getRes().getType()));
3088+
// Vector of pointers type should match result vector type, other than the
3089+
// element type.
3090+
if (ptrsVectorType != expectedPtrsVectorType)
3091+
return emitOpError("expected operand #1 type to be ")
3092+
<< expectedPtrsVectorType;
3093+
return success();
3094+
}
3095+
3096+
//===----------------------------------------------------------------------===//
3097+
// masked_scatter (intrinsic)
3098+
//===----------------------------------------------------------------------===//
3099+
3100+
LogicalResult LLVM::masked_scatter::verify() {
3101+
auto ptrsVectorType = getPtrs().getType();
3102+
Type expectedPtrsVectorType =
3103+
LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
3104+
LLVM::getVectorNumElements(getValue().getType()));
3105+
// Vector of pointers type should match value vector type, other than the
3106+
// element type.
3107+
if (ptrsVectorType != expectedPtrsVectorType)
3108+
return emitOpError("expected operand #2 type to be ")
3109+
<< expectedPtrsVectorType;
3110+
return success();
3111+
}
3112+
30793113
//===----------------------------------------------------------------------===//
30803114
// LLVMDialect initialization, type parsing, and registration.
30813115
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,6 +2248,19 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
22482248

22492249
// -----
22502250

2251+
func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
2252+
%0 = arith.constant 0: index
2253+
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
2254+
return %1 : vector<[3]xf32>
2255+
}
2256+
2257+
// CHECK-LABEL: func @gather_op_scalable
2258+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2259+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2260+
// CHECK: return %[[G]] : vector<[3]xf32>
2261+
2262+
// -----
2263+
22512264
func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
22522265
%0 = arith.constant 0: index
22532266
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
@@ -2351,6 +2364,18 @@ func.func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<
23512364

23522365
// -----
23532366

2367+
func.func @scatter_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) {
2368+
%0 = arith.constant 0: index
2369+
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32>
2370+
return
2371+
}
2372+
2373+
// CHECK-LABEL: func @scatter_op_scalable
2374+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2375+
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
2376+
2377+
// -----
2378+
23542379
func.func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) {
23552380
%0 = arith.constant 0: index
23562381
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex>

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,14 @@ llvm.func @masked_gather_intr_wrong_type(%ptrs : vector<7xf32>, %mask : vector<7
261261

262262
// -----
263263

264+
llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : !llvm.vec<7 x ptr>, %mask : vector<[7]xi1>) -> vector<[7]xf32> {
265+
// expected-error @below{{expected operand #1 type to be '!llvm.vec<? x 7 x ptr>'}}
266+
%0 = llvm.intr.masked.gather %ptrs, %mask { alignment = 1: i32} : (!llvm.vec<7 x ptr>, vector<[7]xi1>) -> vector<[7]xf32>
267+
llvm.return %0 : vector<[7]xf32>
268+
}
269+
270+
// -----
271+
264272
llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : !llvm.vec<7xptr>, %mask : vector<7xi1>) {
265273
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
266274
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into !llvm.vec<7xptr>
@@ -269,6 +277,14 @@ llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : !llvm.vec<7xptr>,
269277

270278
// -----
271279

280+
llvm.func @masked_scatter_intr_wrong_type_scalable(%vec : vector<[7]xf32>, %ptrs : !llvm.vec<7xptr>, %mask : vector<[7]xi1>) {
281+
// expected-error @below{{expected operand #2 type to be '!llvm.vec<? x 7 x ptr>'}}
282+
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : vector<[7]xf32>, vector<[7]xi1> into !llvm.vec<7xptr>
283+
llvm.return
284+
}
285+
286+
// -----
287+
272288
llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {
273289
// expected-error @below{{op result #0 must be LLVM dialect-compatible vector of signless integer, but got 'vector<7xf32>'}}
274290
%0 = llvm.intr.experimental.stepvector : vector<7xf32>

0 commit comments

Comments
 (0)