Skip to content

Commit 5661242

Browse files
authored
[TOSA] FFT2D operator (#77005)
This PR adds lowering for TOSA Fft2d operator down to Linalg.
1 parent 959a430 commit 5661242

File tree

2 files changed

+264
-0
lines changed

2 files changed

+264
-0
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,6 +2344,135 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
23442344
}
23452345
};
23462346

2347+
struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2348+
using OpRewritePattern::OpRewritePattern;
2349+
2350+
LogicalResult matchAndRewrite(FFT2dOp fft2d,
2351+
PatternRewriter &rewriter) const override {
2352+
if (!llvm::all_of(fft2d->getOperandTypes(),
2353+
RFFT2dConverter::isRankedTensor) ||
2354+
!llvm::all_of(fft2d->getResultTypes(),
2355+
RFFT2dConverter::isRankedTensor)) {
2356+
return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2357+
}
2358+
2359+
Location loc = fft2d.getLoc();
2360+
Value input_real = fft2d.getInputReal();
2361+
Value input_imag = fft2d.getInputImag();
2362+
BoolAttr inverse = fft2d.getInverseAttr();
2363+
2364+
auto real_el_ty = cast<FloatType>(
2365+
cast<ShapedType>(input_real.getType()).getElementType());
2366+
auto imag_el_ty = cast<FloatType>(
2367+
cast<ShapedType>(input_imag.getType()).getElementType());
2368+
2369+
assert(real_el_ty == imag_el_ty);
2370+
2371+
// Compute the output type and set of dynamic sizes
2372+
SmallVector<Value> dynamicSizes;
2373+
2374+
// Get [N, H, W]
2375+
ArrayRef<OpFoldResult> dims =
2376+
tensor::getMixedSizes(rewriter, loc, input_real);
2377+
2378+
SmallVector<int64_t, 3> staticSizes;
2379+
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2380+
2381+
auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2382+
2383+
// Iterator types for the linalg.generic implementation
2384+
SmallVector<utils::IteratorType, 5> iteratorTypes = {
2385+
utils::IteratorType::parallel, utils::IteratorType::parallel,
2386+
utils::IteratorType::parallel, utils::IteratorType::reduction,
2387+
utils::IteratorType::reduction};
2388+
2389+
// Inputs/outputs to the linalg.generic implementation
2390+
SmallVector<Value> genericOpInputs = {input_real, input_imag};
2391+
SmallVector<Value> genericOpOutputs = {
2392+
RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2393+
dynamicSizes),
2394+
RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2395+
dynamicSizes)};
2396+
2397+
// Indexing maps for input and output tensors
2398+
auto indexingMaps = AffineMap::inferFromExprList(
2399+
ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2400+
RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2401+
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2402+
RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)});
2403+
2404+
// Width and height dimensions of the original input.
2405+
auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2406+
auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2407+
2408+
// Constants and dimension sizes
2409+
auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2410+
auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2411+
Value constH =
2412+
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2413+
Value constW =
2414+
RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2415+
2416+
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2417+
Value valReal = args[0];
2418+
Value valImag = args[1];
2419+
Value sumReal = args[2];
2420+
Value sumImag = args[3];
2421+
2422+
// Indices for angle computation
2423+
Value oy =
2424+
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
2425+
Value ox =
2426+
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
2427+
Value iy =
2428+
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
2429+
Value ix =
2430+
RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
2431+
2432+
// float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
2433+
auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
2434+
auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
2435+
auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
2436+
auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
2437+
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2438+
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2439+
if (inverse.getValue()) {
2440+
angle = builder.create<arith::MulFOp>(
2441+
loc, angle,
2442+
rewriter.create<arith::ConstantOp>(
2443+
loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2444+
}
2445+
2446+
// realComponent = val_real * cos(a) + val_imag * sin(a);
2447+
// imagComponent = -val_real * sin(a) + val_imag * cos(a);
2448+
auto cosAngle = builder.create<math::CosOp>(loc, angle);
2449+
auto sinAngle = builder.create<math::SinOp>(loc, angle);
2450+
2451+
auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2452+
auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2453+
auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2454+
2455+
auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2456+
auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2457+
2458+
auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2459+
2460+
// outReal = sumReal + realComponent
2461+
// outImag = sumImag - imagComponent
2462+
auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2463+
auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2464+
2465+
builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2466+
};
2467+
2468+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2469+
fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2470+
indexingMaps, iteratorTypes, buildBody);
2471+
2472+
return success();
2473+
}
2474+
};
2475+
23472476
} // namespace
23482477

23492478
void mlir::tosa::populateTosaToLinalgConversionPatterns(
@@ -2407,6 +2536,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
24072536
RescaleConverter,
24082537
ReverseConverter,
24092538
RFFT2dConverter,
2539+
FFT2dConverter,
24102540
TableConverter,
24112541
TileConverter>(patterns->getContext());
24122542
// clang-format on

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,3 +1739,137 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
17391739
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
17401740
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
17411741
}
1742+
1743+
// -----
1744+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1745+
1746+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1747+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1748+
1749+
// CHECK-LABEL: func.func @test_static_fft2d(
1750+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8x8xf32>,
1751+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1752+
// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<8x8x8xf32>
1753+
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
1754+
// CHECK: %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_2]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
1755+
// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<8x8x8xf32>
1756+
// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
1757+
// CHECK: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_5]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
1758+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
1759+
// CHECK: %[[VAL_9:.*]] = arith.constant 8 : index
1760+
// CHECK: %[[VAL_10:.*]] = arith.constant 2 : index
1761+
// CHECK: %[[VAL_11:.*]] = arith.constant 8 : index
1762+
// CHECK: %[[VAL_12:.*]] = arith.constant 6.28318548 : f32
1763+
// CHECK: %[[VAL_13:.*]] = arith.index_castui %[[VAL_9]] : index to i32
1764+
// CHECK: %[[VAL_14:.*]] = arith.uitofp %[[VAL_13]] : i32 to f32
1765+
// CHECK: %[[VAL_15:.*]] = arith.index_castui %[[VAL_11]] : index to i32
1766+
// CHECK: %[[VAL_16:.*]] = arith.uitofp %[[VAL_15]] : i32 to f32
1767+
// CHECK: %[[VAL_17:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1768+
// CHECK: ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32):
1769+
// CHECK: %[[VAL_22:.*]] = linalg.index 1 : index
1770+
// CHECK: %[[VAL_23:.*]] = arith.index_castui %[[VAL_22]] : index to i32
1771+
// CHECK: %[[VAL_24:.*]] = arith.uitofp %[[VAL_23]] : i32 to f32
1772+
// CHECK: %[[VAL_25:.*]] = linalg.index 2 : index
1773+
// CHECK: %[[VAL_26:.*]] = arith.index_castui %[[VAL_25]] : index to i32
1774+
// CHECK: %[[VAL_27:.*]] = arith.uitofp %[[VAL_26]] : i32 to f32
1775+
// CHECK: %[[VAL_28:.*]] = linalg.index 3 : index
1776+
// CHECK: %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32
1777+
// CHECK: %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32
1778+
// CHECK: %[[VAL_31:.*]] = linalg.index 4 : index
1779+
// CHECK: %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32
1780+
// CHECK: %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32
1781+
// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_30]], %[[VAL_24]] : f32
1782+
// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f32
1783+
// CHECK: %[[VAL_36:.*]] = arith.divf %[[VAL_34]], %[[VAL_14]] : f32
1784+
// CHECK: %[[VAL_37:.*]] = arith.divf %[[VAL_35]], %[[VAL_16]] : f32
1785+
// CHECK: %[[VAL_38:.*]] = arith.addf %[[VAL_36]], %[[VAL_37]] : f32
1786+
// CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_12]], %[[VAL_38]] : f32
1787+
// CHECK: %[[VAL_40:.*]] = math.cos %[[VAL_39]] : f32
1788+
// CHECK: %[[VAL_41:.*]] = math.sin %[[VAL_39]] : f32
1789+
// CHECK: %[[VAL_42:.*]] = arith.mulf %[[VAL_18]], %[[VAL_40]] : f32
1790+
// CHECK: %[[VAL_43:.*]] = arith.mulf %[[VAL_19]], %[[VAL_41]] : f32
1791+
// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
1792+
// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_19]], %[[VAL_40]] : f32
1793+
// CHECK: %[[VAL_46:.*]] = arith.mulf %[[VAL_18]], %[[VAL_41]] : f32
1794+
// CHECK: %[[VAL_47:.*]] = arith.subf %[[VAL_45]], %[[VAL_46]] : f32
1795+
// CHECK: %[[VAL_48:.*]] = arith.addf %[[VAL_20]], %[[VAL_44]] : f32
1796+
// CHECK: %[[VAL_49:.*]] = arith.addf %[[VAL_21]], %[[VAL_47]] : f32
1797+
// CHECK: linalg.yield %[[VAL_48]], %[[VAL_49]] : f32, f32
1798+
// CHECK: } -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
1799+
// CHECK: return %[[VAL_50:.*]]#0, %[[VAL_50]]#1 : tensor<8x8x8xf32>, tensor<8x8x8xf32>
1800+
// CHECK: }
1801+
func.func @test_static_fft2d(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
1802+
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse=false} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
1803+
return %output_real, %output_imag : tensor<8x8x8xf32>, tensor<8x8x8xf32>
1804+
}
1805+
1806+
// -----
1807+
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
1808+
1809+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
1810+
// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
1811+
1812+
// CHECK-LABEL: func.func @test_dynamic_fft2d(
1813+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>,
1814+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1815+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
1816+
// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
1817+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
1818+
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
1819+
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
1820+
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
1821+
// CHECK: %[[VAL_8:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
1822+
// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
1823+
// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_8]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1824+
// CHECK: %[[VAL_11:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
1825+
// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
1826+
// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
1827+
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
1828+
// CHECK: %[[VAL_15:.*]] = tensor.dim %[[VAL_0]], %[[VAL_14]] : tensor<?x?x?xf32>
1829+
// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index
1830+
// CHECK: %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?x?xf32>
1831+
// CHECK: %[[VAL_18:.*]] = arith.constant 6.28318548 : f32
1832+
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
1833+
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
1834+
// CHECK: %[[VAL_21:.*]] = arith.index_castui %[[VAL_17]] : index to i32
1835+
// CHECK: %[[VAL_22:.*]] = arith.uitofp %[[VAL_21]] : i32 to f32
1836+
// CHECK: %[[VAL_23:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_3]], #[[$ATTR_3]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_10]], %[[VAL_13]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1837+
// CHECK: ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
1838+
// CHECK: %[[VAL_28:.*]] = linalg.index 1 : index
1839+
// CHECK: %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32
1840+
// CHECK: %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32
1841+
// CHECK: %[[VAL_31:.*]] = linalg.index 2 : index
1842+
// CHECK: %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32
1843+
// CHECK: %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32
1844+
// CHECK: %[[VAL_34:.*]] = linalg.index 3 : index
1845+
// CHECK: %[[VAL_35:.*]] = arith.index_castui %[[VAL_34]] : index to i32
1846+
// CHECK: %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
1847+
// CHECK: %[[VAL_37:.*]] = linalg.index 4 : index
1848+
// CHECK: %[[VAL_38:.*]] = arith.index_castui %[[VAL_37]] : index to i32
1849+
// CHECK: %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
1850+
// CHECK: %[[VAL_40:.*]] = arith.mulf %[[VAL_36]], %[[VAL_30]] : f32
1851+
// CHECK: %[[VAL_41:.*]] = arith.mulf %[[VAL_39]], %[[VAL_33]] : f32
1852+
// CHECK: %[[VAL_42:.*]] = arith.divf %[[VAL_40]], %[[VAL_20]] : f32
1853+
// CHECK: %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_22]] : f32
1854+
// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
1855+
// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_18]], %[[VAL_44]] : f32
1856+
// CHECK: %[[VAL_46:.*]] = arith.constant -1.000000e+00 : f32
1857+
// CHECK: %[[VAL_47:.*]] = arith.mulf %[[VAL_45]], %[[VAL_46]] : f32
1858+
// CHECK: %[[VAL_48:.*]] = math.cos %[[VAL_47]] : f32
1859+
// CHECK: %[[VAL_49:.*]] = math.sin %[[VAL_47]] : f32
1860+
// CHECK: %[[VAL_50:.*]] = arith.mulf %[[VAL_24]], %[[VAL_48]] : f32
1861+
// CHECK: %[[VAL_51:.*]] = arith.mulf %[[VAL_25]], %[[VAL_49]] : f32
1862+
// CHECK: %[[VAL_52:.*]] = arith.addf %[[VAL_50]], %[[VAL_51]] : f32
1863+
// CHECK: %[[VAL_53:.*]] = arith.mulf %[[VAL_25]], %[[VAL_48]] : f32
1864+
// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_24]], %[[VAL_49]] : f32
1865+
// CHECK: %[[VAL_55:.*]] = arith.subf %[[VAL_53]], %[[VAL_54]] : f32
1866+
// CHECK: %[[VAL_56:.*]] = arith.addf %[[VAL_26]], %[[VAL_52]] : f32
1867+
// CHECK: %[[VAL_57:.*]] = arith.addf %[[VAL_27]], %[[VAL_55]] : f32
1868+
// CHECK: linalg.yield %[[VAL_56]], %[[VAL_57]] : f32, f32
1869+
// CHECK: } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1870+
// CHECK: return %[[VAL_58:.*]]#0, %[[VAL_58]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1871+
// CHECK: }
1872+
func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
1873+
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
1874+
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
1875+
}

0 commit comments

Comments
 (0)