Skip to content

Commit 98d8dce

Browse files
authored
[mlir][affine] implement inferType for delinearize (#74644)
1 parent 58c2a4e commit 98d8dce

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
include "mlir/Dialect/Arith/IR/ArithBase.td"
1717
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
1818
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/InferTypeOpInterface.td"
1920
include "mlir/Interfaces/LoopLikeInterface.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
2122

@@ -63,10 +64,6 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
6364
// has a constant builder. That way we wouldn't need to explicitly specify the
6465
// result types here.
6566
let builders = [
66-
OpBuilder<(ins "AffineMap":$map, "ValueRange":$mapOperands),
67-
[{
68-
build($_builder, $_state, $_builder.getIndexType(), map, mapOperands);
69-
}]>,
7067
OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
7168
[{
7269
build($_builder, $_state, $_builder.getIndexType(),
@@ -541,13 +538,6 @@ class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
541538
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
542539
let results = (outs Index);
543540

544-
let builders = [
545-
OpBuilder<(ins "AffineMap":$affineMap, "ValueRange":$mapOperands),
546-
[{
547-
build($_builder, $_state, $_builder.getIndexType(), affineMap, mapOperands);
548-
}]>
549-
];
550-
551541
let extraClassDeclaration = [{
552542
static StringRef getMapAttrStrName() { return "map"; }
553543
AffineMap getAffineMap() { return getMap(); }
@@ -1068,7 +1058,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
10681058
//===----------------------------------------------------------------------===//
10691059

10701060
def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
1071-
[Pure]> {
1061+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
10721062
let summary = "delinearize an index";
10731063
let description = [{
10741064
The `affine.delinearize_index` operation takes a single index value and

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4474,6 +4474,17 @@ LogicalResult AffineVectorStoreOp::verify() {
44744474
// DelinearizeIndexOp
44754475
//===----------------------------------------------------------------------===//
44764476

4477+
LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4478+
MLIRContext *context, std::optional<::mlir::Location> location,
4479+
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4480+
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4481+
AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4482+
regions);
4483+
inferredReturnTypes.assign(adaptor.getBasis().size(),
4484+
IndexType::get(context));
4485+
return success();
4486+
}
4487+
44774488
void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
44784489
Value linearIndex,
44794490
ArrayRef<OpFoldResult> basis) {

mlir/lib/Dialect/Affine/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRAffineDialect
1515
MLIRArithDialect
1616
MLIRDialectUtils
1717
MLIRIR
18+
MLIRInferTypeOpInterface
1819
MLIRLoopLikeInterface
1920
MLIRMemRefDialect
2021
MLIRShapedOpInterfaces

mlir/test/python/dialects/affine.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ def affine_store_test(arg0):
4444
return mem
4545

4646

47+
# CHECK-LABEL: TEST: testAffineDelinearizeInfer
48+
@constructAndPrintInModule
49+
def testAffineDelinearizeInfer():
50+
# CHECK: %[[C0:.*]] = arith.constant 0 : index
51+
c0 = arith.ConstantOp(T.index(), 0)
52+
# CHECK: %[[C1:.*]] = arith.constant 1 : index
53+
c1 = arith.ConstantOp(T.index(), 1)
54+
# CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (%[[C1:.*]], %[[C0:.*]]) : index, index
55+
two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0])
56+
57+
4758
# CHECK-LABEL: TEST: testAffineLoadOp
4859
@constructAndPrintInModule
4960
def testAffineLoadOp():

0 commit comments

Comments
 (0)