Skip to content

Commit 705f858

Browse files
authored
[Tosa] Disable tosa folder for non-int/float/index types (#71757)
In order to fold, we need to create DenseElementsAttr, which does not support quantized element types. This patch adds tests for folding quntized element types and disable tosa folders where appropriate. refactored canonicalize.mlir test to use --split-input-file also fixed verifier for trait MulOperandsAndResultElementType for quantized element types Signed-off-by: Tai Ly <[email protected]>
1 parent 53c06c5 commit 705f858

File tree

3 files changed

+258
-12
lines changed

3 files changed

+258
-12
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ class MulOperandsAndResultElementType
7878
return success();
7979
}
8080

81-
return failure();
81+
// In cases of all other types, op requires the same element
82+
// type for all operands and result.
83+
return impl::verifySameOperandsAndResultElementType(op);
8284
}
8385
};
8486

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
491491
if (!lhsTy || !rhsTy || !resultTy)
492492
return {};
493493

494+
// Cannot create an ElementsAttr from non-int/float/index types
495+
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
496+
!rhsTy.getElementType().isIntOrIndexOrFloat())
497+
return {};
498+
494499
auto resultETy = resultTy.getElementType();
495500
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
496501
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -529,6 +534,7 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
529534
if (lhsTy != rhsTy)
530535
return {};
531536

537+
// IntDivOp inputs must be integer type, no need to check for quantized type
532538
auto resultETy = resultTy.getElementType();
533539
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
534540
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -626,6 +632,11 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
626632
if (!lhsTy || !rhsTy || !resultTy)
627633
return {};
628634

635+
// Cannot create an ElementsAttr from non-int/float/index types
636+
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
637+
!rhsTy.getElementType().isIntOrIndexOrFloat())
638+
return {};
639+
629640
auto resultETy = resultTy.getElementType();
630641
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
631642
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -821,6 +832,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
821832
return getResult();
822833
}
823834

835+
// Cannot create an ElementsAttr from non-int/float/index types
836+
if (!inputTy.getElementType().isIntOrIndexOrFloat())
837+
return {};
838+
824839
// reshape(const(x)) -> const(reshape-attr(x))
825840
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
826841
// Constants must have static shape.
@@ -956,13 +971,12 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
956971
}
957972

958973
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
959-
auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
960974
auto resultTy = llvm::cast<ShapedType>(getType());
961975

962976
// Transposing splat values just means reshaping.
963977
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
964978
if (input.isSplat() && resultTy.hasStaticShape() &&
965-
inputTy.getElementType() == resultTy.getElementType())
979+
input.getType().getElementType() == resultTy.getElementType())
966980
return input.reshape(resultTy);
967981
}
968982

0 commit comments

Comments
 (0)