@@ -1407,37 +1407,178 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1407
1407
}
1408
1408
};
1409
1409
1410
+ // Lowerings the TableOp to a series of gathers and numerica operations. This
1411
+ // includes interpolation between the high/low values. For the I8 varient, this
1412
+ // simplifies to a single gather operation.
1413
+ class TableConverter : public OpRewritePattern <tosa::TableOp> {
1414
+ public:
1415
+ using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
1416
+
1417
+ LogicalResult matchAndRewrite (tosa::TableOp op,
1418
+ PatternRewriter &rewriter) const final {
1419
+ auto loc = op.getLoc ();
1420
+ Value input = op.input ();
1421
+ Value table = op.table ();
1422
+ auto inputTy = input.getType ().cast <ShapedType>();
1423
+ auto tableTy = table.getType ().cast <ShapedType>();
1424
+ auto resultTy = op.getType ().cast <ShapedType>();
1425
+
1426
+ if (!inputTy.hasStaticShape ())
1427
+ return rewriter.notifyMatchFailure (
1428
+ op, " require input type to have static shape" );
1429
+
1430
+ auto inputElementTy = inputTy.getElementType ();
1431
+ auto tableElementTy = tableTy.getElementType ();
1432
+ auto resultElementTy = resultTy.getElementType ();
1433
+
1434
+ auto initTensor =
1435
+ rewriter
1436
+ .create <linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1437
+ resultTy.getShape (), resultElementTy)
1438
+ .result ();
1439
+
1440
+ SmallVector<AffineMap, 2 > affineMaps = {
1441
+ rewriter.getMultiDimIdentityMap (resultTy.getRank ()),
1442
+ rewriter.getMultiDimIdentityMap (resultTy.getRank ())};
1443
+
1444
+ auto genericOp = rewriter.create <linalg::GenericOp>(
1445
+ loc, resultTy, ValueRange ({input}), ValueRange{initTensor}, affineMaps,
1446
+ getNParallelLoopsAttrs (resultTy.getRank ()));
1447
+ rewriter.replaceOp (op, genericOp.getResult (0 ));
1448
+
1449
+ {
1450
+ OpBuilder::InsertionGuard regionGuard (rewriter);
1451
+ Block *block =
1452
+ rewriter.createBlock (&genericOp.region (), genericOp.region ().end (),
1453
+ TypeRange ({inputElementTy, resultElementTy}));
1454
+
1455
+ auto inputValue = block->getArgument (0 );
1456
+ rewriter.setInsertionPointToStart (block);
1457
+ if (inputElementTy.isInteger (8 ) && tableElementTy.isInteger (8 ) &&
1458
+ resultElementTy.isInteger (8 )) {
1459
+ Value index = rewriter.create <IndexCastOp>(loc, rewriter.getIndexType (),
1460
+ inputValue);
1461
+ Value extract =
1462
+ rewriter.create <tensor::ExtractOp>(loc, table, ValueRange{index});
1463
+ rewriter.create <linalg::YieldOp>(loc, extract);
1464
+ return success ();
1465
+ }
1466
+
1467
+ if (inputElementTy.isInteger (16 ) && tableElementTy.isInteger (16 ) &&
1468
+ resultElementTy.isInteger (32 )) {
1469
+ Value extend = rewriter.create <SignExtendIOp>(
1470
+ loc, rewriter.getI32Type (), inputValue);
1471
+
1472
+ auto offset =
1473
+ rewriter.create <ConstantOp>(loc, rewriter.getI32IntegerAttr (32768 ));
1474
+ auto seven =
1475
+ rewriter.create <ConstantOp>(loc, rewriter.getI32IntegerAttr (7 ));
1476
+ auto one =
1477
+ rewriter.create <ConstantOp>(loc, rewriter.getI32IntegerAttr (1 ));
1478
+ auto b1111111 =
1479
+ rewriter.create <ConstantOp>(loc, rewriter.getI32IntegerAttr (127 ));
1480
+
1481
+ // Compute the index and fractional part from the input value:
1482
+ // value = value + 32768
1483
+ // index = value >> 7;
1484
+ // fraction = 0x01111111 & value
1485
+ auto extendAdd = rewriter.create <AddIOp>(loc, extend, offset);
1486
+ Value index =
1487
+ rewriter.create <UnsignedShiftRightOp>(loc, extendAdd, seven);
1488
+ Value fraction = rewriter.create <mlir::AndOp>(loc, extendAdd, b1111111);
1489
+
1490
+ // Extract the base and next values from the table.
1491
+ // base = (int32_t) table[index];
1492
+ // next = (int32_t) table[index + 1];
1493
+ Value indexPlusOne = rewriter.create <AddIOp>(loc, index, one);
1494
+
1495
+ index =
1496
+ rewriter.create <IndexCastOp>(loc, rewriter.getIndexType (), index);
1497
+ indexPlusOne = rewriter.create <IndexCastOp>(
1498
+ loc, rewriter.getIndexType (), indexPlusOne);
1499
+
1500
+ Value base =
1501
+ rewriter.create <tensor::ExtractOp>(loc, table, ValueRange{index});
1502
+ Value next = rewriter.create <tensor::ExtractOp>(
1503
+ loc, table, ValueRange{indexPlusOne});
1504
+
1505
+ base = rewriter.create <SignExtendIOp>(loc, rewriter.getI32Type (), base);
1506
+ next = rewriter.create <SignExtendIOp>(loc, rewriter.getI32Type (), next);
1507
+
1508
+ // Use the fractional part to interpolate between the input values:
1509
+ // result = (base << 7) + (next - base) * fraction
1510
+ Value baseScaled = rewriter.create <ShiftLeftOp>(loc, base, seven);
1511
+ Value diff = rewriter.create <SubIOp>(loc, next, base);
1512
+ Value diffScaled = rewriter.create <MulIOp>(loc, diff, fraction);
1513
+ Value result = rewriter.create <AddIOp>(loc, baseScaled, diffScaled);
1514
+
1515
+ rewriter.create <linalg::YieldOp>(loc, result);
1516
+
1517
+ return success ();
1518
+ }
1519
+ }
1520
+
1521
+ return rewriter.notifyMatchFailure (
1522
+ op, " unable to create body for tosa.table op" );
1523
+ }
1524
+ };
1525
+
1410
1526
} // namespace
1411
1527
1412
1528
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns (
1413
1529
RewritePatternSet *patterns) {
1414
1530
patterns->add <
1415
- PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
1416
- PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
1417
- PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
1418
- PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
1419
- PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
1420
- PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
1531
+ // clang-format off
1532
+ PointwiseConverter<tosa::AddOp>,
1533
+ PointwiseConverter<tosa::SubOp>,
1534
+ PointwiseConverter<tosa::MulOp>,
1535
+ PointwiseConverter<tosa::NegateOp>,
1536
+ PointwiseConverter<tosa::PowOp>,
1537
+ PointwiseConverter<tosa::ReciprocalOp>,
1538
+ PointwiseConverter<tosa::RsqrtOp>,
1539
+ PointwiseConverter<tosa::LogOp>,
1540
+ PointwiseConverter<tosa::ExpOp>,
1541
+ PointwiseConverter<tosa::AbsOp>,
1542
+ PointwiseConverter<tosa::TanhOp>,
1543
+ PointwiseConverter<tosa::BitwiseAndOp>,
1421
1544
PointwiseConverter<tosa::BitwiseOrOp>,
1422
1545
PointwiseConverter<tosa::BitwiseNotOp>,
1423
1546
PointwiseConverter<tosa::BitwiseXorOp>,
1424
1547
PointwiseConverter<tosa::LogicalAndOp>,
1425
1548
PointwiseConverter<tosa::LogicalNotOp>,
1426
1549
PointwiseConverter<tosa::LogicalOrOp>,
1427
- PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
1550
+ PointwiseConverter<tosa::LogicalXorOp>,
1551
+ PointwiseConverter<tosa::CastOp>,
1428
1552
PointwiseConverter<tosa::LogicalLeftShiftOp>,
1429
1553
PointwiseConverter<tosa::LogicalRightShiftOp>,
1430
- PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
1554
+ PointwiseConverter<tosa::SelectOp>,
1555
+ PointwiseConverter<tosa::GreaterOp>,
1431
1556
PointwiseConverter<tosa::GreaterEqualOp>,
1432
- PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
1433
- PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
1434
- PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
1435
- PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
1436
- IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceAllOp>,
1437
- ReduceConverter<tosa::ReduceAnyOp>, ReduceConverter<tosa::ReduceMinOp>,
1438
- ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
1439
- ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
1440
- PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
1441
- TileConverter, TransposeConverter, MatMulConverter,
1557
+ PointwiseConverter<tosa::MaximumOp>,
1558
+ PointwiseConverter<tosa::MinimumOp>,
1559
+ PointwiseConverter<tosa::CeilOp>,
1560
+ PointwiseConverter<tosa::FloorOp>,
1561
+ PointwiseConverter<tosa::ClampOp>,
1562
+ PointwiseConverter<tosa::ReluNOp>,
1563
+ PointwiseConverter<tosa::SigmoidOp>,
1564
+ IdentityNConverter<tosa::IdentityOp>,
1565
+ IdentityNConverter<tosa::IdentityNOp>,
1566
+ ReduceConverter<tosa::ReduceAllOp>,
1567
+ ReduceConverter<tosa::ReduceAnyOp>,
1568
+ ReduceConverter<tosa::ReduceMinOp>,
1569
+ ReduceConverter<tosa::ReduceMaxOp>,
1570
+ ReduceConverter<tosa::ReduceSumOp>,
1571
+ ReduceConverter<tosa::ReduceProdOp>,
1572
+ ArgMaxConverter,
1573
+ ConcatConverter,
1574
+ PadConverter,
1575
+ ReshapeConverter,
1576
+ RescaleConverter,
1577
+ ReverseConverter,
1578
+ TableConverter,
1579
+ TileConverter,
1580
+ TransposeConverter,
1581
+ MatMulConverter,
1442
1582
FullyConnectedConverter>(patterns->getContext ());
1583
+ // clang-format on
1443
1584
}
0 commit comments