Skip to content

Commit 3230a64

Browse files
[mlir][NVGPU] Fix incorrect API usage in RewritePatterns
Incorrect API usage was detected by D144552. Differential Revision: https://reviews.llvm.org/D145156
1 parent b13a197 commit 3230a64

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
3838
precision(precision) {}
3939

4040
LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
41-
PatternRewriter &rewrite) const override {
41+
PatternRewriter &rewriter) const override {
4242
Location location = op->getLoc();
4343

4444
if (op->hasAttr(op.getTf32EnabledAttrName()) ||
@@ -53,8 +53,10 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
5353
return emitError(location, "TF32x3 is not supported at the moment "
5454
"for nvgpu.mma.sync on f32 datatype");
5555

56-
if (precision == MmaSyncF32Lowering::TF32)
57-
op.setTf32EnabledAttr(rewrite.getUnitAttr());
56+
if (precision == MmaSyncF32Lowering::TF32) {
57+
rewriter.updateRootInPlace(
58+
op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
59+
}
5860

5961
return success();
6062
}

0 commit comments

Comments
 (0)