-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Fix FoldTensorCastProducerOp for multiple result operations #93374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Prashant Kumar (pashu123) ChangesFor patterns where there are multiple results apart from dpsInits, this fails.
The above op has results apart from dpsInit and hence fails. The PR assumes that the result has dpsInits followed by nondpsInits. Full diff: https://github.com/llvm/llvm-project/pull/93374.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8545c7b9af8f7..986008b9d379d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4531,17 +4531,17 @@ struct FoldTensorCastProducerOp
if (!hasTensorCastOperand)
return failure();
- SmallVector<Type, 4> newResultTypes;
- newResultTypes.reserve(op->getNumResults());
+ SmallVector<Type, 4> newResultTypes(op->getResultTypes());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
+ int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
- newResultTypes.push_back(newOperands.back().getType());
+ newResultTypes[dpsInitIdx++] = newOperands.back().getType();
}
// Clone op.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
Can you provide an upstream test please?
(I also tweaked your title to make it slightly more self-descriptive)
I didn't find any tensor dialect op producing multiple results, so I didn't add a test. What should I do in this case? |
The patterns applies to any op with DestinationStyleOpInterface, so you can create one in the test dialect. |
+1 to have a test. @pashu123 as we discussed offline, please add an op to https://github.com/llvm/llvm-project/tree/main/mlir/test/lib/Dialect/Test |
SmallVector<Value, 4> newOperands; | ||
newOperands.reserve(op->getNumOperands()); | ||
int64_t dpsInitIdx = 0; | ||
for (OpOperand &opOperand : op->getOpOperands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be easier to split the dpsInputOperands
and dpsInitOperands
into separate loops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried splitting the loops as mentioned, it throws an error here:
******************** TEST 'MLIR :: Dialect/Tensor/tiling.mlir' FAILED ********************
Exit Code: 1
Command Output (stdout):
--
# RUN: at line 1
/home/prashant/llvm-project/build/bin/mlir-opt /home/prashant/llvm-project/mlir/test/Dialect/Tensor/tiling.mlir -transform-interpreter -canonicalize -cse -split-input-file | /h
ome/prashant/llvm-project/build/bin/FileCheck /home/prashant/llvm-project/mlir/test/Dialect/Tensor/tiling.mlir
# executed command: /home/prashant/llvm-project/build/bin/mlir-opt /home/prashant/llvm-project/mlir/test/Dialect/Tensor/tiling.mlir -transform-interpreter -canonicalize -cse -s
plit-input-file
# .---command stderr------------
# | mlir-opt: /home/prashant/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::detail::TypedValue<mlir::RankedTensor
Type>; From = mlir::Value]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
# | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
# | Stack dump:
# | 0. Program arguments: /home/prashant/llvm-project/build/bin/mlir-opt /home/prashant/llvm-project/mlir/test/Dialect/Tensor/tiling.mlir -transform-interpreter -canonicalize
-cse -split-input-file
# | #0 0x0000582137a377d0 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/prashant/llvm-project/build/bin/mlir-opt+0x2f187d0)
# | #1 0x0000582137a34bef llvm::sys::RunSignalHandlers() (/home/prashant/llvm-project/build/bin/mlir-opt+0x2f15bef)
# | #2 0x0000582137a34d45 SignalHandler(int) Signals.cpp:0:0
# | #3 0x0000728d2cc42520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
# | #4 0x0000728d2cc969fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
# | #5 0x0000728d2cc969fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
# | #6 0x0000728d2cc969fc pthread_kill ./nptl/pthread_kill.c:89:10
# | #7 0x0000728d2cc42476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
# | #8 0x0000728d2cc287f3 abort ./stdlib/abort.c:81:7
# | #9 0x0000728d2cc2871b _nl_load_domain ./intl/loadmsgcat.c:1177:9
# | #10 0x0000728d2cc39e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
# | #11 0x00005821387e1f02 (/home/prashant/llvm-project/build/bin/mlir-opt+0x3cc2f02)
# | #12 0x0000582139a6aced mlir::tensor::PackOp::fold(mlir::tensor::PackOpGenericAdaptor<llvm::ArrayRef<mlir::Attribute>>) (/home/prashant/llvm-project/build/bin/mlir-opt+0x4f4
bced)
I can revisit it if you want or look into the bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MaheshRavishankar I think the issue is that there are some operands that are not presented in the DPS interface.
E.g., the pack op has padding_value, but it is neither Inputs nor Inits. Thanks @hanhanW for the quick debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uggh!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
SmallVector<Value, 4> newOperands; | ||
newOperands.reserve(op->getNumOperands()); | ||
int64_t dpsInitIdx = 0; | ||
for (OpOperand &opOperand : op->getOpOperands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uggh!!
For patterns where there are multiple results apart from dpsInits this fails. For eg: ``` %13:2 = iree_codegen.ukernel.generic "iree_uk_unpack" ins(%extracted_slice : tensor<?x1x16x16xf32>) outs(%11 : tensor<?x16xf32>) .. ``` The above op has results apart from dpsInit and hence fails. The PR assumes that the result has dpsInits followed by nondpsInits.
SmallVector<Value, 4> newOperands; | ||
newOperands.reserve(op->getNumOperands()); | ||
// Assumes that the result has dpsInits followed by nonDpsInits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a chat with Prashant offline, and we found that it is actually not documented in DPS interface. However, all the implementation has the assumption. So we ended up with having a comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused now. You dont need that assumption for what is done here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same confusion. Let me explain a bit more.. The number of results is not as same as the number of init operands. It is not a requirement of DestinationPassingStyleInterface. In this example, it has one init tensor and two result types (i.e., tensor + i32 scalar).
We were confused that the mapping between result types and init tensor. The code is wrong if we have i32, tensor<xxx>
return types. After reading the doc again, I now think that the assumption is correct. The leading result types should match init tensor types.
Each tensor init operand is tied to a corresponding tensor OpResult in a
1-to-1 fashion. The i-th init tensor is tied to the i-th OpResult. The op
may not have any additional OpResults. Init operands and their tied
OpResults have the same type.
(It is not verified in the implementation, so I thought that it's not documented.)
llvm-project/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
Lines 29 to 62 in 7476c20
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { | |
DestinationStyleOpInterface dstStyleOp = | |
cast<DestinationStyleOpInterface>(op); | |
SmallVector<OpOperand *> outputTensorOperands; | |
for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) { | |
Type type = operand.get().getType(); | |
if (isa<TensorType>(type)) { | |
outputTensorOperands.push_back(&operand); | |
} else if (!isa<BaseMemRefType>(type)) { | |
return op->emitOpError("expected that operand #") | |
<< operand.getOperandNumber() << " is a tensor or a memref"; | |
} | |
} | |
// Verify the number of tensor results matches the number of output tensors. | |
if (getNumTensorResults(op) != outputTensorOperands.size()) | |
return op->emitOpError("expected the number of tensor results (") | |
<< getNumTensorResults(op) | |
<< ") to be equal to the number of output tensors (" | |
<< outputTensorOperands.size() << ")"; | |
for (OpOperand *opOperand : outputTensorOperands) { | |
OpResult result = dstStyleOp.getTiedOpResult(opOperand); | |
if (result.getType() != opOperand->get().getType()) | |
return op->emitOpError("expected type of operand #") | |
<< opOperand->getOperandNumber() << " (" | |
<< opOperand->get().getType() << ")" | |
<< " to match type of corresponding result (" << result.getType() | |
<< ")"; | |
} | |
return success(); | |
} |
For patterns where there are multiple results apart from dpsInits, this fails.
E.g.:
The above op has results apart from dpsInit and hence fails. The PR assumes that the result has dpsInits followed by nonDpsInits.