Skip to content

[mlir][tensor][NFC] Simplify SubsetInsertionOpInterface implementation #69999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,105 +17,58 @@ using namespace mlir::tensor;

namespace {

/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
/// equivalence of tensors.
template <typename OpTy>
bool isSubsetEquivalentToInsertSliceLikeOp(
OpTy insertSliceOp, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
// Look for a matching tensor.extract_slice op.
auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp)
return false;
if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
return false;
return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
isEqualConstantIntOrValue);
}

template <typename OpTy>
Value buildSubsetExtractionOfInsertSliceLikeOp(OpBuilder &b, Location loc,
OpTy insertSliceOp) {
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
return extractOp.getResult();
}

template <typename OpTy>
SmallVector<Value>
getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(OpTy insertSliceOp) {
SmallVector<Value> neededValues;
// Collect all values that are needed to construct the replacement op.
neededValues.append(insertSliceOp.getOffsets().begin(),
insertSliceOp.getOffsets().end());
neededValues.append(insertSliceOp.getSizes().begin(),
insertSliceOp.getSizes().end());
neededValues.append(insertSliceOp.getStrides().begin(),
insertSliceOp.getStrides().end());
neededValues.push_back(insertSliceOp.getDest());
return neededValues;
}

struct InsertSliceOpInterface
: public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
OpOperand &getSourceOperand(Operation *op) const {
return cast<tensor::InsertSliceOp>(op).getSourceMutable();
}

bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
equivalenceFn);
}

Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
return buildSubsetExtractionOfInsertSliceLikeOp(
builder, loc, cast<tensor::InsertSliceOp>(op));
}

SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
cast<tensor::InsertSliceOp>(op));
}
};

struct ParallelInsertSliceOpInterface
struct InsertSliceLikeOpInterface
: public SubsetInsertionOpInterface::ExternalModel<
ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
InsertSliceLikeOpInterface<OpTy>, OpTy> {
OpOperand &getSourceOperand(Operation *op) const {
return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
return cast<OpTy>(op).getSourceMutable();
}

OpOperand &getDestinationOperand(Operation *op) const {
return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
return cast<OpTy>(op).getDestMutable();
}

/// Return "true" if `insertSliceOp` inserts into a subset that is equivalent
/// to the subset defined by `candidate`. `equivalenceFn` is used to determine
/// equivalence of tensors.
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
return isSubsetEquivalentToInsertSliceLikeOp(insertSliceOp, candidate,
equivalenceFn);
auto insertSliceOp = cast<OpTy>(op);
// Look for a matching tensor.extract_slice op.
auto extractSliceOp = candidate.getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp)
return false;
if (!equivalenceFn(extractSliceOp.getSource(), insertSliceOp.getDest()))
return false;
return sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
isEqualConstantIntOrValue);
}

Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
return buildSubsetExtractionOfInsertSliceLikeOp(
builder, loc, cast<tensor::ParallelInsertSliceOp>(op));
auto insertSliceOp = cast<OpTy>(op);
auto extractOp = builder.create<tensor::ExtractSliceOp>(
loc, insertSliceOp.getSourceType(), insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
return extractOp.getResult();
}

SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
return getValuesNeededToBuildSubsetExtractionOfInsertSliceLikeOp(
cast<tensor::ParallelInsertSliceOp>(op));
auto insertSliceOp = cast<OpTy>(op);
SmallVector<Value> neededValues;
// Collect all values that are needed to construct the replacement op.
neededValues.append(insertSliceOp.getOffsets().begin(),
insertSliceOp.getOffsets().end());
neededValues.append(insertSliceOp.getSizes().begin(),
insertSliceOp.getSizes().end());
neededValues.append(insertSliceOp.getStrides().begin(),
insertSliceOp.getStrides().end());
neededValues.push_back(insertSliceOp.getDest());
return neededValues;
}
};

Expand All @@ -124,8 +77,9 @@ struct ParallelInsertSliceOpInterface
void mlir::tensor::registerSubsetInsertionOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
InsertSliceOp::attachInterface<InsertSliceLikeOpInterface<InsertSliceOp>>(
*ctx);
ParallelInsertSliceOp::attachInterface<
InsertSliceLikeOpInterface<ParallelInsertSliceOp>>(*ctx);
});
}