|
13 | 13 | #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
|
14 | 14 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
15 | 15 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
| 16 | +#include "mlir/Dialect/Index/IR/IndexOps.h" |
16 | 17 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
17 | 18 | #include "mlir/Dialect/Math/IR/Math.h"
|
18 | 19 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
@@ -2367,16 +2368,25 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
|
2367 | 2368 | Value sumImag = args[2];
|
2368 | 2369 |
|
2369 | 2370 | // Indices for angle computation
|
2370 |
| - auto oy = createLinalgIndex(builder, loc, elementType, 1); |
2371 |
| - auto ox = createLinalgIndex(builder, loc, elementType, 2); |
2372 |
| - auto iy = createLinalgIndex(builder, loc, elementType, 3); |
2373 |
| - auto ix = createLinalgIndex(builder, loc, elementType, 4); |
2374 |
| - |
2375 |
| - // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W) |
2376 |
| - auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy); |
2377 |
| - auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox); |
2378 |
| - auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH); |
2379 |
| - auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW); |
| 2371 | + Value oy = builder.create<linalg::IndexOp>(loc, 1); |
| 2372 | + Value ox = builder.create<linalg::IndexOp>(loc, 2); |
| 2373 | + Value iy = builder.create<linalg::IndexOp>(loc, 3); |
| 2374 | + Value ix = builder.create<linalg::IndexOp>(loc, 4); |
| 2375 | + |
| 2376 | + // Calculating angle without integer parts of components as sin/cos are |
| 2377 | + // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) |
| 2378 | + // / W); |
| 2379 | + auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
| 2380 | + auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
| 2381 | + |
| 2382 | + auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
| 2383 | + auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
| 2384 | + |
| 2385 | + auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); |
| 2386 | + auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); |
| 2387 | + |
| 2388 | + auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
| 2389 | + auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
2380 | 2390 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
|
2381 | 2391 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
|
2382 | 2392 |
|
@@ -2481,22 +2491,30 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
|
2481 | 2491 | Value sumImag = args[3];
|
2482 | 2492 |
|
2483 | 2493 | // Indices for angle computation
|
2484 |
| - Value oy = |
2485 |
| - RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1); |
2486 |
| - Value ox = |
2487 |
| - RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2); |
2488 |
| - Value iy = |
2489 |
| - RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3); |
2490 |
| - Value ix = |
2491 |
| - RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4); |
2492 |
| - |
2493 |
| - // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W); |
2494 |
| - auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy); |
2495 |
| - auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox); |
2496 |
| - auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH); |
2497 |
| - auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW); |
| 2494 | + Value oy = builder.create<linalg::IndexOp>(loc, 1); |
| 2495 | + Value ox = builder.create<linalg::IndexOp>(loc, 2); |
| 2496 | + Value iy = builder.create<linalg::IndexOp>(loc, 3); |
| 2497 | + Value ix = builder.create<linalg::IndexOp>(loc, 4); |
| 2498 | + |
| 2499 | + // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * |
| 2500 | + // ox) % W ) / W); |
| 2501 | + auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
| 2502 | + auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
| 2503 | + |
| 2504 | + auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
| 2505 | + auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
| 2506 | + |
| 2507 | + auto iyRemFloat = |
| 2508 | + RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); |
| 2509 | + auto ixRemFloat = |
| 2510 | + RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); |
| 2511 | + |
| 2512 | + auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
| 2513 | + auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
| 2514 | + |
2498 | 2515 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
|
2499 | 2516 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
|
| 2517 | + |
2500 | 2518 | if (inverse.getValue()) {
|
2501 | 2519 | angle = builder.create<arith::MulFOp>(
|
2502 | 2520 | loc, angle,
|
|
0 commit comments