Skip to content

Commit 73f487d

Browse files
authored
[mlir][TosaToLinalg] Fix bugs in PointwiseConverter (#132526)
1 parent a2e5932 commit 73f487d

File tree

2 files changed

+39
-59
lines changed

2 files changed

+39
-59
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
711711
return nullptr;
712712
}
713713

714-
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
715-
int64_t rank) {
716-
// No need to expand if we are already at the desired rank
717-
auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
718-
assert(tensorType && "expected a ranked tensor type");
719-
int64_t tensorRank = tensorType.getRank();
720-
int64_t numExtraDims = rank - tensorRank;
721-
assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
722-
if (!numExtraDims)
723-
return tensor;
724-
725-
// Compute reassociation indices
726-
SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
727-
int64_t index = 0;
728-
if (tensorRank != 0) {
729-
for (index = 0; index <= numExtraDims; index++)
730-
reassociationIndices[0].push_back(index);
731-
for (size_t position = 1; position < reassociationIndices.size();
732-
position++)
733-
reassociationIndices[position].push_back(index++);
734-
}
735-
736-
// Compute result type
737-
SmallVector<int64_t> resultShape;
738-
for (index = 0; index < numExtraDims; index++)
739-
resultShape.push_back(1);
740-
for (auto size : tensorType.getShape())
741-
resultShape.push_back(size);
742-
auto resultType =
743-
RankedTensorType::get(resultShape, tensorType.getElementType());
744-
745-
// Emit 'tensor.expand_shape' op
746-
return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
747-
reassociationIndices);
748-
}
749-
750-
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
751-
Location loc, ValueRange operands,
752-
int64_t rank) {
753-
return llvm::map_to_vector(operands, [&](Value operand) {
754-
return expandRank(rewriter, loc, operand, rank);
755-
});
756-
}
757-
758714
using IndexPool = DenseMap<int64_t, Value>;
759715

760716
// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
1036992
return success();
1037993
}
1038994

995+
static ValueRange getBroadcastableOperands(Operation *operation,
996+
ValueRange operands) {
997+
// Shift cannot broadcast
998+
if (isa<tosa::MulOp>(operation))
999+
return operands.take_front(2);
1000+
// Input1_zp and output_zp cannot broadcast
1001+
if (isa<tosa::NegateOp>(operation))
1002+
return operands.take_front(1);
1003+
return operands;
1004+
}
1005+
10391006
static LogicalResult
10401007
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
10411008
ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
10521019
// Lower operation
10531020
IndexPool indexPool;
10541021
auto loc = operation->getLoc();
1055-
auto rank =
1056-
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
1057-
// For the mul op we need to avoid expanding the rank of the optional shift
1058-
// input.
1059-
auto operandsToExpand =
1060-
isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
1061-
1062-
auto expandedOperands =
1063-
expandInputRanks(rewriter, loc, operandsToExpand, rank);
1022+
auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
10641023
auto [targetShape, masterOperands] =
1065-
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
1066-
auto broadcastOperands = broadcastDynamicDimensions(
1067-
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
1024+
computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1025+
auto broadcastOperands =
1026+
broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1027+
targetShape, masterOperands);
10681028
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
10691029
targetShape, converter);
10701030
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
664664
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
665665

666666
// CHECK: linalg.generic
667-
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
667+
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
668668
// CHECK: [[ZERO:%.+]] = arith.constant 0
669669
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
670670
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
856856
// CHECK-LABEL: @test_negate_quantized
857857
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
858858
// CHECK: linalg.generic
859-
// CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
859+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
860860
// CHECK: [[CNST:%.+]] = arith.constant 7
861861
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
862862
// CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
871871
%0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
872872

873873
// CHECK: linalg.generic
874-
// CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
874+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
875875
// CHECK: [[C_128:%.+]] = arith.constant -128
876876
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
877877
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
23172317

23182318
return
23192319
}
2320+
2321+
// -----
2322+
2323+
// CHECK-LABEL: @test_0d_input
2324+
func.func @test_0d_input(%arg0: tensor<i32>) -> () {
2325+
// CHECK: linalg.generic
2326+
// CHECK: arith.muli
2327+
%shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
2328+
%0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
2329+
2330+
// CHECK: linalg.generic
2331+
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
2332+
// CHECK: [[ZERO:%.+]] = arith.constant 0
2333+
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
2334+
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
2335+
%out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
2336+
%5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
2337+
2338+
return
2339+
}

0 commit comments

Comments
 (0)