Skip to content

[mlir][tosa] Enhance CONV3D & DEPTHWISE_CONV2D verifier #135738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 150 additions & 94 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,150 @@ static LogicalResult verifyConvOpModes(T op) {
return success();
}

//===----------------------------------------------------------------------===//
// ERROR_IF functions.
// ERROR_IF is a predicate that must set an error if the condition holds.
//===----------------------------------------------------------------------===//

template <typename T>
static LogicalResult verifyConvOpErrorIf(T op) {
llvm::ArrayRef<int64_t> padding = op.getPad();
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
return op.emitOpError("expect all padding values to be >= 0, got ")
<< padding;

llvm::ArrayRef<int64_t> strides = op.getStride();
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
return op.emitOpError("expect all stride values to be >= 1, got ")
<< strides;

llvm::ArrayRef<int64_t> dilations = op.getDilation();
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
return op.emitOpError("expect all dilation values to be >= 1, got ")
<< dilations;

const RankedTensorType outputType =
llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
if (!outputType)
// Skip following checks if output is not ranked
return success();

const RankedTensorType inputType =
llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
const RankedTensorType weightType =
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());

if (inputType && weightType) {
const auto verifyOutputSize =
[&op](const int64_t inputSize, const int64_t kernelSize,
const int64_t outputSize, const int64_t padBefore,
const int64_t padAfter, const int64_t stride,
const int64_t dilation, const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
const llvm::StringRef padAfterName) -> LogicalResult {
if (inputSize == ShapedType::kDynamic ||
kernelSize == ShapedType::kDynamic)
return success();

// ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1

const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
stride);
if (!calculatedOutSizeMinusOne.has_value())
return op.emitOpError("expected input_")
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
<< padAfterName << " - (kernel_" << dimName
<< " - 1) * dilation_" << dimAxis
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
<< stride;

const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
return op.emitOpError("calculated output ")
<< dimName << " did not match expected: "
<< "calculated=" << calculatedOutSize
<< ", expected=" << outputSize;

return success();
};

// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
if (failed(verifyOutputSize(
inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();

if (failed(verifyOutputSize(
inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
}

// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
if (failed(verifyOutputSize(
inputType.getDimSize(1), weightType.getDimSize(0),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();

if (failed(verifyOutputSize(
inputType.getDimSize(2), weightType.getDimSize(1),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
}

// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
if (failed(verifyOutputSize(
inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "depth", "d", "front", "back")))
return failure();

if (failed(verifyOutputSize(
inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "height", "y", "top", "bottom")))
return failure();

if (failed(verifyOutputSize(
inputType.getDimSize(3), weightType.getDimSize(3),
outputType.getDimSize(3), padding[4], padding[5], strides[2],
dilations[2], "width", "x", "left", "right")))
return failure();
}
}

const RankedTensorType biasType =
llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
if (!biasType)
// Skip following checks if bias is not ranked
return success();

const int64_t biasChannels = biasType.getDimSize(0);
const int64_t outputChannels = outputType.getDimSize(3);
if (biasChannels == ShapedType::kDynamic ||
outputChannels == ShapedType::kDynamic)
// Skip following checks if biasChannels or outputChannels is dynamic dim
return success();

if (biasChannels != outputChannels && biasChannels != 1)
return op.emitOpError(
"bias channels expected to be equal to output channels (")
<< outputChannels << ") or 1, got " << biasChannels;

return success();
}

// verify that inType and outType have same element types
template <typename T>
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
Expand Down Expand Up @@ -2586,99 +2730,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
}

LogicalResult Conv2DOp::verify() {
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
verifyConvOpErrorIf(*this).failed())
return failure();

llvm::ArrayRef<int64_t> padding = getPad();
if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
return emitOpError("expect all padding values to be >= 0, got ") << padding;

llvm::ArrayRef<int64_t> strides = getStride();
if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
return emitOpError("expect all stride values to be >= 1, got ") << strides;

llvm::ArrayRef<int64_t> dilations = getDilation();
if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
return emitOpError("expect all dilation values to be >= 1, got ")
<< dilations;

const RankedTensorType outputType =
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
if (!outputType)
// Skip following checks if output is not ranked
return success();

const RankedTensorType inputType =
llvm::dyn_cast<RankedTensorType>(getInput().getType());
const RankedTensorType weightType =
llvm::dyn_cast<RankedTensorType>(getWeight().getType());

if (inputType && weightType) {
const auto verifyOutputSize =
[this](const int64_t inputSize, const int64_t kernelSize,
const int64_t outputSize, const int64_t padBefore,
const int64_t padAfter, const int64_t stride,
const int64_t dilation, const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
const llvm::StringRef padBeforeName,
const llvm::StringRef padAfterName) -> LogicalResult {
if (inputSize == ShapedType::kDynamic ||
kernelSize == ShapedType::kDynamic)
return success();

const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
stride);
if (!calculatedOutSizeMinusOne.has_value())
return emitOpError("expected input_")
<< dimName << " - 1 + pad_" << padBeforeName << " + pad_"
<< padAfterName << " - (kernel_" << dimName
<< " - 1) * dilation_" << dimAxis
<< " to be wholly divisible by stride_" << dimAxis << ", got ("
<< inputSize << " - 1 + " << padBefore << " + " << padAfter
<< " - (" << kernelSize << " - 1) * " << dilation << ") / "
<< stride;

const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
return emitOpError("calculated output ")
<< dimName << " did not match expected: "
<< "calculated=" << calculatedOutSize
<< ", expected=" << outputSize;

return success();
};

if (failed(verifyOutputSize(
inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();

if (failed(verifyOutputSize(
inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
}

const RankedTensorType biasType =
llvm::dyn_cast<RankedTensorType>(getBias().getType());
if (!biasType)
// Skip following checks if bias is not ranked
return success();

const int64_t biasChannels = biasType.getDimSize(0);
const int64_t outputChannels = outputType.getDimSize(3);
if (biasChannels == ShapedType::kDynamic ||
outputChannels == ShapedType::kDynamic)
// Skip following checks if biasChannels or outputChannels is dynamic dim
return success();

if (biasChannels != outputChannels && biasChannels != 1)
return emitOpError(
"bias channels expected to be equal to output channels (")
<< outputChannels << ") or 1, got " << biasChannels;
return success();
}

Expand Down Expand Up @@ -2753,7 +2807,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
}

LogicalResult Conv3DOp::verify() {
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
verifyConvOpErrorIf(*this).failed())
return failure();
return success();
}
Expand Down Expand Up @@ -2863,7 +2918,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
}

LogicalResult DepthwiseConv2DOp::verify() {
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
verifyConvOpErrorIf(*this).failed())
return failure();
return success();
}
Expand Down
44 changes: 22 additions & 22 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -878,22 +878,22 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<43x3x4x5x27xf32>, %bias: tensor<43xf32>) -> () {
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xf32>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x47x45x43x28xf32>) {
// CHECK-SAME: ins(%arg2 : tensor<43xf32>) outs(%[[INIT]] : tensor<1x47x45x43x43xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<1x47x45x43x28xf32>
// CHECK: } -> tensor<1x47x45x43x43xf32>
// CHECK: linalg.conv_3d_ndhwc_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x43xf32>)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<43x3x4x5x27xf32>, tensor<43xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x43xf32>
return
}

Expand All @@ -919,40 +919,40 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5x27xi8>, %bias: tensor<43xi32>) -> () {
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%arg2 : tensor<28xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
// CHECK-SAME: ins(%arg2 : tensor<43xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x43xi32>) {
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
// CHECK: linalg.yield %[[IN]] : i32
// CHECK: } -> tensor<1x47x45x43x28xi32>
// CHECK: } -> tensor<1x47x45x43x43xi32>
// CHECK: %[[IZP:.+]] = arith.constant -128 : i32
// CHECK: %[[FZP:.+]] = arith.constant 42 : i32
// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x43xi8>, i32, i32)
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xi32>) -> tensor<1x47x45x43x43xi32>

%input_zp = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
%weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<43x3x4x5x27xi8>, tensor<43xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x43xi32>
return
}

// -----

// CHECK-LABEL: @conv3d_f16_f32_acc
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
// CHECK: arith.extf %{{.*}} : f16 to f32
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
return
}

Expand Down
Loading
Loading