Skip to content

Commit 76c9b90

Browse files
committed
[mlir][Vector] Fix vector.extract lowering to llvm for 0-d vectors
1 parent 4028bb1 commit 76c9b90

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,43 +1096,50 @@ class VectorExtractOpConversion
10961096
SmallVector<OpFoldResult> positionVec = getMixedValues(
10971097
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
10981098

1099-
// Extract entire vector. Should be handled by folder, but just to be safe.
1100-
ArrayRef<OpFoldResult> position(positionVec);
1101-
if (position.empty()) {
1102-
rewriter.replaceOp(extractOp, adaptor.getVector());
1103-
return success();
1104-
}
1105-
1106-
// One-shot extraction of vector from array (only requires extractvalue).
1107-
// Except for extracting 1-element vectors.
1108-
if (isa<VectorType>(resultType) &&
1109-
position.size() !=
1110-
static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
1111-
if (extractOp.hasDynamicPosition())
1112-
return failure();
1113-
1114-
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1115-
loc, adaptor.getVector(), getAsIntegers(position));
1116-
rewriter.replaceOp(extractOp, extracted);
1117-
return success();
1118-
}
1099+
// Determine if we need to extract a scalar as the result. We extract
1100+
// a scalar if the extract is full rank i.e. the number of indices is equal
1101+
// to source vector rank.
1102+
bool isScalarExtract =
1103+
positionVec.size() == extractOp.getSourceVectorType().getRank();
1104+
// Determine if we need to extract a slice out of the original vector. We
1105+
// always need to extract a slice if the input rank >= 2.
1106+
bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
11191107

1120-
// Potential extraction of 1-D vector from array.
11211108
Value extracted = adaptor.getVector();
1122-
if (position.size() > 1) {
1123-
if (extractOp.hasDynamicPosition())
1109+
if (isSlicingExtract) {
1110+
ArrayRef<OpFoldResult> position(positionVec);
1111+
if (isScalarExtract) {
1112+
// If we are extracting a scalar from the returned slice, we need to
1113+
// extract a N-1 D slice.
1114+
position = position.drop_back();
1115+
}
1116+
// llvm.extractvalue does not support dynamic dimensions.
1117+
if (!llvm::all_of(position,
1118+
[](OpFoldResult x) { return isa<Attribute>(x); })) {
11241119
return failure();
1120+
}
1121+
extracted = rewriter.create<LLVM::ExtractValueOp>(
1122+
loc, extracted, getAsIntegers(position));
1123+
}
11251124

1126-
SmallVector<int64_t> nMinusOnePosition =
1127-
getAsIntegers(position.drop_back());
1128-
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1129-
nMinusOnePosition);
1125+
if (isScalarExtract) {
1126+
Value position;
1127+
if (positionVec.empty()) {
1128+
// A scalar extract with no position is a 0-D vector extract. The LLVM
1129+
// type converter converts 0-D vectors to 1-D vectors, so we need to add
1130+
// a constant position.
1131+
auto idxType = rewriter.getIndexType();
1132+
position = rewriter.create<LLVM::ConstantOp>(
1133+
loc, typeConverter->convertType(idxType),
1134+
rewriter.getIntegerAttr(idxType, 0));
1135+
} else {
1136+
position = getAsLLVMValue(rewriter, loc, positionVec.back());
1137+
}
1138+
extracted =
1139+
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
11301140
}
11311141

1132-
Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1133-
// Remaining extraction of element from 1-D LLVM vector.
1134-
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1135-
lastPosition);
1142+
rewriter.replaceOp(extractOp, extracted);
11361143
return success();
11371144
}
11381145
};

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,26 +1258,65 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
12581258

12591259
// -----
12601260

1261-
func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
1261+
func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
12621262
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
12631263
return %0 : f32
12641264
}
12651265

1266-
// Multi-dim vectors are not supported but this test shouldn't crash.
1266+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
12671267

1268-
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
1269-
// CHECK: vector.extract
1268+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
1269+
// CHECK: llvm.extractvalue
1270+
// CHECK: llvm.extractelement
12701271

1271-
func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
1272+
func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
12721273
%0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
12731274
return %0 : f32
12741275
}
12751276

1276-
// Multi-dim vectors are not supported but this test shouldn't crash.
1277+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1278+
1279+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
1280+
// CHECK: llvm.extractvalue
1281+
// CHECK: llvm.extractelement
1282+
1283+
// -----
12771284

1278-
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(
1285+
func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
1286+
%0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x16xf32>
1287+
return %0 : f32
1288+
}
1289+
1290+
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1291+
1292+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx(
12791293
// CHECK: vector.extract
12801294

1295+
func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
1296+
%0 = vector.extract %arg0[%arg1, 0]: f32 from vector<1x[16]xf32>
1297+
return %0 : f32
1298+
}
1299+
1300+
// Multi-dim vectors with outer dimension as dynamic are not supported, but it
1301+
// shouldn't crash.
1302+
1303+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(
1304+
// CHECK: vector.extract
1305+
1306+
// -----
1307+
1308+
func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
1309+
%0 = vector.extract %arg0[]: index from vector<index>
1310+
return %0 : index
1311+
}
1312+
// CHECK-LABEL: @extract_scalar_from_vec_0d_index(
1313+
// CHECK-SAME: %[[A:.*]]: vector<index>)
1314+
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<index> to vector<1xi64>
1315+
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : index) : i64
1316+
// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<1xi64>
1317+
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
1318+
// CHECK: return %[[T3]] : index
1319+
12811320
// -----
12821321

12831322
func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {

0 commit comments

Comments
 (0)