Skip to content

Commit 5984d74

Browse files
committed
[MLIR][Shape] Allow get_extent to operate on extent tensors and indices
Differential Revision: https://reviews.llvm.org/D84435
1 parent 7f600da commit 5984d74

File tree

6 files changed

+113
-28
lines changed

6 files changed

+113
-28
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
235235
an error then it returns an error size.
236236
}];
237237
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
238-
Shape_SizeType:$dim);
239-
let results = (outs Shape_SizeType:$extent);
240-
let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
238+
Shape_SizeOrIndexType:$dim);
239+
let results = (outs Shape_SizeOrIndexType:$extent);
240+
let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
241+
"type($extent) attr-dict";
241242

242243
let builders = [
243244
// Builder that allows passing a constant dimension as a simple integer.
@@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
251252
}];
252253

253254
let hasFolder = 1;
255+
let verifier = [{ return ::verify(*this); }];
254256
}
255257

256258
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
535535
// GetExtentOp
536536
//===----------------------------------------------------------------------===//
537537

538+
static LogicalResult verify(GetExtentOp op) {
539+
Type shapeTy = op.shape().getType();
540+
Type dimTy = op.dim().getType();
541+
Type extentTy = op.extent().getType();
542+
bool errorPropagationPossible =
543+
shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
544+
if (errorPropagationPossible) {
545+
if (!extentTy.isa<SizeType>())
546+
op.emitError()
547+
<< "if at least one of the operands can hold error values then the "
548+
"result must be of type `size` to propagate them";
549+
} else {
550+
if (extentTy.isa<SizeType>())
551+
op.emitError() << "if none of the operands can hold error values then "
552+
"the result must be of type `index`";
553+
}
554+
return success();
555+
}
556+
538557
Optional<int64_t> GetExtentOp::getConstantDim() {
539-
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
558+
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
540559
return constSizeOp.value().getLimitedValue();
541-
}
560+
if (auto constantOp = dim().getDefiningOp<ConstantOp>())
561+
return constantOp.value().cast<IntegerAttr>().getInt();
542562
return llvm::None;
543563
}
544564

@@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
558578
int64_t dim) {
559579
auto loc = result.location;
560580
auto dimAttr = builder.getIndexAttr(dim);
561-
Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
562-
build(builder, result, shape, dimValue);
581+
if (shape.getType().isa<ShapeType>()) {
582+
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
583+
build(builder, result, builder.getType<SizeType>(), shape, dim);
584+
} else {
585+
Value dim =
586+
builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
587+
build(builder, result, builder.getIndexType(), shape, dim);
588+
}
563589
}
564590

565591
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
136136
// `shape_of` operation.
137137
// CHECK-LABEL: @get_extent_shape_of
138138
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
139-
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
140-
-> !shape.size {
139+
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
141140
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
142141
// CHECK: return %[[RESULT]] : index
143142
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
144-
%result = shape.get_extent %shape, %idx : tensor<?xindex>
145-
return %result : !shape.size
143+
%result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
144+
return %result : index
146145
}
147146

148147
// -----
149148

150-
// Express `get_extent` as `std.extract_element` when it relies directly on the
151-
// outcome of a `from_extent_tensor` operation.
149+
// Express `get_extent` as `std.extract_element`.
152150
// CHECK-LABEL: @get_extent_from_extent_tensor
153151
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
154-
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
155-
%idx : !shape.size) -> !shape.size {
152+
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
153+
-> index {
156154
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
157155
// CHECK: return %[[RESULT]] : index
158-
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
159-
%result = shape.get_extent %shape, %idx : !shape.shape
160-
return %result : !shape.size
156+
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
157+
return %result : index
161158
}
162159

163160
// -----

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
235235

236236
// -----
237237

238+
// Basic folding.
239+
// CHECK-LABEL: func @basic
240+
func @basic() -> index {
241+
// CHECK: constant 2 : index
242+
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
243+
%c2 = constant 2 : index
244+
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
245+
return %1 : index
246+
}
247+
248+
// -----
249+
250+
// Should not fold.
251+
// CHECK-LABEL: func @out_of_bounds
252+
func @out_of_bounds() -> index {
253+
// CHECK: shape.const_shape
254+
// CHECK: shape.get_extent
255+
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
256+
%c3 = constant 3 : index
257+
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
258+
return %1 : index
259+
}
260+
261+
// -----
262+
263+
// Should not fold.
264+
// CHECK-LABEL: func @not_const
265+
func @not_const(%arg0: tensor<?xindex>) -> index {
266+
// CHECK: shape.get_extent
267+
%c3 = constant 3 : index
268+
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
269+
return %0 : index
270+
}
271+
272+
// -----
273+
238274
// Basic folding.
239275
// CHECK-LABEL: func @basic
240276
func @basic() -> !shape.size {
241277
// CHECK: shape.const_size 2
242-
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
278+
%0 = shape.const_shape [0, 1, 2] : !shape.shape
243279
%c2 = shape.const_size 2
244-
%1 = shape.get_extent %0, %c2 : tensor<?xindex>
280+
%1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
245281
return %1 : !shape.size
246282
}
247283

@@ -252,24 +288,23 @@ func @basic() -> !shape.size {
252288
func @out_of_bounds() -> !shape.size {
253289
// CHECK: shape.const_shape
254290
// CHECK: shape.get_extent
255-
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
291+
%0 = shape.const_shape [0, 1, 2] : !shape.shape
256292
%c3 = shape.const_size 3
257-
%1 = shape.get_extent %0, %c3 : tensor<?xindex>
293+
%1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
258294
return %1 : !shape.size
259295
}
260296

261297
// -----
262298

263299
// Should not fold.
264300
// CHECK-LABEL: func @not_const
265-
func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
301+
func @not_const(%arg0 : !shape.shape) -> !shape.size {
266302
// CHECK: shape.get_extent
267303
%c3 = shape.const_size 3
268-
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
304+
%0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
269305
return %0 : !shape.size
270306
}
271307

272-
273308
// -----
274309
// cstr_eq with non-constant but known equal shapes can be removed.
275310
// CHECK-LABEL: func @f

mlir/test/Dialect/Shape/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
102102
%0 = shape.rank %arg : !shape.shape -> index
103103
}
104104

105+
// -----
106+
107+
func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
108+
%c0 = constant 0 : index
109+
// expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
110+
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
111+
return %result : !shape.size
112+
}
113+
114+
// -----
115+
116+
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
117+
%c0 = shape.const_size 0
118+
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
119+
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
120+
return %result : index
121+
}
122+

mlir/test/Dialect/Shape/ops.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
163163

164164
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
165165
%c0 = shape.const_size 0
166-
%result = shape.get_extent %arg, %c0 : !shape.shape
166+
%result = shape.get_extent %arg, %c0 :
167+
!shape.shape, !shape.size -> !shape.size
167168
return %result : !shape.size
168169
}
169170

170-
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
171+
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
172+
%c0 = constant 0 : index
173+
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
174+
return %result : index
175+
}
176+
177+
func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
171178
%c0 = shape.const_size 0
172-
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
179+
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
173180
return %result : !shape.size
174181
}
175182

0 commit comments

Comments
 (0)