@@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3436
3436
return explicitSet == defaultSet;
3437
3437
}
3438
3438
3439
- // / Returns true if the \p explictMap is broadcasted with respect to the
3440
- // / \p defaultMap.
3439
+ // / Check if the user defined map is valid broadcast map. Here broadcast
3440
+ // / indexing maps are defined in context of corresponding default indexing maps
3441
+ // / for the given Op. This way the check becomes very simple i.e just check the
3442
+ // / number of result dims.
3443
+ // / Returns true if the explictMap is broadcasted with respect to the
3444
+ // / defaultMap.
3441
3445
static bool isBroadcasted (AffineMap explictMap, AffineMap defaultMap) {
3442
3446
return explictMap.getNumResults () < defaultMap.getNumResults ();
3443
3447
}
@@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3458
3462
return matmulOp->emitOpError ()
3459
3463
<< " Unexpected dim expression in map result." ;
3460
3464
3461
- // Check if the user defined map is valid broadcast map. Here broadcast
3462
- // indexing maps are defined in context of corresponding default indexing maps
3463
- // for the given Op. This way the check becomes very simple i.e just check the
3464
- // number of result dims.
3465
3465
if (isBroadcasted (opIndexingMap, defaultIndexingMap)) {
3466
3466
if (!matmulOp.isValidLhsRhsBroadcastMap (opIndexingMap)) {
3467
3467
return matmulOp->emitOpError ()
@@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3527
3527
}
3528
3528
3529
3529
// / Verifies the broadcast and transpose semantic specified by the explicit
3530
- // / indexing map for the BatchMatmulOp \p op for each operand specified by \p
3531
- // / opIndex.
3530
+ // / indexing map for the BatchMatmulOp op for each operand specified by opIndex.
3532
3531
static LogicalResult
3533
3532
verifyExtendedBatchMatmulSemantic (BatchMatmulOp batchMatmulOp,
3534
3533
unsigned opIndex) {
@@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
3934
3933
return defaultMaps != explicitMaps;
3935
3934
}
3936
3935
3937
- // / Returns true if the given broadcast map \p bcastMap is valid for this op.
3936
+ // / Returns true if the given broadcast map bcastMap is valid for this op.
3938
3937
bool BatchMatmulOp::isValidLhsRhsBroadcastMap (AffineMap bcastMap, bool isLHS) {
3939
3938
assert (bcastMap.getNumResults () < 3 &&
3940
3939
" Expected less than 3 result dim expr." );
@@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3960
3959
RegionBuilderHelper helper (b, block);
3961
3960
SmallVector<Value> yields;
3962
3961
3963
- Value value1 =
3964
- helper.buildTypeFn (TypeFn::cast_signed, block.getArgument (2 ).getType (),
3965
- block.getArgument (0 ));
3966
- Value value2 =
3967
- helper.buildTypeFn (TypeFn::cast_signed, block.getArgument (2 ).getType (),
3968
- block.getArgument (1 ));
3969
- Value value3 = helper.buildBinaryFn (BinaryFn::mul, value1, value2);
3970
- Value value4 =
3971
- helper.buildBinaryFn (BinaryFn::add, block.getArgument (2 ), value3);
3972
- yields.push_back (value4);
3962
+ auto toType = block.getArgument (2 ).getType ();
3963
+ Value castValA =
3964
+ helper.buildTypeFn (TypeFn::cast_signed, toType, block.getArgument (0 ));
3965
+ Value castValB =
3966
+ helper.buildTypeFn (TypeFn::cast_signed, toType, block.getArgument (1 ));
3967
+ Value mulVal = helper.buildBinaryFn (BinaryFn::mul, castValA, castValB);
3968
+ Value addVal =
3969
+ helper.buildBinaryFn (BinaryFn::add, block.getArgument (2 ), mulVal);
3970
+ yields.push_back (addVal);
3973
3971
helper.yieldOutputs (yields);
3974
3972
}
3975
3973
0 commit comments