Skip to content

Commit 048a481

Browse files
committed
*Renames few variables and updates few comments
1 parent 237e041 commit 048a481

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3436,8 +3436,12 @@ static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
34363436
return explicitSet == defaultSet;
34373437
}
34383438

3439-
/// Returns true if the \p explictMap is broadcasted with respect to the
3440-
/// \p defaultMap.
3439+
/// Check if the user defined map is valid broadcast map. Here broadcast
3440+
/// indexing maps are defined in context of corresponding default indexing maps
3441+
/// for the given Op. This way the check becomes very simple i.e just check the
3442+
/// number of result dims.
3443+
/// Returns true if the explictMap is broadcasted with respect to the
3444+
/// defaultMap.
34413445
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
34423446
return explictMap.getNumResults() < defaultMap.getNumResults();
34433447
}
@@ -3458,10 +3462,6 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34583462
return matmulOp->emitOpError()
34593463
<< "Unexpected dim expression in map result.";
34603464

3461-
// Check if the user defined map is valid broadcast map. Here broadcast
3462-
// indexing maps are defined in context of corresponding default indexing maps
3463-
// for the given Op. This way the check becomes very simple i.e just check the
3464-
// number of result dims.
34653465
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
34663466
if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
34673467
return matmulOp->emitOpError()
@@ -3527,8 +3527,7 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
35273527
}
35283528

35293529
/// Verifies the broadcast and transpose semantic specified by the explicit
3530-
/// indexing map for the BatchMatmulOp \p op for each operand specified by \p
3531-
/// opIndex.
3530+
/// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
35323531
static LogicalResult
35333532
verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
35343533
unsigned opIndex) {
@@ -3934,7 +3933,7 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
39343933
return defaultMaps != explicitMaps;
39353934
}
39363935

3937-
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
3936+
/// Returns true if the given broadcast map bcastMap is valid for this op.
39383937
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
39393938
assert(bcastMap.getNumResults() < 3 &&
39403939
"Expected less than 3 result dim expr.");
@@ -3960,16 +3959,15 @@ void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
39603959
RegionBuilderHelper helper(b, block);
39613960
SmallVector<Value> yields;
39623961

3963-
Value value1 =
3964-
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
3965-
block.getArgument(0));
3966-
Value value2 =
3967-
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
3968-
block.getArgument(1));
3969-
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3970-
Value value4 =
3971-
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3972-
yields.push_back(value4);
3962+
auto toType = block.getArgument(2).getType();
3963+
Value castValA =
3964+
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
3965+
Value castValB =
3966+
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
3967+
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3968+
Value addVal =
3969+
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
3970+
yields.push_back(addVal);
39733971
helper.yieldOutputs(yields);
39743972
}
39753973

0 commit comments

Comments
 (0)