Skip to content

Commit 3d22305

Browse files
committed
Add to folders
1 parent 0d83b20 commit 3d22305

File tree

5 files changed

+34
-13
lines changed

5 files changed

+34
-13
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14541454
def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
14551455
let summary = "Convert Vector dialect to SPIR-V dialect";
14561456
let constructor = "mlir::createConvertVectorToSPIRVPass()";
1457-
let dependentDialects = ["spirv::SPIRVDialect"];
1457+
let dependentDialects = [
1458+
"spirv::SPIRVDialect",
1459+
"ub::UBDialect"
1460+
];
14581461
}
14591462

14601463
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1919
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2020
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
21-
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2221
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2322
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2423
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -27,7 +26,6 @@
2726
#include "mlir/Pass/Pass.h"
2827
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
2928
#include "mlir/Transforms/DialectConversion.h"
30-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3129
#include <memory>
3230

3331
#define DEBUG_TYPE "convert-to-spirv"

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1716
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1817
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1918
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1717
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18+
#include "mlir/Dialect/UB/IR/UBOps.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19861986
return fromElementsOp.getElements()[flatIndex];
19871987
}
19881988

1989-
OpFoldResult ExtractOp::fold(FoldAdaptor) {
1989+
/// Fold an insert or extract operation into an poison value when a poison index
1990+
/// is found at any dimension of the static position.
1991+
static ub::PoisonAttr
1992+
foldPoisonIndexInsertExtractOp(MLIRContext *context,
1993+
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
1994+
if (!llvm::is_contained(staticPos, poisonVal))
1995+
return ub::PoisonAttr();
1996+
1997+
return ub::PoisonAttr::get(context);
1998+
}
1999+
2000+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
19902001
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
19912002
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
19922003
// mismatch).
19932004
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
19942005
return getVector();
2006+
if (auto res = foldPoisonIndexInsertExtractOp(
2007+
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2008+
return res;
19952009
if (succeeded(foldExtractOpFromExtractChain(*this)))
19962010
return getResult();
19972011
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2262,13 +2276,15 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22622276
/// Fold an insert or extract operation into an poison value when a poison index
22632277
/// is found at any dimension of the static position.
22642278
template <typename OpTy>
2265-
LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
2266-
PatternRewriter &rewriter) {
2267-
if (!llvm::is_contained(op.getStaticPosition(), OpTy::kPoisonIndex))
2268-
return failure();
2279+
LogicalResult
2280+
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) {
2281+
if (auto poisonAttr = foldPoisonIndexInsertExtractOp(
2282+
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) {
2283+
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr);
2284+
return success();
2285+
}
22692286

2270-
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());
2271-
return success();
2287+
return failure();
22722288
}
22732289

22742290
} // namespace
@@ -2279,7 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
22792295
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
22802296
results.add(foldExtractFromShapeCastToShapeCast);
22812297
results.add(foldExtractFromFromElements);
2282-
results.add(foldPoisonIndexInsertExtractOp<ExtractOp>);
2298+
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>);
22832299
}
22842300

22852301
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -3044,7 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
30443060
MLIRContext *context) {
30453061
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
30463062
InsertOpConstantFolder>(context);
3047-
results.add(foldPoisonIndexInsertExtractOp<InsertOp>);
3063+
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>);
30483064
}
30493065

30503066
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
@@ -3053,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30533069
// (type mismatch).
30543070
if (getNumIndices() == 0 && getSourceType() == getType())
30553071
return getSource();
3072+
if (auto res = foldPoisonIndexInsertExtractOp(
3073+
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3074+
return res;
3075+
30563076
return {};
30573077
}
30583078

0 commit comments

Comments
 (0)