@@ -2512,13 +2512,38 @@ struct Conv1DGenerator
2512
2512
.getOperation ();
2513
2513
}
2514
2514
2515
+ // Take a value and widen to have the same element type as `ty`.
2516
+ Value promote (RewriterBase &rewriter, Location loc, Value val, Type ty) {
2517
+ const Type srcElementType = getElementTypeOrSelf (val.getType ());
2518
+ const Type dstElementType = getElementTypeOrSelf (ty);
2519
+ assert (isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
2520
+ if (srcElementType == dstElementType)
2521
+ return val;
2522
+
2523
+ const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth ();
2524
+ const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth ();
2525
+ const Type dstType =
2526
+ cast<ShapedType>(val.getType ()).cloneWith (std::nullopt, dstElementType);
2527
+
2528
+ if (isa<FloatType>(dstElementType) && srcWidth < dstWidth)
2529
+ return rewriter.create <arith::ExtFOp>(loc, dstType, val);
2530
+
2531
+ if (isa<IntegerType>(dstElementType) && srcWidth < dstWidth)
2532
+ return rewriter.create <arith::ExtSIOp>(loc, dstType, val);
2533
+
2534
+ assert (false && " unhandled promotion case" );
2535
+ return nullptr ;
2536
+ }
2537
+
2515
2538
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
2516
2539
Value conv1dSliceAsContraction (RewriterBase &rewriter, Location loc,
2517
2540
Value lhs, Value rhs, Value res) {
2518
2541
vector::IteratorType par = vector::IteratorType::parallel;
2519
2542
vector::IteratorType red = vector::IteratorType::reduction;
2520
2543
AffineExpr n, w, f, c;
2521
2544
bindDims (ctx, n, w, f, c);
2545
+ lhs = promote (rewriter, loc, lhs, res.getType ());
2546
+ rhs = promote (rewriter, loc, rhs, res.getType ());
2522
2547
return rewriter.create <vector::ContractionOp>(
2523
2548
loc, lhs, rhs, res,
2524
2549
/* indexingMaps=*/ MapList{{n, w, c}, {c, f}, {n, w, f}},
@@ -2666,24 +2691,6 @@ struct Conv1DGenerator
2666
2691
.getOperation ();
2667
2692
}
2668
2693
2669
- // Take a value of element type T and widen to the destination type.
2670
- Value promote (RewriterBase &rewriter, Location loc, Value val, Type ty) {
2671
- if (val.getType () == ty)
2672
- return val;
2673
-
2674
- const int64_t srcWidth =
2675
- getElementTypeOrSelf (val.getType ()).getIntOrFloatBitWidth ();
2676
- const int64_t destWidth = getElementTypeOrSelf (ty).getIntOrFloatBitWidth ();
2677
-
2678
- if (getElementTypeOrSelf (ty).isa <FloatType>() && srcWidth < destWidth)
2679
- return rewriter.create <arith::ExtFOp>(loc, ty, val);
2680
-
2681
- if (getElementTypeOrSelf (ty).isa <IntegerType>() && srcWidth < destWidth)
2682
- return rewriter.create <arith::ExtSIOp>(loc, ty, val);
2683
-
2684
- return nullptr ;
2685
- }
2686
-
2687
2694
// / Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
2688
2695
Value depthwiseConv1dSliceAsMulAcc (RewriterBase &rewriter, Location loc,
2689
2696
Value lhs, Value rhs, Value res) {
0 commit comments