@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
711
711
return nullptr ;
712
712
}
713
713
714
- static Value expandRank (PatternRewriter &rewriter, Location loc, Value tensor,
715
- int64_t rank) {
716
- // No need to expand if we are already at the desired rank
717
- auto tensorType = dyn_cast<RankedTensorType>(tensor.getType ());
718
- assert (tensorType && " expected a ranked tensor type" );
719
- int64_t tensorRank = tensorType.getRank ();
720
- int64_t numExtraDims = rank - tensorRank;
721
- assert (numExtraDims >= 0 && " cannot expand tensor to a lower rank" );
722
- if (!numExtraDims)
723
- return tensor;
724
-
725
- // Compute reassociation indices
726
- SmallVector<ReassociationIndices> reassociationIndices (tensorRank);
727
- int64_t index = 0 ;
728
- if (tensorRank != 0 ) {
729
- for (index = 0 ; index <= numExtraDims; index++)
730
- reassociationIndices[0 ].push_back (index);
731
- for (size_t position = 1 ; position < reassociationIndices.size ();
732
- position++)
733
- reassociationIndices[position].push_back (index++);
734
- }
735
-
736
- // Compute result type
737
- SmallVector<int64_t > resultShape;
738
- for (index = 0 ; index < numExtraDims; index++)
739
- resultShape.push_back (1 );
740
- for (auto size : tensorType.getShape ())
741
- resultShape.push_back (size);
742
- auto resultType =
743
- RankedTensorType::get (resultShape, tensorType.getElementType ());
744
-
745
- // Emit 'tensor.expand_shape' op
746
- return rewriter.create <tensor::ExpandShapeOp>(loc, resultType, tensor,
747
- reassociationIndices);
748
- }
749
-
750
- static SmallVector<Value> expandInputRanks (PatternRewriter &rewriter,
751
- Location loc, ValueRange operands,
752
- int64_t rank) {
753
- return llvm::map_to_vector (operands, [&](Value operand) {
754
- return expandRank (rewriter, loc, operand, rank);
755
- });
756
- }
757
-
758
714
using IndexPool = DenseMap<int64_t , Value>;
759
715
760
716
// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
1036
992
return success ();
1037
993
}
1038
994
995
+ static ValueRange getBroadcastableOperands (Operation *operation,
996
+ ValueRange operands) {
997
+ // Shift cannot broadcast
998
+ if (isa<tosa::MulOp>(operation))
999
+ return operands.take_front (2 );
1000
+ // Input1_zp and output_zp cannot broadcast
1001
+ if (isa<tosa::NegateOp>(operation))
1002
+ return operands.take_front (1 );
1003
+ return operands;
1004
+ }
1005
+
1039
1006
static LogicalResult
1040
1007
elementwiseMatchAndRewriteHelper (Operation *operation, ValueRange operands,
1041
1008
ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
1052
1019
// Lower operation
1053
1020
IndexPool indexPool;
1054
1021
auto loc = operation->getLoc ();
1055
- auto rank =
1056
- cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
1057
- // For the mul op we need to avoid expanding the rank of the optional shift
1058
- // input.
1059
- auto operandsToExpand =
1060
- isa<tosa::MulOp>(operation) ? operands.take_front (2 ) : operands;
1061
-
1062
- auto expandedOperands =
1063
- expandInputRanks (rewriter, loc, operandsToExpand, rank);
1022
+ auto operandsToBroadcast = getBroadcastableOperands (operation, operands);
1064
1023
auto [targetShape, masterOperands] =
1065
- computeTargetShape (rewriter, loc, indexPool, expandedOperands);
1066
- auto broadcastOperands = broadcastDynamicDimensions (
1067
- rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
1024
+ computeTargetShape (rewriter, loc, indexPool, operandsToBroadcast);
1025
+ auto broadcastOperands =
1026
+ broadcastDynamicDimensions (rewriter, loc, indexPool, operandsToBroadcast,
1027
+ targetShape, masterOperands);
1068
1028
return emitElementwiseComputation (rewriter, loc, operation, broadcastOperands,
1069
1029
targetShape, converter);
1070
1030
}
0 commit comments