Skip to content

Commit 3dae223

Browse files
adam-yangagozillon
authored andcommitted
[𝘀𝗽𝗿] initial version
Created using spr 1.3.4
2 parents f74879c + d93c9f9 commit 3dae223

File tree

8 files changed

+344
-28
lines changed

8 files changed

+344
-28
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
8686
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8787
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8888
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
89+
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
8990
}

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ static bool isIntrinsicExpansion(Function &F) {
6464
case Intrinsic::dx_udot:
6565
case Intrinsic::dx_sign:
6666
case Intrinsic::dx_step:
67+
case Intrinsic::dx_radians:
6768
return true;
6869
}
6970
return false;
@@ -442,6 +443,14 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
442443
return Builder.CreateSelect(Cond, Zero, One);
443444
}
444445

446+
static Value *expandRadiansIntrinsic(CallInst *Orig) {
447+
Value *X = Orig->getOperand(0);
448+
Type *Ty = X->getType();
449+
IRBuilder<> Builder(Orig);
450+
Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
451+
return Builder.CreateFMul(X, PiOver180);
452+
}
453+
445454
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
446455
Intrinsic::ID ClampIntrinsic) {
447456
if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -561,6 +570,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
561570
break;
562571
case Intrinsic::dx_step:
563572
Result = expandStepIntrinsic(Orig);
573+
case Intrinsic::dx_radians:
574+
Result = expandRadiansIntrinsic(Orig);
575+
break;
564576
}
565577
if (Result) {
566578
Orig->replaceAllUsesWith(Result);

llvm/test/CodeGen/DirectX/radians.ll

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-intrinsic-expansion -scalarizer -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
declare half @llvm.dx.radians.f16(half)
5+
declare float @llvm.dx.radians.f32(float)
6+
7+
declare <4 x half> @llvm.dx.radians.v4f16(<4 x half>)
8+
declare <4 x float> @llvm.dx.radians.v4f32(<4 x float>)
9+
10+
define noundef half @radians_half(half noundef %a) {
11+
; CHECK-LABEL: define noundef half @radians_half(
12+
; CHECK-SAME: half noundef [[A:%.*]]) {
13+
; CHECK-NEXT: [[ENTRY:.*:]]
14+
; CHECK-NEXT: [[TMP0:%.*]] = fmul half [[A]], 0xH2478
15+
; CHECK-NEXT: ret half [[TMP0]]
16+
;
17+
entry:
18+
%elt.radians = call half @llvm.dx.radians.f16(half %a)
19+
ret half %elt.radians
20+
}
21+
22+
define noundef float @radians_float(float noundef %a) {
23+
; CHECK-LABEL: define noundef float @radians_float(
24+
; CHECK-SAME: float noundef [[A:%.*]]) {
25+
; CHECK-NEXT: [[ENTRY:.*:]]
26+
; CHECK-NEXT: [[TMP0:%.*]] = fmul float [[A]], 0x3F91DF46A0000000
27+
; CHECK-NEXT: ret float [[TMP0]]
28+
;
29+
entry:
30+
%elt.radians = call float @llvm.dx.radians.f32(float %a)
31+
ret float %elt.radians
32+
}
33+
34+
define noundef <4 x half> @radians_half_vector(<4 x half> noundef %a) {
35+
; CHECK-LABEL: define noundef <4 x half> @radians_half_vector(
36+
; CHECK-SAME: <4 x half> noundef [[A:%.*]]) {
37+
; CHECK-NEXT: [[ENTRY:.*:]]
38+
; CHECK: [[ee0:%.*]] = extractelement <4 x half> [[A]], i64 0
39+
; CHECK: [[ie0:%.*]] = fmul half [[ee0]], 0xH2478
40+
; CHECK: [[ee1:%.*]] = extractelement <4 x half> [[A]], i64 1
41+
; CHECK: [[ie1:%.*]] = fmul half [[ee1]], 0xH2478
42+
; CHECK: [[ee2:%.*]] = extractelement <4 x half> [[A]], i64 2
43+
; CHECK: [[ie2:%.*]] = fmul half [[ee2]], 0xH2478
44+
; CHECK: [[ee3:%.*]] = extractelement <4 x half> [[A]], i64 3
45+
; CHECK: [[ie3:%.*]] = fmul half [[ee3]], 0xH2478
46+
; CHECK: [[TMP0:%.*]] = insertelement <4 x half> poison, half [[ie0]], i64 0
47+
; CHECK: [[TMP1:%.*]] = insertelement <4 x half> %[[TMP0]], half [[ie1]], i64 1
48+
; CHECK: [[TMP2:%.*]] = insertelement <4 x half> %[[TMP1]], half [[ie2]], i64 2
49+
; CHECK: [[TMP3:%.*]] = insertelement <4 x half> %[[TMP2]], half [[ie3]], i64 3
50+
; CHECK: ret <4 x half> [[TMP3]]
51+
;
52+
entry:
53+
%elt.radians = call <4 x half> @llvm.dx.radians.v4f16(<4 x half> %a)
54+
ret <4 x half> %elt.radians
55+
}
56+
57+
define noundef <4 x float> @radians_float_vector(<4 x float> noundef %a) {
58+
; CHECK-LABEL: define noundef <4 x float> @radians_float_vector(
59+
; CHECK-SAME: <4 x float> noundef [[A:%.*]]) {
60+
; CHECK-NEXT: [[ENTRY:.*:]]
61+
; CHECK: [[ee0:%.*]] = extractelement <4 x float> [[A]], i64 0
62+
; CHECK: [[ie0:%.*]] = fmul float [[ee0]], 0x3F91DF46A0000000
63+
; CHECK: [[ee1:%.*]] = extractelement <4 x float> [[A]], i64 1
64+
; CHECK: [[ie1:%.*]] = fmul float [[ee1]], 0x3F91DF46A0000000
65+
; CHECK: [[ee2:%.*]] = extractelement <4 x float> [[A]], i64 2
66+
; CHECK: [[ie2:%.*]] = fmul float [[ee2]], 0x3F91DF46A0000000
67+
; CHECK: [[ee3:%.*]] = extractelement <4 x float> [[A]], i64 3
68+
; CHECK: [[ie3:%.*]] = fmul float [[ee3]], 0x3F91DF46A0000000
69+
; CHECK: [[TMP0:%.*]] = insertelement <4 x float> poison, float [[ie0]], i64 0
70+
; CHECK: [[TMP1:%.*]] = insertelement <4 x float> %[[TMP0]], float [[ie1]], i64 1
71+
; CHECK: [[TMP2:%.*]] = insertelement <4 x float> %[[TMP1]], float [[ie2]], i64 2
72+
; CHECK: [[TMP3:%.*]] = insertelement <4 x float> %[[TMP2]], float [[ie3]], i64 3
73+
; CHECK: ret <4 x float> [[TMP3]]
74+
;
75+
entry:
76+
%elt.radians = call <4 x float> @llvm.dx.radians.v4f32(<4 x float> %a)
77+
ret <4 x float> %elt.radians
78+
}
79+

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
295295
let description = [{
296296
Tiles the operations pointed to by the target handle and fuses their
297297
producers greedily using the options provided as attributes.
298+
299+
If `apply_cleanup` is true then slice canonicalization is applied between
300+
fusion steps.
298301
}];
299302

300303
let arguments =
301304
(ins TransformHandleTypeInterface:$target,
302305
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
303-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
306+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
307+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
304308
let results = (outs TransformHandleTypeInterface:$transformed,
305309
Variadic<TransformHandleTypeInterface>:$loops);
306310

307311
let assemblyFormat = [{
308312
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
309-
attr-dict `:` functional-type(operands, results)
313+
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
314+
`:` functional-type(operands, results)
310315
}];
311316
let hasVerifier = 1;
312317
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Interfaces/LoopLikeInterface.h"
1616
#include "mlir/Interfaces/TilingInterface.h"
1717
#include "mlir/Interfaces/ViewLikeInterface.h"
18+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1819

1920
#include <deque>
2021

@@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
153154
fusionControlFn = controlFn;
154155
return *this;
155156
}
157+
158+
/// An optional set of rewrite patterns to apply to the results of tiling
159+
/// before fusion. This will track deleted and newly inserted
160+
/// `tensor.extract_slice` ops and update the worklist.
161+
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
156162
};
157163

158164
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
562562
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
563563
scf::SCFTileAndFuseOptions tileAndFuseOptions;
564564
tileAndFuseOptions.tilingOptions = tilingOptions;
565+
566+
if (getApplyCleanup()) {
567+
MLIRContext *context = rewriter.getContext();
568+
RewritePatternSet patterns(context);
569+
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
570+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
571+
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
572+
}
573+
565574
LogicalResult result = applyTilingToAll(
566575
rewriter, getOperation(), state.getPayloadOps(getTarget()),
567576
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 130 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/TilingInterface.h"
27+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Support/Debug.h"
2931
#include <optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317
return generatedSlices;
13161318
}
13171319

1320+
namespace {
1321+
1322+
//===----------------------------------------------------------------------===//
1323+
// SliceTrackingListener
1324+
//===----------------------------------------------------------------------===//
1325+
1326+
/// This class is a listener for tracking the insertion and removal of
1327+
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328+
/// fusion algorithm to apply cleanup patterns in between fusion steps.
1329+
class SliceTrackingListener : public RewriterBase::Listener {
1330+
public:
1331+
explicit SliceTrackingListener(
1332+
std::optional<FrozenRewritePatternSet> patterns);
1333+
SliceTrackingListener() = default;
1334+
1335+
/// Adds the given list of operations to the worklist, and if present, applies
1336+
/// the list of `patterns` to the newly added operations. This only processes
1337+
/// the given operations and any newly inserted ones by the pattern set.
1338+
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1339+
1340+
/// Add to the new operation worklist if it is an extract_slice.
1341+
void notifyOperationInserted(Operation *op,
1342+
OpBuilder::InsertPoint previous) override;
1343+
1344+
/// Shared helper for operation removal from the worklist.
1345+
void removeOp(Operation *op);
1346+
1347+
/// Remove the operation from the worklist.
1348+
void notifyOperationErased(Operation *op) override;
1349+
1350+
/// Remove the operation from the worklist.
1351+
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1352+
1353+
/// The worklist for this transformation keeps track of the slices to visit
1354+
/// next for fusion.
1355+
std::deque<tensor::ExtractSliceOp> worklist;
1356+
1357+
private:
1358+
/// Optional pattern set to apply when adding new operations to the worklist.
1359+
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1360+
};
1361+
1362+
SliceTrackingListener::SliceTrackingListener(
1363+
std::optional<FrozenRewritePatternSet> p) {
1364+
patterns = std::move(p);
1365+
}
1366+
1367+
LogicalResult
1368+
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1369+
for (Operation *op : ops) {
1370+
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371+
worklist.push_back(slice);
1372+
}
1373+
1374+
if (!patterns)
1375+
return success();
1376+
1377+
GreedyRewriteConfig config;
1378+
config.listener = this;
1379+
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1380+
return applyOpPatternsAndFold(ops, patterns.value(), config);
1381+
}
1382+
1383+
void SliceTrackingListener::notifyOperationInserted(
1384+
Operation *op, OpBuilder::InsertPoint previous) {
1385+
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+
if (!slice)
1387+
return;
1388+
worklist.push_back(slice);
1389+
}
1390+
1391+
// Scan the worklist for the given op and remove it if present. The expectation
1392+
// is for the worklist to be small and for removal to be relatively rare.
1393+
void SliceTrackingListener::removeOp(Operation *op) {
1394+
if (!isa<tensor::ExtractSliceOp>(op))
1395+
return;
1396+
auto iter = worklist.begin();
1397+
while (iter != worklist.end()) {
1398+
if (*iter == op)
1399+
break;
1400+
iter++;
1401+
}
1402+
if (iter == worklist.end())
1403+
return;
1404+
1405+
worklist.erase(iter);
1406+
}
1407+
1408+
void SliceTrackingListener::notifyOperationErased(Operation *op) {
1409+
removeOp(op);
1410+
}
1411+
1412+
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1413+
ValueRange replacement) {
1414+
removeOp(op);
1415+
}
1416+
} // namespace
1417+
13181418
/// Implementation of tile consumer and fuse producer greedily.
13191419
FailureOr<scf::SCFTileAndFuseResult>
13201420
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701470
tensor::ExtractSliceOp candidateSlice;
13711471
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721472
};
1373-
std::deque<WorklistItem> worklist;
1374-
auto addCandidateSlices = [&worklist, &options,
1375-
&loops](ArrayRef<Operation *> candidates) {
1376-
for (auto candidate : candidates) {
1377-
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378-
if (!sliceOp || sliceOp.use_empty())
1379-
continue;
13801473

1381-
auto [fusableProducer, destinationInitArg] =
1382-
getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
1383-
if (!fusableProducer)
1384-
continue;
1385-
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386-
options.fusionControlFn(sliceOp, fusableProducer,
1387-
destinationInitArg.has_value());
1388-
if (!controlFnResult)
1389-
continue;
1390-
worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
1391-
}
1392-
};
1474+
SliceTrackingListener sliceTracker =
1475+
SliceTrackingListener(options.cleanupPatterns);
13931476

1394-
addCandidateSlices(tilingResult->generatedSlices);
1477+
if (failed(
1478+
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1479+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1480+
}
13951481
OpBuilder::InsertionGuard g(rewriter);
1396-
while (!worklist.empty()) {
1397-
// Traverse the slices in BFS fashion.
1398-
WorklistItem worklistItem = worklist.front();
1399-
worklist.pop_front();
1482+
while (!sliceTracker.worklist.empty()) {
1483+
auto candidateSlice = sliceTracker.worklist.front();
1484+
sliceTracker.worklist.pop_front();
1485+
1486+
auto [fusableProducer, destinationInitArg] =
1487+
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1488+
loops);
1489+
if (!fusableProducer)
1490+
continue;
1491+
1492+
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493+
options.fusionControlFn(candidateSlice, fusableProducer,
1494+
destinationInitArg.has_value());
1495+
if (!controlFnResult)
1496+
continue;
1497+
1498+
WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
14001499

14011500
// The operands of the fused producer might themselved be slices of
14021501
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071506
if (!fusedResult)
14081507
continue;
14091508

1509+
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1510+
14101511
if (worklistItem.controlFnResult.yieldProducerReplacement) {
14111512
// Reconstruct and yield all opResult of fusableProducerOp by default. The
14121513
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211522
fusableProducerOp, "failed to replacement value for this "
14221523
"operation from within the tiled loop");
14231524
}
1424-
addCandidateSlices(newSlices.value());
1525+
worklistCandidates.append(newSlices.value());
14251526
for (auto [index, result] :
14261527
llvm::enumerate(fusableProducerOp->getResults())) {
14271528
origValToResultNumber[result] = loops.front()->getNumResults() -
14281529
fusableProducerOp->getNumResults() +
14291530
index;
14301531
}
14311532
}
1432-
addCandidateSlices(fusedResult->generatedSlices);
14331533
if (Operation *tiledAndFusedOp =
14341534
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
14351535
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
14361536
tiledAndFusedOps.insert(tiledAndFusedOp);
14371537
}
1538+
1539+
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1540+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1541+
}
14381542
}
14391543

14401544
DenseMap<Value, Value> replacements;

0 commit comments

Comments
 (0)