Skip to content

Commit 971b579

Browse files
authored
[MLIR] Don't drop attached discardable attributes (#111261)
The creation of pack op was dropping discardable attributes.
1 parent 5d372ea commit 971b579

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4337,11 +4337,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
43374337
dest =
43384338
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
43394339
}
4340-
Value newOp = rewriter.create<tensor::PackOp>(
4341-
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4342-
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4340+
auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
4341+
Value res = clonedPackOp.getResult();
4342+
rewriter.startOpModification(clonedPackOp);
4343+
clonedPackOp.getSourceMutable().assign(source);
4344+
clonedPackOp.getDestMutable().assign(dest);
4345+
res.setType(dest.getType());
4346+
rewriter.finalizeOpModification(clonedPackOp);
4347+
43434348
rewriter.replaceOpWithNewOp<tensor::CastOp>(
4344-
packOp, packOp.getResult().getType(), newOp);
4349+
packOp, packOp.getResult().getType(), clonedPackOp);
43454350
return success();
43464351
}
43474352

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2357,7 +2357,7 @@ func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>
23572357
%tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
23582358
%tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
23592359
%packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2360-
%unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2360+
%unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
23612361
return %unpacked : tensor<224x512xbf16>
23622362
}
23632363

@@ -2707,3 +2707,14 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso
27072707
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
27082708
return %0#1 : index
27092709
}
2710+
2711+
// -----
2712+
2713+
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
2714+
// CHECK: tensor.pack {{.*}} {test_attr}
2715+
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
2716+
%c32_i64 = arith.constant 32 : i64
2717+
%cst = arith.constant 0.000000e+00 : f16
2718+
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
2719+
return %pack : tensor<128x?x100x16x1xf16>
2720+
}

0 commit comments

Comments
 (0)