-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] add tensor insert/extract op folders #142458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Interfaces/DestinationStyleOpInterface.h" | ||
#include "mlir/Interfaces/InferIntRangeInterface.h" | ||
|
@@ -33,10 +34,12 @@ | |
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallBitVector.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include "llvm/Support/MathExtras.h" | ||
#include <algorithm> | ||
#include <optional> | ||
#include <vector> | ||
|
||
using namespace mlir; | ||
using namespace mlir::tensor; | ||
|
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { | |
} | ||
}; | ||
|
||
/// Canonicalizes the pattern of the form | ||
/// | ||
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into | ||
/// tensor<12xf64> | ||
/// %extracted_element = tensor.extract %val[%c10] : | ||
/// tensor<12xf64> | ||
/// | ||
/// to | ||
/// | ||
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64> | ||
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> { | ||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, | ||
PatternRewriter &rewriter) const final { | ||
auto collapseOp = | ||
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>(); | ||
if (!collapseOp) | ||
return failure(); | ||
if (!collapseOp.getSrcType().hasStaticShape()) | ||
return failure(); | ||
|
||
auto sourceSizes = collapseOp.getSrcType().getShape(); | ||
|
||
SmallVector<Value> indices(extractOp.getIndices().begin(), | ||
extractOp.getIndices().end()); | ||
SmallVector<Value> sourceIndices; | ||
for (auto [index, group] : | ||
llvm::zip(indices, collapseOp.getReassociationIndices())) { | ||
assert(!group.empty() && "association indices groups cannot be empty"); | ||
auto groupSize = group.size(); | ||
|
||
if (groupSize == 1) { | ||
sourceIndices.push_back(index); | ||
continue; | ||
} | ||
|
||
SmallVector<int64_t> basis = | ||
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); | ||
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>( | ||
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true); | ||
llvm::append_range(sourceIndices, delinearize.getResults()); | ||
} | ||
if (collapseOp.getReassociationIndices().empty()) { | ||
auto zeroAffineMap = rewriter.getConstantAffineMap(0); | ||
int64_t srcRank = | ||
cast<RankedTensorType>(collapseOp.getSrcType()).getRank(); | ||
OpFoldResult ofr = affine::makeComposedFoldedAffineApply( | ||
rewriter, extractOp.getLoc(), zeroAffineMap, | ||
ArrayRef<OpFoldResult>{}); | ||
for (int64_t i = 0; i < srcRank; i++) { | ||
sourceIndices.push_back( | ||
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr)); | ||
} | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>( | ||
extractOp, collapseOp.getSrc(), sourceIndices); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void ExtractOp::getAsmResultNames( | ||
|
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() { | |
return success(); | ||
} | ||
|
||
/// If we have an ExtractOp consuming an InsertOp with the same | ||
/// indices, we can return the InsertOp's scalar directly. | ||
// TODO: This only checks the immediate producer; extend to go up the | ||
// insert/extract chain if the slices are disjoint. | ||
static Value foldExtractAfterInsert(ExtractOp extractOp) { | ||
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>(); | ||
|
||
auto isSame = [](Value a, Value b) { | ||
return getAsOpFoldResult(a) == getAsOpFoldResult(b); | ||
}; | ||
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() && | ||
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame)) | ||
return insertOp.getScalar(); | ||
|
||
return {}; | ||
} | ||
|
||
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { | ||
if (Attribute tensor = adaptor.getTensor()) { | ||
// If this is a splat elements attribute, simply return the value. | ||
|
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { | |
return elementsAttr.getValues<Attribute>()[indices]; | ||
} | ||
|
||
if (Value result = foldExtractAfterInsert(*this)) | ||
return result; | ||
|
||
return {}; | ||
} | ||
|
||
|
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
results.add<ExtractFromTensorCast>(context); | ||
} | ||
|
||
void mlir::tensor::populateFoldCollapseExtractPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<ExtractFromCollapseShape>(patterns.getContext()); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// FromElementsOp | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) { | |
// InsertOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
|
||
/// Pattern to fold an insert op of a constant destination and scalar to a new | ||
/// constant. | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> | ||
/// %c0 = arith.constant 0 : index | ||
/// %c4_f32 = arith.constant 4.0 : f32 | ||
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32> | ||
/// ``` | ||
/// is rewritten into: | ||
/// ``` | ||
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32> | ||
/// ``` | ||
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> { | ||
public: | ||
using OpRewritePattern<InsertOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(InsertOp insertOp, | ||
PatternRewriter &rewriter) const override { | ||
// Requires a ranked tensor type. | ||
auto destType = | ||
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType()); | ||
if (!destType) | ||
return failure(); | ||
|
||
// Pattern requires constant indices | ||
SmallVector<uint64_t, 8> indices; | ||
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) { | ||
auto indiceAttr = dyn_cast<Attribute>(indice); | ||
if (!indiceAttr) | ||
return failure(); | ||
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt()); | ||
} | ||
|
||
// Requires a constant scalar to insert | ||
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar()); | ||
Attribute scalarAttr = dyn_cast<Attribute>(scalar); | ||
if (!scalarAttr) | ||
return failure(); | ||
|
||
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>( | ||
insertOp.getDest().getDefiningOp())) { | ||
if (auto sourceAttr = | ||
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) { | ||
// Update the attribute at the inserted index. | ||
auto sourceValues = sourceAttr.getValues<Attribute>(); | ||
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices); | ||
std::vector<Attribute> updatedValues; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is an unfortunate slow path for what will be converted ultimately to a vector of int or float. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you suggesting to copy the source values instead of looping? Or to switch on the int / float type and then convert back to attributes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm saying we should be able to avoid using individual attribues per-element in the common case indeed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah in retrospect it makes sense to type switch on the attribute and the options are DenseIntElementsAttr and DenseFPElementsAttr. |
||
updatedValues.reserve(sourceAttr.getNumElements()); | ||
for (auto i = 0; i < sourceAttr.getNumElements(); ++i) { | ||
updatedValues.push_back(i == flattenedIndex ? scalarAttr | ||
: sourceValues[i]); | ||
asraa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
rewriter.replaceOpWithNewOp<arith::ConstantOp>( | ||
insertOp, sourceAttr.getType(), | ||
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()), | ||
updatedValues)); | ||
return success(); | ||
} | ||
} | ||
|
||
return failure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void InsertOp::getAsmResultNames( | ||
function_ref<void(Value, StringRef)> setNameFn) { | ||
setNameFn(getResult(), "inserted"); | ||
|
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { | |
return {}; | ||
} | ||
|
||
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<InsertOpConstantFold>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// GenerateOp | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @extract_from_collapse_shape | ||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>) | ||
func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) { | ||
%c1 = arith.constant 1 : index | ||
%c0 = arith.constant 0 : index | ||
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8> | ||
%extracted = tensor.extract %collapsed[%c0] : tensor<8xi8> | ||
%extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8> | ||
func.return %extracted, %extracted_0 : i8, i8 | ||
} | ||
|
||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8> | ||
// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8> | ||
// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8 | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @extract_from_static_shape | ||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) | ||
func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 { | ||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32> | ||
%1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32> | ||
return %1 : f32 | ||
} | ||
// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6) | ||
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32> | ||
// CHECK-NEXT: return %[[RESULT]] : f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: don't use a specific size for SmallVector without a good reasons to.