@@ -1096,53 +1096,55 @@ class VectorExtractOpConversion
1096
1096
SmallVector<OpFoldResult> positionVec = getMixedValues (
1097
1097
adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
1098
1098
1099
- // The LLVM lowering models multi dimension vectors as stacked 1-d vectors.
1100
- // The stacking is modeled using arrays. We do this conversion from a
1101
- // N-d vector extract to stacked 1-d vector extract in two steps:
1102
- // - Extract a 1-d vector or a stack of 1-d vectors (llvm.extractvalue)
1103
- // - Extract a scalar out of the 1-d vector if needed (llvm.extractelement)
1104
-
1105
- // Determine if we need to extract a slice out of the original vector. We
1106
- // always need to extract a slice if the input rank >= 2.
1107
- bool isSlicingExtract = extractOp.getSourceVectorType ().getRank () >= 2 ;
1099
+ // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100
+ // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101
+ // from a N-d vector extract to a nested aggregate vector extract in two
1102
+ // steps:
1103
+ // - Extract a member from the nested aggregate. The result can be
1104
+ // a lower rank nested aggregate or a vector (1-D). This is done using
1105
+ // `llvm.extractvalue`.
1106
+ // - Extract a scalar out of the vector if needed. This is done using
1107
+ // `llvm.extractelement`.
1108
+
1109
+ // Determine if we need to extract a member out of the aggregate. We
1110
+ // always need to extract a member if the input rank >= 2.
1111
+ bool extractsAggregate = extractOp.getSourceVectorType ().getRank () >= 2 ;
1108
1112
// Determine if we need to extract a scalar as the result. We extract
1109
- // a scalar if the extract is full rank i.e. the number of indices is equal
1110
- // to source vector rank.
1111
- bool isScalarExtract = static_cast <int64_t >(positionVec.size ()) ==
1112
- extractOp.getSourceVectorType ().getRank ();
1113
+ // a scalar if the extract is full rank, i.e., the number of indices is
1114
+ // equal to source vector rank.
1115
+ bool extractsScalar = static_cast <int64_t >(positionVec.size ()) ==
1116
+ extractOp.getSourceVectorType ().getRank ();
1117
+
1118
+ // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119
+ // need to add a position for this change.
1120
+ if (extractOp.getSourceVectorType ().getRank () == 0 ) {
1121
+ auto idxType = rewriter.getIndexType ();
1122
+ Value position = rewriter.create <LLVM::ConstantOp>(
1123
+ loc, typeConverter->convertType (idxType),
1124
+ rewriter.getIntegerAttr (idxType, 0 ));
1125
+ positionVec.push_back (position);
1126
+ }
1113
1127
1114
1128
Value extracted = adaptor.getVector ();
1115
- if (isSlicingExtract ) {
1129
+ if (extractsScalar ) {
1116
1130
ArrayRef<OpFoldResult> position (positionVec);
1117
- if (isScalarExtract) {
1118
- // If we are extracting a scalar from the returned slice, we need to
1119
- // extract a N-1 D slice.
1131
+ if (extractsAggregate) {
1132
+ // If we are extracting a scalar from the extracted member, we drop
1133
+ // the last index, which will be used to extract the scalar out of the
1134
+ // vector.
1120
1135
position = position.drop_back ();
1121
1136
}
1122
1137
// llvm.extractvalue does not support dynamic dimensions.
1123
- if (!llvm::all_of (position,
1124
- [](OpFoldResult x) { return isa<Attribute>(x); })) {
1138
+ if (!llvm::all_of (position, llvm::IsaPred<Attribute>)) {
1125
1139
return failure ();
1126
1140
}
1127
1141
extracted = rewriter.create <LLVM::ExtractValueOp>(
1128
1142
loc, extracted, getAsIntegers (position));
1129
1143
}
1130
1144
1131
- if (isScalarExtract) {
1132
- Value position;
1133
- if (positionVec.empty ()) {
1134
- // A scalar extract with no position is a 0-D vector extract. The LLVM
1135
- // type converter converts 0-D vectors to 1-D vectors, so we need to add
1136
- // a constant position.
1137
- auto idxType = rewriter.getIndexType ();
1138
- position = rewriter.create <LLVM::ConstantOp>(
1139
- loc, typeConverter->convertType (idxType),
1140
- rewriter.getIntegerAttr (idxType, 0 ));
1141
- } else {
1142
- position = getAsLLVMValue (rewriter, loc, positionVec.back ());
1143
- }
1144
- extracted =
1145
- rewriter.create <LLVM::ExtractElementOp>(loc, extracted, position);
1145
+ if (extractsScalar) {
1146
+ extracted = rewriter.create <LLVM::ExtractElementOp>(
1147
+ loc, extracted, getAsLLVMValue (rewriter, loc, positionVec.back ()));
1146
1148
}
1147
1149
1148
1150
rewriter.replaceOp (extractOp, extracted);
0 commit comments