Skip to content

[mlir][tosa] Remove FullyConnectedOp from TOSA Dialect #126152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
outputShape, acc_type);
}]>;

// The tosa.fully_connected op has its own builder as it does not have
// strides/dilation/padding.
def Tosa_FCOpQuantInfoBuilder : OpBuilder<
(ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
[{
buildFCOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias);
}]>;

// The tosa.matmul op is also intended to be generated where a fully_connected
// op must be constructed where the weight is not a constant. In this case,
// the fully_connected op must be expressed using matmul.
Expand Down
26 changes: 0 additions & 26 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
}];
}

//===----------------------------------------------------------------------===//
// Operator: fully_connected
//===----------------------------------------------------------------------===//
def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
let summary = "Fully Connected operator";

let description = [{
Performs a fully connected network.
}];

let arguments = (ins
Tosa_Tensor2D:$input,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$weight_zp
);

let results = (outs
Tosa_Tensor2D:$output
);

let builders = [Tosa_FCOpQuantInfoBuilder];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;

Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ namespace tosa {

// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
// The rewrites can be selectively added to a conversion pass.
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
Expand Down
79 changes: 0 additions & 79 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
}
};

class FullyConnectedConverter
: public OpConversionPattern<tosa::FullyConnectedOp> {
public:
using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto outputTy = cast<ShapedType>(op.getType());
auto input = op.getInput();
auto inputTy = cast<ShapedType>(input.getType());

auto bias = op.getBias();

auto weight = op.getWeight();
auto weightTy = cast<ShapedType>(weight.getType());
auto weightShape = weightTy.getShape();

auto outputETy = outputTy.getElementType();

SmallVector<Value> dynDims;
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());

if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
}

if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
}

SmallVector<Value> filteredDims = condenseValues(dynDims);

SmallVector<int64_t> permutation = {1, 0};
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
rewriter.create<arith::ConstantOp>(loc, permutationAttr);

SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());

Value transposedWeight = rewriter.create<tosa::TransposeOp>(
loc, newWeightTy, weight, permutationValue);

Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputETy, filteredDims);

Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (!op.getInputZp() && !op.getWeightZp()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight}, broadcastBias)
->getResult(0);

rewriter.replaceOp(op, matmul);
return success();
}

auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
Value matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
broadcastBias)
->getResult(0);

rewriter.replaceOp(op, matmul);
return success();
}
};

class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
DepthwiseConvConverter,
MatMulConverter,
AvgPool2dConverter,
FullyConnectedConverter,
TransposeConverter
>(patterns->getContext());

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::MaxPool2dOp>();
target.addIllegalOp<tosa::AvgPool2dOp>();
target.addIllegalOp<tosa::MatMulOp>();
target.addIllegalOp<tosa::FullyConnectedOp>();
target.addIllegalOp<tosa::TransposeOp>();

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
Expand Down
93 changes: 3 additions & 90 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,26 +566,9 @@ static void buildTransConvOpWithQuantInfo(
result.addTypes(finalOutputType);
}

/// The tosa.fully_connected op has its own builder as it does not have
/// strides/dilation/padding.
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias) {

result.addOperands({input, weight, bias});
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
result.addTypes(
buildConvOpResultTypeInfo(builder, outputType, input, weight));
} else {
result.addTypes(outputType);
}
}

/// The tosa.matmul op is also intended to be generated where a
/// fully_connected op must be constructed where the weight is not a constant.
/// In this case, the fully_connected op must be expressed using matmul.
/// The tosa.matmul op is also intended to be generated where a fully_connected
/// op must be constructed where the weight is not a constant. In this case,
/// the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Expand Down Expand Up @@ -889,76 +872,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return succeeded(verifyCompatibleShape(l[0], r[0]));
}

LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FullyConnectedOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput().getType());
ShapeAdaptor weightShape(adaptor.getWeight().getType());
ShapeAdaptor biasShape(adaptor.getBias().getType());

// All shapes are dynamic.
SmallVector<int64_t> outShape;
outShape.resize(2, ShapedType::kDynamic);

if (inputShape.hasRank()) {
outShape[0] = inputShape.getDimSize(0);
}

if (weightShape.hasRank()) {
outShape[1] = weightShape.getDimSize(0);
}

if (biasShape.hasRank()) {
outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
: outShape[1];
}

inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
return success();
}

LogicalResult FullyConnectedOp::verify() {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());

RankedTensorType weightType =
llvm::dyn_cast<RankedTensorType>(getWeight().getType());

// Must be ranked tensor types
if (!inputType) {
emitOpError("expect a ranked tensor for input, got ") << getInput();
return failure();
}
if (!weightType) {
emitOpError("expect a ranked tensor for weight, got ") << getWeight();
return failure();
}

auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();

bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);

// Either both must be quantized or both unquantized.
if (inputIsQuant != weightIsQuant) {
emitOpError(
"expect both input and weight to be float or not together, got ")
<< inputEType << " and " << weightEType;
return failure();
}

// Quantized type must have constructed the quantizationattr, and unquantized
// types should not have a quantizationattr.
if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
emitOpError("input zero point is required for quantized type, and not "
"allowed for float type");
return failure();
}
return success();
}

LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFolders.cpp
TosaInferShapes.cpp
Expand Down
Loading