Skip to content

Commit 4edeaff

Browse files
[mlir][tosa] Fix tosa.Resize-to-linalg lowering (#88514)
1 parent 06947b9 commit 4edeaff

File tree

2 files changed

+60
-73
lines changed

2 files changed

+60
-73
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,17 +1582,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15821582
}
15831583
// x = x * scale_d + offset;
15841584
// ix = floor(x / scale_n)
1585-
// dx = x / scale_n - ix
1586-
Value val = b.create<arith::UIToFPOp>(floatTy, in);
1587-
scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN);
1588-
scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD);
1589-
offset = b.create<arith::SIToFPOp>(floatTy, offset);
1590-
val = b.create<arith::MulFOp>(val, scaleD);
1591-
val = b.create<arith::AddFOp>(val, offset);
1592-
val = b.create<arith::DivFOp>(val, scaleN);
1593-
index = b.create<math::FloorOp>(val);
1594-
delta = b.create<arith::SubFOp>(val, index);
1595-
index = b.create<arith::FPToSIOp>(b.getI32Type(), index);
1585+
Value val = b.create<arith::MulIOp>(in, scaleD);
1586+
val = b.create<arith::AddIOp>(val, offset);
1587+
index = b.create<arith::FloorDivSIOp>(val, scaleN);
1588+
1589+
// rx = x % scale_n
1590+
// dx = rx / scale_n
1591+
Value r = b.create<arith::RemSIOp>(val, scaleN);
1592+
Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1593+
Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1594+
delta = b.create<arith::DivFOp>(rFp, scaleNfp);
15961595
};
15971596

15981597
// Compute the ix and dx values for the X and Y dimensions - int case.

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -304,50 +304,44 @@ func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
304304
// CHECK-DAG: %[[XMAX:.*]] = arith.constant 47
305305
// CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]]
306306
// CHECK: %[[X:.+]] = arith.index_cast %[[IDX2]]
307-
// CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 64
308-
// CHECK-DAG: %[[ISCALE_Y_D:.*]] = arith.constant 2
309-
// CHECK-DAG: %[[ISCALE_X_N:.*]] = arith.constant 64
310-
// CHECK-DAG: %[[ISCALE_X_D:.*]] = arith.constant 2
311-
// CHECK-DAG: %[[IOFFSET_Y:.*]] = arith.constant -31
312-
// CHECK-DAG: %[[IOFFSET_X:.*]] = arith.constant -31
313-
// CHECK-DAG: %[[IBORDER_Y:.*]] = arith.constant 31
314-
// CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 31
315-
316-
// CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
317-
// CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
318-
// CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
319-
// CHECK: %[[OFFSET_Y:.*]] = arith.sitofp %[[IOFFSET_Y]]
320-
// CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
321-
// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
322-
// CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
323-
// CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
324-
// CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
325-
// CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
326-
327-
// CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
328-
// CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
329-
// CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
330-
// CHECK: %[[OFFSET_X:.*]] = arith.sitofp %[[IOFFSET_X]]
331-
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
332-
// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
333-
// CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
334-
// CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
335-
// CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
336-
// CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]]
307+
// CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 64
308+
// CHECK-DAG: %[[SCALE_Y_D:.*]] = arith.constant 2
309+
// CHECK-DAG: %[[SCALE_X_N:.*]] = arith.constant 64
310+
// CHECK-DAG: %[[SCALE_X_D:.*]] = arith.constant 2
311+
// CHECK-DAG: %[[OFFSET_Y:.*]] = arith.constant -31
312+
// CHECK-DAG: %[[OFFSET_X:.*]] = arith.constant -31
313+
// CHECK-DAG: %[[BORDER_Y:.*]] = arith.constant 31
314+
// CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 31
315+
316+
// CHECK: %[[VAL_29:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
317+
// CHECK: %[[Y_TEMP:.*]] = arith.addi %[[VAL_29]], %[[OFFSET_Y]]
318+
// CHECK: %[[IY_TEMP:.*]] = arith.floordivsi %[[Y_TEMP]], %[[SCALE_Y_N]]
319+
// CHECK: %[[RY:.*]] = arith.remsi %[[Y_TEMP]], %[[SCALE_Y_N]]
320+
// CHECK: %[[RY_FP:.*]] = arith.sitofp %[[RY]]
321+
// CHECK: %[[SCALE_Y_N_FP:.*]] = arith.uitofp %[[SCALE_Y_N]]
322+
// CHECK: %[[D_Y:.*]] = arith.divf %[[RY_FP]], %[[SCALE_Y_N_FP]]
323+
324+
// CHECK: %[[VAL_30:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
325+
// CHECK: %[[X_TEMP:.*]] = arith.addi %[[VAL_30]], %[[OFFSET_X]]
326+
// CHECK: %[[IX_TEMP:.*]] = arith.floordivsi %[[X_TEMP]], %[[SCALE_X_N]]
327+
// CHECK: %[[RX:.*]] = arith.remsi %[[X_TEMP]], %[[SCALE_X_N]]
328+
// CHECK: %[[RX_FP:.*]] = arith.sitofp %[[RX]]
329+
// CHECK: %[[SCALE_X_N_FP:.*]] = arith.uitofp %[[SCALE_X_N]]
330+
// CHECK: %[[D_X:.*]] = arith.divf %[[RX_FP]], %[[SCALE_X_N_FP]]
337331

338332
// CHECK-DAG: %[[ONE:.*]] = arith.constant 1
339333
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
340334
// CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
341335
// CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
342-
// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
336+
// CHECK: %[[VAL_48:.*]] = arith.addi %[[IY_TEMP]], %[[ROUND_Y]]
343337
// CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_48]]
344338
// CHECK: %[[CLAMPED:.*]] = arith.minsi %[[YMAX]], %[[LOWER]]
345339
// CHECK: %[[IDY:.*]] = arith.index_cast %[[CLAMPED]]
346340

347341
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
348342
// CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
349343
// CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
350-
// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
344+
// CHECK: %[[VAL_49:.*]] = arith.addi %[[IX_TEMP]], %[[ROUND_X]]
351345
// CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_49]]
352346
// CHECK: %[[CLAMPED:.*]] = arith.minsi %[[XMAX]], %[[LOWER]]
353347
// CHECK: %[[IDX:.*]] = arith.index_cast %[[CLAMPED]]
@@ -374,36 +368,30 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () {
374368
// CHECK-DAG: %[[X_MAX:.*]] = arith.constant 23
375369
// CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
376370
// CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
377-
// CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 4
378-
// CHECK-DAG: %[[ISCALE_Y_D:.*]] = arith.constant 1
379-
// CHECK-DAG: %[[ISCALE_X_N:.*]] = arith.constant 4
380-
// CHECK-DAG: %[[ISCALE_X_D:.*]] = arith.constant 1
381-
// CHECK-DAG: %[[IOFFSET_Y:.*]] = arith.constant 0
382-
// CHECK-DAG: %[[IOFFSET_X:.*]] = arith.constant 0
383-
// CHECK-DAG: %[[IBORDER_Y:.*]] = arith.constant 0
384-
// CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 0
385-
386-
// CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
387-
// CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
388-
// CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
389-
// CHECK: %[[OFFSET_Y:.*]] = arith.sitofp %[[IOFFSET_Y]]
390-
// CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
391-
// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
392-
// CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
393-
// CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
394-
// CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
395-
// CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
396-
397-
// CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
398-
// CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
399-
// CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
400-
// CHECK: %[[OFFSET_X:.*]] = arith.sitofp %[[IOFFSET_X]]
401-
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
402-
// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
403-
// CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
404-
// CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
405-
// CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
406-
// CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]]
371+
// CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 4
372+
// CHECK-DAG: %[[SCALE_Y_D:.*]] = arith.constant 1
373+
// CHECK-DAG: %[[SCALE_X_N:.*]] = arith.constant 4
374+
// CHECK-DAG: %[[SCALE_X_D:.*]] = arith.constant 1
375+
// CHECK-DAG: %[[OFFSET_Y:.*]] = arith.constant 0
376+
// CHECK-DAG: %[[OFFSET_X:.*]] = arith.constant 0
377+
// CHECK-DAG: %[[BORDER_Y:.*]] = arith.constant 0
378+
// CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0
379+
380+
// CHECK: %[[VAL_29:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
381+
// CHECK: %[[Y_TEMP:.*]] = arith.addi %[[VAL_29]], %[[OFFSET_Y]]
382+
// CHECK: %[[I_Y:.*]] = arith.floordivsi %[[Y_TEMP]], %[[SCALE_Y_N]]
383+
// CHECK: %[[RY:.*]] = arith.remsi %[[Y_TEMP]], %[[SCALE_Y_N]]
384+
// CHECK: %[[RY_FP:.*]] = arith.sitofp %[[RY]]
385+
// CHECK: %[[SCALE_Y_N_FP:.*]] = arith.uitofp %[[SCALE_Y_N]]
386+
// CHECK: %[[D_Y:.*]] = arith.divf %[[RY_FP]], %[[SCALE_Y_N_FP]]
387+
388+
// CHECK: %[[VAL_30:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
389+
// CHECK: %[[X_TEMP:.*]] = arith.addi %[[VAL_30]], %[[OFFSET_X]]
390+
// CHECK: %[[I_X:.*]] = arith.floordivsi %[[X_TEMP]], %[[SCALE_X_N]]
391+
// CHECK: %[[RX:.*]] = arith.remsi %[[X_TEMP]], %[[SCALE_X_N]]
392+
// CHECK: %[[RX_FP:.*]] = arith.sitofp %[[RX]]
393+
// CHECK: %[[SCALE_X_N_FP:.*]] = arith.uitofp %[[SCALE_X_N]]
394+
// CHECK: %[[D_X:.*]] = arith.divf %[[RX_FP]], %[[SCALE_X_N_FP]]
407395

408396
// Compute the left, right, and top indices for the bilinear interpolation.
409397

0 commit comments

Comments
 (0)