Skip to content

Commit 933221e

Browse files
authored
[mlir][tosa] Rename RFFT2D input to input_real (#130614)
This is to align to the input name as defined in the specification. Signed-off-by: Luke Hutton <[email protected]>
1 parent c7f7ac7 commit 933221e

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
396396
}];
397397

398398
let arguments = (ins
399-
Tosa_Tensor3D:$input,
399+
Tosa_Tensor3D:$input_real,
400400
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
401401
);
402402

@@ -411,7 +411,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
411411
];
412412

413413
let assemblyFormat = [{
414-
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
414+
$input_real attr-dict `:` `(` type($input_real) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
415415
}];
416416

417417
let hasVerifier = 1;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2614,7 +2614,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
26142614
}
26152615

26162616
auto loc = rfft2d.getLoc();
2617-
auto input = rfft2d.getInput();
2617+
auto input = rfft2d.getInputReal();
26182618
auto elementType =
26192619
dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
26202620
if (!elementType)

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
805805
MLIRContext *context, ::std::optional<Location> location,
806806
RFFT2dOp::Adaptor adaptor,
807807
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
808-
ShapeAdaptor inputShape(adaptor.getInput().getType());
808+
ShapeAdaptor inputShape(adaptor.getInputReal().getType());
809809

810810
if (!inputShape.hasRank())
811811
return failure();
@@ -842,7 +842,8 @@ LogicalResult tosa::RFFT2dOp::verify() {
842842
if (failed(verifyCompatibleShapes(outputTypes)))
843843
return emitOpError("expected output shapes to match, got ") << outputTypes;
844844

845-
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
845+
const auto inputType =
846+
llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
846847
if (!inputType)
847848
return success();
848849

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
160160

161161
template <>
162162
void ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
163-
addValue(op.getInput());
163+
addValue(op.getInputReal());
164164
addValue(op.getOutputReal());
165165
addValue(op.getOutputImag());
166166
}

0 commit comments

Comments
 (0)