|
13 | 13 | #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
|
14 | 14 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
15 | 15 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
16 |
| -#include "mlir/Dialect/Index/IR/IndexOps.h" |
17 | 16 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
18 | 17 | #include "mlir/Dialect/Math/IR/Math.h"
|
19 | 18 | #include "mlir/Dialect/SCF/IR/SCF.h"
|
@@ -2368,25 +2367,16 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
|
2368 | 2367 | Value sumImag = args[2];
|
2369 | 2368 |
|
2370 | 2369 | // Indices for angle computation
|
2371 |
| - Value oy = builder.create<linalg::IndexOp>(loc, 1); |
2372 |
| - Value ox = builder.create<linalg::IndexOp>(loc, 2); |
2373 |
| - Value iy = builder.create<linalg::IndexOp>(loc, 3); |
2374 |
| - Value ix = builder.create<linalg::IndexOp>(loc, 4); |
2375 |
| - |
2376 |
| - // Calculating angle without integer parts of components as sin/cos are |
2377 |
| - // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) |
2378 |
| - // / W); |
2379 |
| - auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
2380 |
| - auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
2381 |
| - |
2382 |
| - auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
2383 |
| - auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
2384 |
| - |
2385 |
| - auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); |
2386 |
| - auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); |
2387 |
| - |
2388 |
| - auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
2389 |
| - auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
| 2370 | + auto oy = createLinalgIndex(builder, loc, elementType, 1); |
| 2371 | + auto ox = createLinalgIndex(builder, loc, elementType, 2); |
| 2372 | + auto iy = createLinalgIndex(builder, loc, elementType, 3); |
| 2373 | + auto ix = createLinalgIndex(builder, loc, elementType, 4); |
| 2374 | + |
| 2375 | + // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W) |
| 2376 | + auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy); |
| 2377 | + auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox); |
| 2378 | + auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH); |
| 2379 | + auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW); |
2390 | 2380 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
|
2391 | 2381 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
|
2392 | 2382 |
|
@@ -2491,30 +2481,22 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
|
2491 | 2481 | Value sumImag = args[3];
|
2492 | 2482 |
|
2493 | 2483 | // Indices for angle computation
|
2494 |
| - Value oy = builder.create<linalg::IndexOp>(loc, 1); |
2495 |
| - Value ox = builder.create<linalg::IndexOp>(loc, 2); |
2496 |
| - Value iy = builder.create<linalg::IndexOp>(loc, 3); |
2497 |
| - Value ix = builder.create<linalg::IndexOp>(loc, 4); |
2498 |
| - |
2499 |
| - // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * |
2500 |
| - // ox) % W ) / W); |
2501 |
| - auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
2502 |
| - auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
2503 |
| - |
2504 |
| - auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
2505 |
| - auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
2506 |
| - |
2507 |
| - auto iyRemFloat = |
2508 |
| - RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); |
2509 |
| - auto ixRemFloat = |
2510 |
| - RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); |
2511 |
| - |
2512 |
| - auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
2513 |
| - auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
2514 |
| - |
| 2484 | + Value oy = |
| 2485 | + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1); |
| 2486 | + Value ox = |
| 2487 | + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2); |
| 2488 | + Value iy = |
| 2489 | + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3); |
| 2490 | + Value ix = |
| 2491 | + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4); |
| 2492 | + |
| 2493 | + // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W); |
| 2494 | + auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy); |
| 2495 | + auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox); |
| 2496 | + auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH); |
| 2497 | + auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW); |
2515 | 2498 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
|
2516 | 2499 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
|
2517 |
| - |
2518 | 2500 | if (inverse.getValue()) {
|
2519 | 2501 | angle = builder.create<arith::MulFOp>(
|
2520 | 2502 | loc, angle,
|
|
0 commit comments