@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
119
119
}
120
120
121
121
// Broadcast the source value to all the outer dimensions of the result value.
122
- // If required, the element type is expanded using an arith.extsi operation.
123
- static mlir::Value linalgBroadcastAndMaybeExtSI (PatternRewriter &rewriter,
124
- Location loc, Value source,
125
- Value result) {
122
+ // If required, the element type is expanded using an arith.extsi or arith.extf
123
+ // operation as appropriate.
124
+ static mlir::Value linalgBroadcastAndMaybeExt (PatternRewriter &rewriter,
125
+ Location loc, Value source,
126
+ Value result) {
126
127
ShapedType resultTy = cast<ShapedType>(result.getType ());
127
128
const int64_t resultRank = resultTy.getRank ();
128
129
// Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
135
136
.create <linalg::GenericOp>(
136
137
loc, resultTy, ValueRange ({source}), result, indexingMaps,
137
138
getNParallelLoopsAttrs (resultTy.getRank ()),
138
- [](OpBuilder &builder, Location loc, ValueRange args) {
139
+ [&resultTy ](OpBuilder &builder, Location loc, ValueRange args) {
139
140
Value biasVal = args[0 ];
140
141
Type resType = args[1 ].getType ();
141
142
if (resType != biasVal.getType ()) {
142
- biasVal = builder.create <arith::ExtSIOp>(loc, resType, biasVal);
143
+ biasVal =
144
+ resultTy.getElementType ().isFloat ()
145
+ ? builder.create <arith::ExtFOp>(loc, resType, biasVal)
146
+ .getResult ()
147
+ : builder.create <arith::ExtSIOp>(loc, resType, biasVal)
148
+ .getResult ();
143
149
}
144
150
builder.create <linalg::YieldOp>(loc, biasVal);
145
151
})
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
253
259
ShapedType resultTy = cast<ShapedType>(op->getResult (0 ).getType ());
254
260
255
261
Type inputETy = inputTy.getElementType ();
256
- Type resultETy = resultTy.getElementType ();
257
262
258
263
DenseI64ArrayAttr padAttr = op.getPadAttr ();
259
264
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260
265
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261
266
267
+ Type accETy = op.getAccType ();
268
+ Type accTy = RankedTensorType::get (resultTy.getShape (), accETy);
269
+
262
270
// Get and verify zero points.
263
271
FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
264
272
if (failed (maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
385
393
auto dilationAttr = rewriter.getI64TensorAttr (dilation);
386
394
387
395
Value biasEmptyTensor = rewriter.create <tensor::EmptyOp>(
388
- loc, resultTy.getShape (), resultETy , filteredDims);
396
+ loc, resultTy.getShape (), accETy , filteredDims);
389
397
390
398
Value broadcastBias =
391
- linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
399
+ linalgBroadcastAndMaybeExt (rewriter, loc, bias, biasEmptyTensor);
392
400
393
401
if (hasZp) {
394
402
auto iZp = rewriter.getI32IntegerAttr (inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
410
418
411
419
Value conv = rewriter
412
420
.create <LinalgConvOp>(
413
- loc, resultTy , ValueRange{input, weight},
421
+ loc, accTy , ValueRange{input, weight},
414
422
ValueRange{broadcastBias}, strideAttr, dilationAttr)
415
423
->getResult (0 );
416
424
425
+ // We may need to truncate back to the result type if the accumulator was
426
+ // wider than the result.
427
+ if (resultTy != accTy)
428
+ conv = rewriter.create <tosa::CastOp>(loc, resultTy, conv);
429
+
417
430
rewriter.replaceOp (op, conv);
418
431
return success ();
419
432
}
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
444
457
auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr (" stride" ));
445
458
auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr (" dilation" ));
446
459
460
+ Type accETy = op.getAccType ();
461
+
447
462
if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
448
463
return rewriter.notifyMatchFailure (
449
464
op, " tosa.depthwise_conv ops require static shapes" );
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
516
531
ShapedType linalgConvTy =
517
532
RankedTensorType::get ({resultShape[0 ], resultShape[1 ], resultShape[2 ],
518
533
weightShape[2 ], weightShape[3 ]},
519
- resultETy );
534
+ accETy );
520
535
521
- auto resultZeroAttr = rewriter.getZeroAttr (resultETy );
536
+ auto resultZeroAttr = rewriter.getZeroAttr (accETy );
522
537
Value emptyTensor = rewriter.create <tensor::EmptyOp>(
523
- loc, linalgConvTy.getShape (), resultETy , filteredDims);
538
+ loc, linalgConvTy.getShape (), accETy , filteredDims);
524
539
Value zero = rewriter.create <arith::ConstantOp>(loc, resultZeroAttr);
525
540
Value zeroTensor = rewriter
526
541
.create <linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
543
558
ValueRange{zeroTensor}, strideAttr, dilationAttr)
544
559
.getResult (0 );
545
560
561
+ // We may need to truncate back to the result type if the accumulator was
562
+ // wider than the result.
563
+ if (accETy != resultETy)
564
+ conv = rewriter.create <tosa::CastOp>(
565
+ loc,
566
+ RankedTensorType::get (cast<ShapedType>(conv.getType ()).getShape (),
567
+ resultETy),
568
+ conv);
569
+
546
570
SmallVector<ReassociationExprs, 4 > reassociationMap;
547
571
createDepthwiseConvCollapseMap (resultRank, reassociationMap, rewriter);
548
572
Value convReshape = rewriter.create <tensor::CollapseShapeOp>(
0 commit comments