-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add ERROR_IF checks to TRANSPOSE_CONV2D verifier #133234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2896,6 +2896,118 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( | |
LogicalResult TransposeConv2DOp::verify() { | ||
if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) | ||
return failure(); | ||
|
||
const llvm::ArrayRef<int64_t> strides = getStride(); | ||
const int64_t strideY = strides[0]; | ||
const int64_t strideX = strides[1]; | ||
|
||
if (strideY < 1 || strideX < 1) | ||
return emitOpError("expect all stride values to be >= 1, got [") | ||
<< strides << "]"; | ||
|
||
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think there's a slight bit of readability value in moving these casts to where they're used |
||
|
||
const auto outputType = | ||
llvm::dyn_cast<RankedTensorType>(getOutput().getType()); | ||
|
||
const auto weightType = | ||
llvm::dyn_cast<RankedTensorType>(getWeight().getType()); | ||
|
||
const auto checkPadAgainstKernelDim = | ||
[this](int64_t pad_value, int64_t kernel_dim_size, | ||
llvm::StringRef pad_name, | ||
llvm::StringRef kernel_dim_name) -> LogicalResult { | ||
if (pad_value <= -kernel_dim_size) | ||
return emitOpError("expected ") | ||
<< pad_name << " > -" << kernel_dim_name | ||
<< ", but got: " << pad_name << "=" << pad_value << " and " | ||
<< kernel_dim_name << "=" << kernel_dim_size; | ||
return success(); | ||
}; | ||
|
||
const llvm::ArrayRef<int64_t> padding = getOutPad(); | ||
|
||
const int64_t outPadTop = padding[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the dialect these are currently int64 (see https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td#L216), though I agree these should probably be changed to conform to the spec at some point There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other dialects like Linalg (https://mlir.llvm.org/docs/Dialects/Linalg/#linalgconv_2d_nchw_fchw-linalgconv2dnchwfchwop) is using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept it as |
||
const int64_t outPadBottom = padding[1]; | ||
|
||
const int64_t kernelHeight = weightType.getDimSize(1); | ||
lhutton1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (!ShapedType::isDynamic(kernelHeight)) { | ||
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top", | ||
"KH"))) | ||
return failure(); | ||
|
||
if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight, | ||
"out_pad_bottom", "KH"))) | ||
return failure(); | ||
} | ||
|
||
const int64_t kernelWidth = weightType.getDimSize(2); | ||
|
||
const int64_t outPadLeft = padding[2]; | ||
const int64_t outPadRight = padding[3]; | ||
|
||
if (!ShapedType::isDynamic(kernelWidth)) { | ||
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left", | ||
"KW"))) | ||
return failure(); | ||
|
||
if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth, | ||
"out_pad_right", "KW"))) | ||
return failure(); | ||
} | ||
|
||
// Rest of the checks depend on the output type being a RankedTensorType | ||
if (!outputType) | ||
return success(); | ||
|
||
const int64_t inputHeight = inputType.getDimSize(1); | ||
lhutton1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const int64_t outputHeight = outputType.getDimSize(1); | ||
|
||
if (!ShapedType::isDynamic(inputHeight) && | ||
!ShapedType::isDynamic(outputHeight)) { | ||
if (outputHeight != | ||
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight) | ||
return emitOpError( | ||
"dimension mismatch: expected OH == (IH - 1) * stride_y " | ||
"+ out_pad_top + out_pad_bottom + KH, but got ") | ||
<< outputHeight << " != (" << inputHeight << " - 1) * " << strideY | ||
<< " + " << outPadTop << " + " << outPadBottom << " + " | ||
<< kernelHeight; | ||
} | ||
|
||
const int64_t inputWidth = inputType.getDimSize(2); | ||
const int64_t outputWidth = outputType.getDimSize(2); | ||
|
||
if (!ShapedType::isDynamic(inputWidth) && | ||
!ShapedType::isDynamic(outputWidth)) { | ||
if (outputWidth != | ||
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth) | ||
return emitOpError( | ||
"dimension mismatch: expected OW == (IW - 1) * stride_x " | ||
"+ out_pad_left + out_pad_right + KW, but got ") | ||
<< outputWidth << " != (" << inputWidth << " - 1) * " << strideX | ||
<< " + " << outPadLeft << " + " << outPadRight << " + " | ||
<< kernelWidth; | ||
} | ||
|
||
const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType()); | ||
|
||
if (!biasType) | ||
return success(); | ||
|
||
const int64_t biasChannels = biasType.getDimSize(0); | ||
|
||
// Skip further checks if bias is dynamic | ||
if (biasChannels == ShapedType::kDynamic) | ||
return success(); | ||
|
||
const int64_t outputChannels = outputType.getDimSize(3); | ||
if (biasChannels != outputChannels && biasChannels != 1) | ||
return emitOpError( | ||
"bias channels expected to be equal to output channels (") | ||
<< outputChannels << ") or 1, got " << biasChannels; | ||
lhutton1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return success(); | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.