Skip to content

Commit 143edec

Browse files
committed
[mlir][tosa] Shape inference for a few remaining easy cases:
Handles shape inference for identity, cast, and rescale. These were missed during the initialy elementwise work. This includes resize shape propagation which includes both attribute and input type based propagation. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D105845
1 parent 2d9759c commit 143edec

File tree

3 files changed

+188
-27
lines changed

3 files changed

+188
-27
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,7 +1582,10 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
15821582
//===----------------------------------------------------------------------===//
15831583
// Operator: resize
15841584
//===----------------------------------------------------------------------===//
1585-
def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
1585+
def Tosa_ResizeOp : Tosa_Op<"resize", [
1586+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1587+
["inferReturnTypeComponents"]>,
1588+
NoSideEffect]> {
15861589

15871590
let summary = "Resize operation, supports various resize/upsample modes";
15881591

@@ -1617,7 +1620,9 @@ def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
16171620
//===----------------------------------------------------------------------===//
16181621
// Operator: cast
16191622
//===----------------------------------------------------------------------===//
1620-
def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect]> {
1623+
def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect,
1624+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1625+
["inferReturnTypeComponents"]>]> {
16211626

16221627
let summary = "Cast operation";
16231628

@@ -1655,7 +1660,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect]> {
16551660
//===----------------------------------------------------------------------===//
16561661
// Operator: rescale
16571662
//===----------------------------------------------------------------------===//
1658-
def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
1663+
def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect,
1664+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1665+
["inferReturnTypeComponents"]>]> {
16591666
let summary = "Tosa rescale operator";
16601667

16611668
let description = [{
@@ -1723,7 +1730,9 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect,
17231730
//===----------------------------------------------------------------------===//
17241731
// Operator: identity
17251732
//===----------------------------------------------------------------------===//
1726-
def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> {
1733+
def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect,
1734+
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1735+
["inferReturnTypeComponents"]>]> {
17271736
let summary = "Identity operator";
17281737
let description = [{
17291738
Returns a tensor with the same shape, size, type

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,12 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
345345
}
346346
}
347347

348+
static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
349+
for (auto it : arrayAttr) {
350+
values.push_back(it.cast<FloatAttr>().getValueAsDouble());
351+
}
352+
}
353+
348354
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
349355
MLIRContext *context, ::llvm::Optional<Location> location,
350356
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -386,13 +392,13 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
386392

387393
// Copy the Operand's rank.
388394
if (!hasRankedInput)
389-
outputShape.resize(operandTy.getRank(), -1);
395+
outputShape.resize(operandTy.getRank(), ShapedType::kDynamicSize);
390396

391397
// Copy shapes until the dim is non-dynamic.
392398
for (int i = 0, s = operandTy.getRank(); i < s; i++) {
393399
if (i == axis || operandTy.isDynamicDim(i))
394400
continue;
395-
if (outputShape[i] == -1)
401+
if (outputShape[i] == ShapedType::kDynamicSize)
396402
outputShape[i] = operandTy.getDimSize(i);
397403
if (outputShape[i] != operandTy.getDimSize(i))
398404
return failure();
@@ -414,7 +420,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
414420
// We need to know the length of the concatenation axis of all inputs to
415421
// determine the dimension size of the output shape.
416422
if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
417-
concatDimSize = -1;
423+
concatDimSize = ShapedType::kDynamicSize;
418424
break;
419425
}
420426

@@ -437,7 +443,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
437443

438444
// All shapes are dynamic.
439445
SmallVector<int64_t> outShape;
440-
outShape.resize(2, -1);
446+
outShape.resize(2, ShapedType::kDynamicSize);
441447

442448
if (inputTy.hasRank()) {
443449
outShape[0] = inputTy.getDimSize(0);
@@ -448,7 +454,8 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
448454
}
449455

450456
if (biasTy.hasRank()) {
451-
outShape[1] = outShape[1] == -1 ? biasTy.getDimSize(0) : outShape[1];
457+
outShape[1] = outShape[1] == ShapedType::kDynamicSize ? biasTy.getDimSize(0)
458+
: outShape[1];
452459
}
453460

454461
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -464,15 +471,16 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
464471

465472
// All shapes are dynamic.
466473
SmallVector<int64_t> outShape;
467-
outShape.resize(3, -1);
474+
outShape.resize(3, ShapedType::kDynamicSize);
468475

469476
if (lhsTy.hasRank()) {
470477
outShape[0] = lhsTy.getDimSize(0);
471478
outShape[1] = lhsTy.getDimSize(1);
472479
}
473480

474481
if (rhsTy.hasRank()) {
475-
outShape[0] = outShape[0] == -1 ? rhsTy.getDimSize(0) : outShape[0];
482+
outShape[0] = outShape[0] == ShapedType::kDynamicSize ? rhsTy.getDimSize(0)
483+
: outShape[0];
476484
outShape[2] = rhsTy.getDimSize(2);
477485
}
478486

@@ -503,15 +511,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
503511
return success();
504512
}
505513

506-
outputShape.resize(paddingTy.getDimSize(0), -1);
514+
outputShape.resize(paddingTy.getDimSize(0), ShapedType::kDynamicSize);
507515
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
508516
return success();
509517
}
510518

511519
DenseIntElementsAttr paddings;
512520
// If the paddings value is not a constant, all dimensions must be dynamic.
513521
if (!matchPattern(operands[1], m_Constant(&paddings))) {
514-
outputShape.resize(inputTy.getRank(), -1);
522+
outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
515523
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
516524
return success();
517525
}
@@ -524,7 +532,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
524532
outputShape.reserve(inputTy.getRank());
525533
for (int i = 0, s = inputTy.getRank(); i < s; i++) {
526534
if (inputTy.isDynamicDim(i)) {
527-
outputShape.push_back(-1);
535+
outputShape.push_back(ShapedType::kDynamicSize);
528536
continue;
529537
}
530538

@@ -574,7 +582,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
574582
ShapedType inputTy = operands[0].getType().cast<ShapedType>();
575583
SmallVector<int64_t> outputShape;
576584
if (!inputTy.hasRank()) {
577-
outputShape.resize(multiples.size(), -1);
585+
outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
578586
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
579587
return success();
580588
}
@@ -590,7 +598,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
590598
outputShape.reserve(multiples.size());
591599
for (int i = 0, s = inputTy.getRank(); i < s; i++) {
592600
int dim = inputTy.getDimSize(i);
593-
if (dim != -1)
601+
if (dim != ShapedType::kDynamicSize)
594602
dim *= multipleValues[i];
595603
outputShape.push_back(dim);
596604
}
@@ -622,14 +630,14 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
622630
int64_t numElements = type.getNumElements();
623631
int64_t staticMul = 1;
624632
for (auto val : newShapeValue) {
625-
if (val != -1) {
633+
if (val != ShapedType::kDynamicSize) {
626634
staticMul *= val;
627635
}
628636
}
629637

630638
// Determine the length of the dynamic dimension.
631639
for (auto &val : newShapeValue) {
632-
if (val == -1)
640+
if (val == ShapedType::kDynamicSize)
633641
val = numElements / staticMul;
634642
}
635643

@@ -655,7 +663,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
655663
// can determine the output rank.
656664
SmallVector<int64_t> outputShape;
657665
if (!inputTy.hasRank()) {
658-
outputShape.resize(permsTy.getDimSize(0), -1);
666+
outputShape.resize(permsTy.getDimSize(0), ShapedType::kDynamicSize);
659667
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
660668
return success();
661669
}
@@ -684,7 +692,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
684692
}
685693

686694
DenseIntElementsAttr perms;
687-
outputShape.resize(inputTy.getRank(), -1);
695+
outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
688696
// If the permuations are a constant we can directly determine the output
689697
// shape.
690698
if (matchPattern(operands[1], m_Constant(&perms))) {
@@ -708,30 +716,100 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
708716
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
709717
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
710718
llvm::SmallVector<int64_t> outputShape;
711-
outputShape.resize(3, -1);
719+
outputShape.resize(3, ShapedType::kDynamicSize);
712720

713721
if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
714722
outputShape[0] = ty.getDimSize(0);
715723
outputShape[2] = ty.getDimSize(2);
716724
}
717725

718726
if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
719-
if (outputShape[0] == -1)
727+
if (outputShape[0] == ShapedType::kDynamicSize)
720728
outputShape[0] = ty.getDimSize(0);
721-
if (outputShape[1] == -1)
729+
if (outputShape[1] == ShapedType::kDynamicSize)
722730
outputShape[1] = ty.getDimSize(1);
723731
}
724732

725733
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
726734
return success();
727735
}
728736

737+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
738+
MLIRContext *context, ::llvm::Optional<Location> location,
739+
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
740+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
741+
llvm::SmallVector<int64_t, 4> outputShape;
742+
outputShape.resize(4, ShapedType::kDynamicSize);
743+
744+
int32_t inHeight = ShapedType::kDynamicSize;
745+
int32_t inWidth = ShapedType::kDynamicSize;
746+
747+
if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
748+
outputShape[0] = ty.getDimSize(0);
749+
outputShape[3] = ty.getDimSize(3);
750+
751+
inHeight = ty.getDimSize(1);
752+
inWidth = ty.getDimSize(2);
753+
}
754+
755+
int32_t shift =
756+
attributes.get("shift").cast<IntegerAttr>().getValue().getSExtValue();
757+
llvm::SmallVector<int64_t> newShape;
758+
getI64Values(attributes.get("output_size").cast<ArrayAttr>(), newShape);
759+
outputShape[1] = newShape[0];
760+
outputShape[2] = newShape[1];
761+
762+
llvm::SmallVector<int64_t> strideInt;
763+
llvm::SmallVector<int64_t> offsetInt;
764+
llvm::SmallVector<double> strideFp;
765+
llvm::SmallVector<double> offsetFp;
766+
getI64Values(attributes.get("offset").cast<ArrayAttr>(), offsetInt);
767+
getF64Values(attributes.get("offset_fp").cast<ArrayAttr>(), offsetFp);
768+
getI64Values(attributes.get("stride").cast<ArrayAttr>(), strideInt);
769+
getF64Values(attributes.get("stride_fp").cast<ArrayAttr>(), strideFp);
770+
771+
// If we have a 0 zero in integers we know that the resize indexing needs to
772+
// be performed in floating point. Use the floating point varient to compute
773+
// the resize shape.
774+
bool fpMode = strideInt[0] == 0;
775+
776+
// We can compute the output shape if attribute specifies unknown dimensions
777+
// based on the offset and stride. If we perfectly line up to the last index
778+
// we need to round up the size to include it.
779+
if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
780+
float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
781+
float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
782+
outputShape[1] = std::ceil(sizeFp) + round;
783+
}
784+
785+
if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
786+
float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
787+
float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
788+
outputShape[2] = std::ceil(sizeFp) + round;
789+
}
790+
791+
if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
792+
int64_t size = (inHeight - 1);
793+
size = ((size << shift) - offsetInt[0]) / strideInt[0];
794+
outputShape[1] = size + 1;
795+
}
796+
797+
if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
798+
int64_t size = (inWidth - 1);
799+
size = ((size << shift) - offsetInt[1]) / strideInt[1];
800+
outputShape[2] = size + 1;
801+
}
802+
803+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
804+
return success();
805+
}
806+
729807
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
730808
MLIRContext *context, ::llvm::Optional<Location> location,
731809
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
732810
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
733811
llvm::SmallVector<int64_t> outputShape;
734-
outputShape.resize(3, -1);
812+
outputShape.resize(3, ShapedType::kDynamicSize);
735813

736814
if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
737815
outputShape[0] = ty.getDimSize(0);
@@ -740,14 +818,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
740818
}
741819

742820
if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
743-
if (outputShape[0] == -1)
821+
if (outputShape[0] == ShapedType::kDynamicSize)
744822
outputShape[0] = ty.getDimSize(0);
745823
}
746824

747825
if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
748-
if (outputShape[0] == -1)
826+
if (outputShape[0] == ShapedType::kDynamicSize)
749827
outputShape[0] = ty.getDimSize(0);
750-
if (outputShape[2] == -1)
828+
if (outputShape[2] == ShapedType::kDynamicSize)
751829
outputShape[2] = ty.getDimSize(2);
752830
}
753831

@@ -859,6 +937,7 @@ NARY_SHAPE_INFER(tosa::BitwiseAndOp)
859937
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
860938
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
861939
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
940+
NARY_SHAPE_INFER(tosa::CastOp)
862941
NARY_SHAPE_INFER(tosa::CeilOp)
863942
NARY_SHAPE_INFER(tosa::ClampOp)
864943
NARY_SHAPE_INFER(tosa::ClzOp)
@@ -868,6 +947,7 @@ NARY_SHAPE_INFER(tosa::ExpOp)
868947
NARY_SHAPE_INFER(tosa::FloorOp)
869948
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
870949
NARY_SHAPE_INFER(tosa::GreaterOp)
950+
NARY_SHAPE_INFER(tosa::IdentityOp)
871951
NARY_SHAPE_INFER(tosa::LogOp)
872952
NARY_SHAPE_INFER(tosa::LogicalAndOp)
873953
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
@@ -882,6 +962,7 @@ NARY_SHAPE_INFER(tosa::NegateOp)
882962
NARY_SHAPE_INFER(tosa::PowOp)
883963
NARY_SHAPE_INFER(tosa::ReciprocalOp)
884964
NARY_SHAPE_INFER(tosa::ReluNOp)
965+
NARY_SHAPE_INFER(tosa::RescaleOp)
885966
NARY_SHAPE_INFER(tosa::ReverseOp)
886967
NARY_SHAPE_INFER(tosa::RsqrtOp)
887968
NARY_SHAPE_INFER(tosa::SelectOp)

0 commit comments

Comments
 (0)