Skip to content

Commit 24acade

Browse files
committed
[mlir][Shape] Make shape_eq nary
This gets rid of a dubious shape_eq %a, %a fold, that folds shape_eq even if %a is not an Attribute. Differential Revision: https://reviews.llvm.org/D97728
1 parent 64f5d7e commit 24acade

File tree

5 files changed

+128
-64
lines changed

5 files changed

+128
-64
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,38 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
168168
let hasFolder = 1;
169169
}
170170

171-
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
171+
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
172+
InferTypeOpInterface]> {
172173
let summary = "Returns whether the input shapes or extent tensors are equal";
173174
let description = [{
174-
Takes two shape or extent tensor operands and determines whether they are
175-
equal. When extent tensors are compared to shapes they are regarded as their
176-
equivalent non-error shapes. Error shapes can be tested for equality like
177-
any other shape value, meaning that the error value is equal to itself.
175+
Takes one or more shape or extent tensor operands and determines whether
176+
they are equal. When extent tensors are compared to shapes they are regarded
177+
as their equivalent non-error shapes. Error shapes can be tested for
178+
equality like any other shape value, meaning that the error value is equal
179+
to itself.
178180
}];
179181

180-
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
181-
Shape_ShapeOrExtentTensorType:$rhs);
182+
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
182183
let results = (outs I1:$result);
183184

184-
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
185+
// Convenience builder alias for the binary version.
186+
let builders = [
187+
OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
188+
[{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
189+
];
190+
let extraClassDeclaration = [{
191+
// TODO: This should really be automatic. Figure out how to not need this defined.
192+
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
193+
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
194+
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
195+
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
196+
inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
197+
/*width=*/1));
198+
return success();
199+
};
200+
}];
201+
202+
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
185203
let hasFolder = 1;
186204
}
187205

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -474,46 +474,56 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
474474
LogicalResult
475475
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
476476
ConversionPatternRewriter &rewriter) const {
477-
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
478-
// on shapes.
479-
if (op.lhs().getType().isa<ShapeType>() ||
480-
op.rhs().getType().isa<ShapeType>()) {
477+
if (!llvm::all_of(op.shapes(),
478+
[](Value v) { return !v.getType().isa<ShapeType>(); }))
481479
return failure();
480+
481+
Type i1Ty = rewriter.getI1Type();
482+
if (op.shapes().size() <= 1) {
483+
rewriter.replaceOpWithNewOp<ConstantOp>(op, i1Ty,
484+
rewriter.getBoolAttr(true));
485+
return success();
482486
}
483487

484488
ShapeEqOp::Adaptor transformed(operands);
485489
auto loc = op.getLoc();
486490
Type indexTy = rewriter.getIndexType();
487491
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
488-
Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
489-
Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
490-
Value eqRank =
491-
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
492-
Type i1Ty = rewriter.getI1Type();
493-
rewriter.replaceOpWithNewOp<IfOp>(
494-
op, i1Ty, eqRank,
495-
[&](OpBuilder &b, Location loc) {
496-
Value one = b.create<ConstantIndexOp>(loc, 1);
497-
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
498-
auto loop = b.create<scf::ForOp>(
499-
loc, zero, lhsRank, one, ValueRange{init},
500-
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
501-
Value conj = args[0];
502-
Value lhsExtent =
503-
b.create<tensor::ExtractOp>(loc, transformed.lhs(), iv);
504-
Value rhsExtent =
505-
b.create<tensor::ExtractOp>(loc, transformed.rhs(), iv);
506-
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
507-
lhsExtent, rhsExtent);
508-
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
509-
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
510-
});
511-
b.create<scf::YieldOp>(loc, loop.getResults());
512-
},
513-
[&](OpBuilder &b, Location loc) {
514-
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
515-
b.create<scf::YieldOp>(loc, result);
516-
});
492+
Value firstShape = transformed.shapes().front();
493+
Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
494+
Value result = nullptr;
495+
// Generate a linear sequence of compares, all with firstShape as lhs.
496+
for (Value shape : transformed.shapes().drop_front(1)) {
497+
Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
498+
Value eqRank =
499+
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
500+
auto same = rewriter.create<IfOp>(
501+
loc, i1Ty, eqRank,
502+
[&](OpBuilder &b, Location loc) {
503+
Value one = b.create<ConstantIndexOp>(loc, 1);
504+
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
505+
auto loop = b.create<scf::ForOp>(
506+
loc, zero, firstRank, one, ValueRange{init},
507+
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
508+
Value conj = args[0];
509+
Value lhsExtent =
510+
b.create<tensor::ExtractOp>(loc, firstShape, iv);
511+
Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
512+
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
513+
lhsExtent, rhsExtent);
514+
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
515+
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
516+
});
517+
b.create<scf::YieldOp>(loc, loop.getResults());
518+
},
519+
[&](OpBuilder &b, Location loc) {
520+
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
521+
b.create<scf::YieldOp>(loc, result);
522+
});
523+
result = !result ? same.getResult(0)
524+
: rewriter.create<AndOp>(loc, result, same.getResult(0));
525+
}
526+
rewriter.replaceOp(op, result);
517527
return success();
518528
}
519529

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -629,15 +629,15 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
629629
//===----------------------------------------------------------------------===//
630630

631631
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
632-
if (lhs() == rhs())
633-
return BoolAttr::get(getContext(), true);
634-
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
635-
if (lhs == nullptr)
636-
return {};
637-
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
638-
if (rhs == nullptr)
632+
bool allSame = true;
633+
if (!operands.empty() && !operands[0])
639634
return {};
640-
return BoolAttr::get(getContext(), lhs == rhs);
635+
for (Attribute operand : operands.drop_front(1)) {
636+
if (!operand)
637+
return {};
638+
allSame = allSame && operand == operands[0];
639+
}
640+
return BoolAttr::get(getContext(), allSame);
641641
}
642642

643643
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
295295

296296
// -----
297297

298+
// CHECK-LABEL: @shape_eq
299+
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> i1
300+
func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>, %c : tensor<?xindex>) -> i1 {
301+
// CHECK: %[[C0:.*]] = constant 0 : index
302+
// CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
303+
// CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
304+
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]]
305+
// CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
306+
// CHECK: %[[C1:.*]] = constant 1 : index
307+
// CHECK: %[[INIT:.*]] = constant true
308+
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
309+
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
310+
// CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor<?xindex>
311+
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]]
312+
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
313+
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
314+
// CHECK: }
315+
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
316+
// CHECK: } else {
317+
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
318+
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
319+
// CHECK: }
320+
// CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor<?xindex>
321+
// CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]]
322+
// CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
323+
// CHECK: %[[C1:.*]] = constant 1 : index
324+
// CHECK: %[[INIT:.*]] = constant true
325+
// CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
326+
// CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xindex>
327+
// CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor<?xindex>
328+
// CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]]
329+
// CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
330+
// CHECK: scf.yield %[[CONJ_NEXT]] : i1
331+
// CHECK: }
332+
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
333+
// CHECK: } else {
334+
// CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
335+
// CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
336+
// CHECK: }
337+
// CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1
338+
// CHECK: return %[[RESULT]] : i1
339+
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
340+
return %result : i1
341+
}
342+
343+
// -----
344+
298345
// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
299346
// CHECK-LABEL: @broadcast
300347
func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,8 @@ func @shape_eq_fold_1() -> i1 {
864864
// CHECK: return %[[RESULT]] : i1
865865
%a = shape.const_shape [1, 2, 3] : !shape.shape
866866
%b = shape.const_shape [1, 2, 3] : tensor<?xindex>
867-
%result = shape.shape_eq %a, %b : !shape.shape, tensor<?xindex>
867+
%c = shape.const_shape [1, 2, 3] : tensor<?xindex>
868+
%result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<?xindex>, tensor<?xindex>
868869
return %result : i1
869870
}
870871

@@ -877,7 +878,8 @@ func @shape_eq_fold_0() -> i1 {
877878
// CHECK: return %[[RESULT]] : i1
878879
%a = shape.const_shape [1, 2, 3] : tensor<?xindex>
879880
%b = shape.const_shape [4, 5, 6] : tensor<?xindex>
880-
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
881+
%c = shape.const_shape [4, 5, 6] : tensor<?xindex>
882+
%result = shape.shape_eq %a, %b, %c : tensor<?xindex>, tensor<?xindex>, tensor<?xindex>
881883
return %result : i1
882884
}
883885

@@ -908,19 +910,6 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
908910
return %result : i1
909911
}
910912

911-
912-
// -----
913-
914-
// Fold `shape_eq` for non-constant but same shapes.
915-
// CHECK-LABEL: @shape_eq_do_fold
916-
// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
917-
func @shape_eq_do_fold(%a : !shape.shape) -> i1 {
918-
// CHECK: %[[RESULT:.*]] = constant true
919-
// CHECK: return %[[RESULT]] : i1
920-
%result = shape.shape_eq %a, %a : !shape.shape, !shape.shape
921-
return %result : i1
922-
}
923-
924913
// -----
925914

926915
// Fold `mul` for constant sizes.

0 commit comments

Comments
 (0)