Skip to content

Commit 6b95dd4

Browse files
author
Xinyu Yang
authored
[Torch] Fix PrimNumToTensorScalarOp::fold (#3339)
In constant folding progress, a new constant op will be created according to the origin op's result type. See the code in TorchDialect.cpp. ```cpp Operation *TorchDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto integerType = dyn_cast<Torch::IntType>(type)) return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value)); if (auto floatType = dyn_cast<Torch::FloatType>(type)) return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value)); if (auto numberType = dyn_cast<Torch::NumberType>(type)) { if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) { return builder.create<Torch::ConstantNumberOp>(loc, floatValue); } else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) { return builder.create<Torch::ConstantNumberOp>(loc, intValue); } } if (isa<Torch::BoolType>(type)) { return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value)); } if (isa<Torch::NoneType>(type)) return builder.create<ConstantNoneOp>(loc); if (auto stringAttr = dyn_cast<StringAttr>(value)) return builder.create<ConstantStrOp>(loc, stringAttr); if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) { // Only !torch.vtensor can be constant folded. !torch.tensor has // non-trivial aliasing semantics which prevent deduplicating it. assert(isa<ValueTensorType>(type) && "should be a vtensor type!"); return builder.create<ValueTensorLiteralOp>(loc, elementsAttr); } return nullptr; } ``` So when the op has a tensor result type, it must be "ValueTensorType" due to the **assert** statement. However, many fold methods in TorchOps.cpp only have a judgment of "BaseTensorType".
1 parent 44fa6c3 commit 6b95dd4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4471,10 +4471,10 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
44714471

44724472
OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
44734473
Attribute a = adaptor.getA();
4474-
auto resultTy = cast<BaseTensorType>(getType());
4474+
auto resultTy = dyn_cast<ValueTensorType>(getType());
44754475
if (!a)
44764476
return {};
4477-
if (!resultTy.hasDtype() || !resultTy.hasSizes())
4477+
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
44784478
return {};
44794479

44804480
auto dty = resultTy.getDtype();

0 commit comments

Comments
 (0)