Skip to content

Commit ce8df36

Browse files
gflegarcopybara-github
authored andcommitted
Do not use ConversionPatternRewriter outside of dialect conversion.
MLIR does not allow this after llvm/llvm-project#82244 PiperOrigin-RevId: 610740193
1 parent 92475a1 commit ce8df36

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

third_party/triton/cl610740193.patch

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
Upstream PR: https://github.com/openai/triton/pull/3213
2+
3+
diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
4+
--- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp
5+
+++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp
6+
@@ -157,9 +157,10 @@ getSharedMemoryObjectFromStruct(Location
7+
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
8+
}
9+
10+
-SmallVector<Value>
11+
-getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
12+
- Location loc, ConversionPatternRewriter &rewriter) {
13+
+SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
14+
+ ArrayRef<unsigned> order,
15+
+ Location loc,
16+
+ RewriterBase &rewriter) {
17+
auto rank = shape.size();
18+
SmallVector<Value> strides(rank);
19+
int64_t stride = 1;
20+
@@ -172,9 +173,8 @@ getStridesFromShapeAndOrder(ArrayRef<int
21+
22+
// Convert an \param index to a multi-dim coordinate given \param shape and
23+
// \param order.
24+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
25+
- Location loc, Value linear,
26+
- ArrayRef<unsigned> shape,
27+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
28+
+ Value linear, ArrayRef<unsigned> shape,
29+
ArrayRef<unsigned> order) {
30+
unsigned rank = shape.size();
31+
assert(rank == order.size());
32+
@@ -194,9 +194,8 @@ SmallVector<Value> delinearize(Conversio
33+
return multiDim;
34+
}
35+
36+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
37+
- Location loc, unsigned linear,
38+
- ArrayRef<unsigned> shape) {
39+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
40+
+ unsigned linear, ArrayRef<unsigned> shape) {
41+
unsigned rank = shape.size();
42+
assert(rank > 0);
43+
SmallVector<Value> multiDim(rank);
44+
@@ -209,9 +208,8 @@ SmallVector<Value> delinearize(Conversio
45+
return multiDim;
46+
}
47+
48+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
49+
- Location loc, Value linear,
50+
- ArrayRef<unsigned> shape) {
51+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
52+
+ Value linear, ArrayRef<unsigned> shape) {
53+
unsigned rank = shape.size();
54+
assert(rank > 0);
55+
SmallVector<Value> multiDim(rank);
56+
diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h
57+
--- a/lib/Conversion/TritonGPUToLLVM/Utility.h
58+
+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h
59+
@@ -232,9 +232,10 @@ void createStoreDSmem(Location loc, Patt
60+
Value ctaId, ArrayRef<Value> values);
61+
62+
/// Helper function to get strides from a given shape and its order
63+
-SmallVector<Value>
64+
-getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
65+
- Location loc, ConversionPatternRewriter &rewriter);
66+
+SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
67+
+ ArrayRef<unsigned> order,
68+
+ Location loc,
69+
+ RewriterBase &rewriter);
70+
struct SharedMemoryObject {
71+
Value base; // i32 ptr. The start address of the shared memory object after
72+
// the initial allocation or the last slicing operation.
73+
@@ -264,7 +265,7 @@ struct SharedMemoryObject {
74+
75+
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<int64_t> shape,
76+
ArrayRef<unsigned> order, Location loc,
77+
- ConversionPatternRewriter &rewriter)
78+
+ RewriterBase &rewriter)
79+
: base(base), baseElemType(baseElemType) {
80+
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
81+
offsets.append(order.size(), i32_val(0));
82+
@@ -311,18 +312,15 @@ getSharedMemoryObjectFromStruct(Location
83+
84+
// Convert an \param index to a multi-dim coordinate given \param shape and
85+
// \param order.
86+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
87+
- Location loc, Value linear,
88+
- ArrayRef<unsigned> shape,
89+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
90+
+ Value linear, ArrayRef<unsigned> shape,
91+
ArrayRef<unsigned> order);
92+
93+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
94+
- Location loc, unsigned linear,
95+
- ArrayRef<unsigned> shape);
96+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
97+
+ unsigned linear, ArrayRef<unsigned> shape);
98+
99+
-SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
100+
- Location loc, Value linear,
101+
- ArrayRef<unsigned> shape);
102+
+SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
103+
+ Value linear, ArrayRef<unsigned> shape);
104+
105+
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
106+
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
107+
@@ -380,22 +378,20 @@ static Value getSharedMemoryBase(Locatio
108+
109+
/* ------------------------------------ */
110+
// Returns CTA level thread idx
111+
-static Value getThreadIdInCTA(ConversionPatternRewriter &rewriter,
112+
- Location loc) {
113+
+static Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) {
114+
Value tid =
115+
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
116+
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
117+
}
118+
119+
// Returns CTA level thread idx.
120+
-static Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) {
121+
+static Value getThreadId(RewriterBase &rewriter, Location loc) {
122+
Value tid = getThreadIdInCTA(rewriter, loc);
123+
auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
124+
return tid;
125+
}
126+
127+
-static Value getClusterCTAId(ConversionPatternRewriter &rewriter,
128+
- Location loc) {
129+
+static Value getClusterCTAId(RewriterBase &rewriter, Location loc) {
130+
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(loc,
131+
rewriter.getI32Type());
132+
}
133+
@@ -413,8 +409,8 @@ using ::mlir::triton::gpu::DotOperandEnc
134+
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
135+
using ::mlir::triton::gpu::SliceEncodingAttr;
136+
137+
-static Value dot(ConversionPatternRewriter &rewriter, Location loc,
138+
- ArrayRef<Value> offsets, ArrayRef<Value> strides) {
139+
+static Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
140+
+ ArrayRef<Value> strides) {
141+
assert(offsets.size() == strides.size());
142+
Value ret = i32_val(0);
143+
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
144+
@@ -428,9 +424,10 @@ static Value dot(ConversionPatternRewrit
145+
// -----------------------------------------------------------------------
146+
147+
// Get an index-base for each dimension for a \param blockedLayout.
148+
-static SmallVector<Value> emitBaseIndexWithinCTAForBlockedLayout(
149+
- Location loc, ConversionPatternRewriter &rewriter,
150+
- const BlockedEncodingAttr &blockedLayout, RankedTensorType type) {
151+
+static SmallVector<Value>
152+
+emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
153+
+ const BlockedEncodingAttr &blockedLayout,
154+
+ RankedTensorType type) {
155+
auto shape = type.getShape();
156+
Value threadId = getThreadId(rewriter, loc);
157+
Value warpSize = i32_val(32);
158+
@@ -511,9 +508,10 @@ emitOffsetForBlockedLayout(const Blocked
159+
// Mma layout indices
160+
// -----------------------------------------------------------------------
161+
162+
-static SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV1(
163+
- Location loc, ConversionPatternRewriter &rewriter,
164+
- const NvidiaMmaEncodingAttr &mmaLayout, RankedTensorType type) {
165+
+static SmallVector<Value>
166+
+emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter,
167+
+ const NvidiaMmaEncodingAttr &mmaLayout,
168+
+ RankedTensorType type) {
169+
auto shape = type.getShape();
170+
auto wpt = mmaLayout.getWarpsPerCTA();
171+
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
172+
@@ -654,9 +652,10 @@ emitOffsetForMmaLayoutV2(const NvidiaMma
173+
return ret;
174+
}
175+
176+
-static SmallVector<Value> emitBaseIndexWithinCTAForMmaLayoutV2V3(
177+
- Location loc, ConversionPatternRewriter &rewriter,
178+
- const NvidiaMmaEncodingAttr &mmaLayout, RankedTensorType type) {
179+
+static SmallVector<Value>
180+
+emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter,
181+
+ const NvidiaMmaEncodingAttr &mmaLayout,
182+
+ RankedTensorType type) {
183+
auto shape = type.getShape();
184+
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
185+
auto rank = shape.size();
186+
@@ -776,9 +775,10 @@ emitOffsetForSliceLayout(const SliceEnco
187+
// Get offsets / indices for any layout
188+
// -----------------------------------------------------------------------
189+
190+
-static SmallVector<Value>
191+
-emitCTAOffsetForLayout(Location loc, ConversionPatternRewriter &rewriter,
192+
- Attribute layout, ArrayRef<int64_t> shape) {
193+
+static SmallVector<Value> emitCTAOffsetForLayout(Location loc,
194+
+ RewriterBase &rewriter,
195+
+ Attribute layout,
196+
+ ArrayRef<int64_t> shape) {
197+
unsigned rank = shape.size();
198+
SmallVector<unsigned> CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
199+
SmallVector<unsigned> CTASplitNum = triton::gpu::getCTASplitNum(layout);
200+
@@ -806,13 +806,12 @@ emitCTAOffsetForLayout(Location loc, Con
201+
}
202+
203+
static SmallVector<Value>
204+
-emitBaseIndexForLayout(Location loc, ConversionPatternRewriter &rewriter,
205+
- Attribute layout, RankedTensorType type,
206+
- bool withCTAOffset) {
207+
+emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
208+
+ RankedTensorType type, bool withCTAOffset) {
209+
auto shape = type.getShape();
210+
211+
SmallVector<Value> baseIndex;
212+
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
213+
+ RewriterBase::InsertionGuard guard(rewriter);
214+
SmallVector<Value> result;
215+
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
216+
result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter,
217+
@@ -866,7 +865,7 @@ emitOffsetForLayout(Attribute layout, Ra
218+
// Emit indices calculation within each ConversionPattern, and returns a
219+
// [elemsPerThread X rank] index matrix.
220+
static SmallVector<SmallVector<Value>>
221+
-emitIndices(Location loc, ConversionPatternRewriter &rewriter, Attribute layout,
222+
+emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
223+
RankedTensorType type, bool withCTAOffset) {
224+
// step 1, delinearize threadId to get the base index
225+
auto multiDimBase =
226+
@@ -892,7 +891,7 @@ emitIndices(Location loc, ConversionPatt
227+
DenseMap<unsigned, Value> static getSwizzledSharedPtrs(
228+
Location loc, unsigned inVec, RankedTensorType srcTy,
229+
triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy,
230+
- SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
231+
+ SharedMemoryObject smemObj, RewriterBase &rewriter,
232+
SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) {
233+
// This utility computes the pointers for accessing the provided swizzled
234+
// shared memory layout `resSharedLayout`. More specifically, it computes,
235+
diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp
236+
--- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp
237+
+++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp
238+
@@ -78,7 +78,7 @@ private:
239+
LowerToLLVMOptions option;
240+
TritonGPUToLLVMTypeConverter typeConverter;
241+
Block block;
242+
- ConversionPatternRewriter rewriter;
243+
+ IRRewriter rewriter;
244+
Location loc;
245+
};
246+

third_party/triton/workspace.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ def repo():
1616
patch_file = [
1717
"//third_party/triton:cl607293980.patch", # long standing :(
1818
"//third_party/triton:cl610393680.patch",
19+
"//third_party/triton:cl610740193.patch",
1920
],
2021
)

0 commit comments

Comments
 (0)