Skip to content

Commit 54f4632

Browse files
Tai78641Jerry-Ge
authored andcommitted
Remove FullyConnectedOp from TOSA Dialect
This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms. Signed-off-by: Tai Ly <[email protected]> Change-Id: Ib8c928cb21daf325f00cdad302680af2d7c13da5
1 parent 01072e5 commit 54f4632

File tree

17 files changed

+4
-642
lines changed

17 files changed

+4
-642
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
150150
outputShape, acc_type);
151151
}]>;
152152

153-
// The tosa.fully_connected op has its own builder as it does not have
154-
// strides/dilation/padding.
155-
def Tosa_FCOpQuantInfoBuilder : OpBuilder<
156-
(ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
157-
[{
158-
buildFCOpWithQuantInfo($_builder, $_state, outputType,
159-
input, weight, bias);
160-
}]>;
161-
162153
// The tosa.matmul op is also intended to be generated where a fully_connected
163154
// op must be constructed where the weight is not a constant. In this case,
164155
// the fully_connected op must be expressed using matmul.

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
224224
}];
225225
}
226226

227-
//===----------------------------------------------------------------------===//
228-
// Operator: fully_connected
229-
//===----------------------------------------------------------------------===//
230-
def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
231-
let summary = "Fully Connected operator";
232-
233-
let description = [{
234-
Performs a fully connected network.
235-
}];
236-
237-
let arguments = (ins
238-
Tosa_Tensor2D:$input,
239-
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
240-
Tosa_Tensor1D:$bias,
241-
OptionalAttr<I32Attr>:$input_zp,
242-
OptionalAttr<I32Attr>:$weight_zp
243-
);
244-
245-
let results = (outs
246-
Tosa_Tensor2D:$output
247-
);
248-
249-
let builders = [Tosa_FCOpQuantInfoBuilder];
250-
let hasVerifier = 1;
251-
}
252-
253227
//===----------------------------------------------------------------------===//
254228
// Operator: matmul
255229
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8181
"number">;
8282

8383
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
84-
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
84+
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
8585
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
8686
Tosa_QuantizedInt, AnyFloat]>;
8787

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ namespace tosa {
2626

2727
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
2828
// The rewrites can be selectively added to a conversion pass.
29-
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
3029
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
3130
RewritePatternSet &patterns);
3231
void populateTosaDecomposeDepthwise(MLIRContext *ctx,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
607607
}
608608
};
609609

610-
class FullyConnectedConverter
611-
: public OpConversionPattern<tosa::FullyConnectedOp> {
612-
public:
613-
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
614-
LogicalResult
615-
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
616-
ConversionPatternRewriter &rewriter) const final {
617-
Location loc = op.getLoc();
618-
auto outputTy = cast<ShapedType>(op.getType());
619-
auto input = op.getInput();
620-
auto inputTy = cast<ShapedType>(input.getType());
621-
622-
auto bias = op.getBias();
623-
624-
auto weight = op.getWeight();
625-
auto weightTy = cast<ShapedType>(weight.getType());
626-
auto weightShape = weightTy.getShape();
627-
628-
auto outputETy = outputTy.getElementType();
629-
630-
SmallVector<Value> dynDims;
631-
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
632-
633-
if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
634-
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
635-
}
636-
637-
if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
638-
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
639-
}
640-
641-
SmallVector<Value> filteredDims = condenseValues(dynDims);
642-
643-
SmallVector<int64_t> permutation = {1, 0};
644-
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
645-
Value permutationValue =
646-
rewriter.create<arith::ConstantOp>(loc, permutationAttr);
647-
648-
SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
649-
Type newWeightTy =
650-
RankedTensorType::get(newWeightShape, weightTy.getElementType());
651-
652-
Value transposedWeight = rewriter.create<tosa::TransposeOp>(
653-
loc, newWeightTy, weight, permutationValue);
654-
655-
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
656-
loc, outputTy.getShape(), outputETy, filteredDims);
657-
658-
Value broadcastBias =
659-
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
660-
661-
if (!op.getInputZp() && !op.getWeightZp()) {
662-
Value matmul = rewriter
663-
.create<linalg::MatmulOp>(
664-
loc, TypeRange{op.getType()},
665-
ValueRange{input, transposedWeight}, broadcastBias)
666-
->getResult(0);
667-
668-
rewriter.replaceOp(op, matmul);
669-
return success();
670-
}
671-
672-
auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
673-
auto outputZp =
674-
rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
675-
Value matmul =
676-
rewriter
677-
.create<linalg::QuantizedMatmulOp>(
678-
loc, TypeRange{op.getType()},
679-
ValueRange{input, transposedWeight, inputZp, outputZp},
680-
broadcastBias)
681-
->getResult(0);
682-
683-
rewriter.replaceOp(op, matmul);
684-
return success();
685-
}
686-
};
687-
688610
class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
689611
public:
690612
using OpConversionPattern::OpConversionPattern;
@@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
10901012
DepthwiseConvConverter,
10911013
MatMulConverter,
10921014
AvgPool2dConverter,
1093-
FullyConnectedConverter,
10941015
TransposeConverter
10951016
>(patterns->getContext());
10961017

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ struct TosaToLinalgNamed
6262
target.addIllegalOp<tosa::MaxPool2dOp>();
6363
target.addIllegalOp<tosa::AvgPool2dOp>();
6464
target.addIllegalOp<tosa::MatMulOp>();
65-
target.addIllegalOp<tosa::FullyConnectedOp>();
6665
target.addIllegalOp<tosa::TransposeOp>();
6766

6867
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -540,26 +540,9 @@ static void buildTransConvOpWithQuantInfo(
540540
}
541541
}
542542

543-
/// The tosa.fully_connected op has its own builder as it does not have
544-
/// strides/dilation/padding.
545-
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
546-
Type outputType, Value input, Value weight,
547-
Value bias) {
548-
549-
result.addOperands({input, weight, bias});
550-
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
551-
if (quantAttr) {
552-
result.addAttribute("quantization_info", quantAttr);
553-
result.addTypes(
554-
buildConvOpResultTypeInfo(builder, outputType, input, weight));
555-
} else {
556-
result.addTypes(outputType);
557-
}
558-
}
559-
560-
/// The tosa.matmul op is also intended to be generated where a
561-
/// fully_connected op must be constructed where the weight is not a constant.
562-
/// In this case, the fully_connected op must be expressed using matmul.
543+
/// The tosa.matmul op is also intended to be generated where a fully_connected
544+
/// op must be constructed where the weight is not a constant. In this case,
545+
/// the fully_connected op must be expressed using matmul.
563546
/// TODO: Add link to the leglization document explaining this.
564547
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
565548
OperationState &result, Type outputType,
@@ -863,76 +846,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
863846
return succeeded(verifyCompatibleShape(l[0], r[0]));
864847
}
865848

866-
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
867-
MLIRContext *context, ::std::optional<Location> location,
868-
FullyConnectedOp::Adaptor adaptor,
869-
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
870-
ShapeAdaptor inputShape(adaptor.getInput().getType());
871-
ShapeAdaptor weightShape(adaptor.getWeight().getType());
872-
ShapeAdaptor biasShape(adaptor.getBias().getType());
873-
874-
// All shapes are dynamic.
875-
SmallVector<int64_t> outShape;
876-
outShape.resize(2, ShapedType::kDynamic);
877-
878-
if (inputShape.hasRank()) {
879-
outShape[0] = inputShape.getDimSize(0);
880-
}
881-
882-
if (weightShape.hasRank()) {
883-
outShape[1] = weightShape.getDimSize(0);
884-
}
885-
886-
if (biasShape.hasRank()) {
887-
outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
888-
: outShape[1];
889-
}
890-
891-
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
892-
return success();
893-
}
894-
895-
LogicalResult FullyConnectedOp::verify() {
896-
// All TOSA conv ops have an input() and weight().
897-
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
898-
899-
RankedTensorType weightType =
900-
llvm::dyn_cast<RankedTensorType>(getWeight().getType());
901-
902-
// Must be ranked tensor types
903-
if (!inputType) {
904-
emitOpError("expect a ranked tensor for input, got ") << getInput();
905-
return failure();
906-
}
907-
if (!weightType) {
908-
emitOpError("expect a ranked tensor for weight, got ") << getWeight();
909-
return failure();
910-
}
911-
912-
auto inputEType = inputType.getElementType();
913-
auto weightEType = weightType.getElementType();
914-
915-
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
916-
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
917-
918-
// Either both must be quantized or both unquantized.
919-
if (inputIsQuant != weightIsQuant) {
920-
emitOpError(
921-
"expect both input and weight to be float or not together, got ")
922-
<< inputEType << " and " << weightEType;
923-
return failure();
924-
}
925-
926-
// Quantized type must have constructed the quantizationattr, and unquantized
927-
// types should not have a quantizationattr.
928-
if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
929-
emitOpError("input zero point is required for quantized type, and not "
930-
"allowed for float type");
931-
return failure();
932-
}
933-
return success();
934-
}
935-
936849
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
937850
MLIRContext *context, ::std::optional<Location> location,
938851
MatMulOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
22
TosaDecomposeTransposeConv.cpp
3-
TosaDecomposeConv2D.cpp
43
TosaDecomposeDepthwise.cpp
54
TosaFolders.cpp
65
TosaInferShapes.cpp

0 commit comments

Comments
 (0)