Skip to content

Commit f4a3fb4

Browse files
committed
Address comments
1 parent 824be43 commit f4a3fb4

File tree

2 files changed

+39
-37
lines changed

2 files changed

+39
-37
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,53 +1096,55 @@ class VectorExtractOpConversion
10961096
SmallVector<OpFoldResult> positionVec = getMixedValues(
10971097
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
10981098

1099-
// The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
1100-
// The stacking is modeled using arrays. We do this conversion from a
1101-
// N-d vector extract to stacked 1-d vector extract in two steps:
1102-
// - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
1103-
// - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
1104-
1105-
// Determine if we need to extract a slice out of the original vector. We
1106-
// always need to extract a slice if the input rank >= 2.
1107-
bool isSlicingExtract = extractOp.getSourceVectorType().getRank() >= 2;
1099+
// The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100+
// 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101+
// from a N-d vector extract to a nested aggregate vector extract in two
1102+
// steps:
1103+
// - Extract a member from the nested aggregate. The result can be
1104+
// a lower rank nested aggregate or a vector (1-D). This is done using
1105+
// `llvm.extractvalue`.
1106+
// - Extract a scalar out of the vector if needed. This is done using
1107+
// `llvm.extractelement`.
1108+
1109+
// Determine if we need to extract a member out of the aggregate. We
1110+
// always need to extract a member if the input rank >= 2.
1111+
bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
11081112
// Determine if we need to extract a scalar as the result. We extract
1109-
// a scalar if the extract is full rank i.e. the number of indices is equal
1110-
// to source vector rank.
1111-
bool isScalarExtract = static_cast<int64_t>(positionVec.size()) ==
1112-
extractOp.getSourceVectorType().getRank();
1113+
// a scalar if the extract is full rank, i.e., the number of indices is
1114+
// equal to source vector rank.
1115+
bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1116+
extractOp.getSourceVectorType().getRank();
1117+
1118+
// Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119+
// need to add a position for this change.
1120+
if (extractOp.getSourceVectorType().getRank() == 0) {
1121+
auto idxType = rewriter.getIndexType();
1122+
Value position = rewriter.create<LLVM::ConstantOp>(
1123+
loc, typeConverter->convertType(idxType),
1124+
rewriter.getIntegerAttr(idxType, 0));
1125+
positionVec.push_back(position);
1126+
}
11131127

11141128
Value extracted = adaptor.getVector();
1115-
if (isSlicingExtract) {
1129+
if (extractsScalar) {
11161130
ArrayRef<OpFoldResult> position(positionVec);
1117-
if (isScalarExtract) {
1118-
// If we are extracting a scalar from the returned slice, we need to
1119-
// extract a N-1 D slice.
1131+
if (extractsAggregate) {
1132+
// If we are extracting a scalar from the extracted member, we drop
1133+
// the last index, which will be used to extract the scalar out of the
1134+
// vector.
11201135
position = position.drop_back();
11211136
}
11221137
// llvm.extractvalue does not support dynamic dimensions.
1123-
if (!llvm::all_of(position,
1124-
[](OpFoldResult x) { return isa<Attribute>(x); })) {
1138+
if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
11251139
return failure();
11261140
}
11271141
extracted = rewriter.create<LLVM::ExtractValueOp>(
11281142
loc, extracted, getAsIntegers(position));
11291143
}
11301144

1131-
if (isScalarExtract) {
1132-
Value position;
1133-
if (positionVec.empty()) {
1134-
// A scalar extract with no position is a 0-D vector extract. The LLVM
1135-
// type converter converts 0-D vectors to 1-D vectors, so we need to add
1136-
// a constant position.
1137-
auto idxType = rewriter.getIndexType();
1138-
position = rewriter.create<LLVM::ConstantOp>(
1139-
loc, typeConverter->convertType(idxType),
1140-
rewriter.getIntegerAttr(idxType, 0));
1141-
} else {
1142-
position = getAsLLVMValue(rewriter, loc, positionVec.back());
1143-
}
1144-
extracted =
1145-
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, position);
1145+
if (extractsScalar) {
1146+
extracted = rewriter.create<LLVM::ExtractElementOp>(
1147+
loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
11461148
}
11471149

11481150
rewriter.replaceOp(extractOp, extracted);

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(%arg0: vector<1x16xf
12951295
return %0 : f32
12961296
}
12971297

1298-
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1298+
// Multi-dim vectors are supported if the innermost index is dynamic.
12991299

13001300
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx(
13011301
// CHECK: llvm.extractvalue
@@ -1306,7 +1306,7 @@ func.func @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(%arg0: vect
13061306
return %0 : f32
13071307
}
13081308

1309-
// Multi-dim vectors are supported if the inner most dimension is dynamic.
1309+
// Multi-dim vectors are supported if the innermost index is dynamic.
13101310

13111311
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_inner_dynamic_idx_scalable(
13121312
// CHECK: llvm.extractvalue
@@ -1329,7 +1329,7 @@ func.func @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(%arg0: vect
13291329
return %0 : f32
13301330
}
13311331

1332-
// Multi-dim vectors with outer dimension as dynamic are not supported, but it
1332+
// Multi-dim vectors with outer indices as dynamic are not supported, but it
13331333
// shouldn't crash.
13341334

13351335
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_outer_dynamic_idx_scalable(

0 commit comments

Comments
 (0)