Skip to content

Commit f35318e

Browse files
authored
[mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (#98074)
The existing `fp8` lowering from `arith` to `amdgpu` bails out on the multidimensional case. We can handle this by `vector.shape_cast` collapsing to the 1-D case on extraction and re-casting back to the desired output shape.
1 parent 9739df2 commit f35318e

File tree

2 files changed

+98
-16
lines changed

2 files changed

+98
-16
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
6767
if (auto inVecType = dyn_cast<VectorType>(inType)) {
6868
if (inVecType.isScalable())
6969
return failure();
70-
if (inVecType.getShape().size() > 1)
71-
// Multi-dimensional vectors are currently unsupported.
72-
return failure();
7370
inType = inVecType.getElementType();
7471
}
7572
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
@@ -80,28 +77,38 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
8077
Location loc = op.getLoc();
8178
Value in = op.getIn();
8279
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
83-
if (!isa<VectorType>(in.getType())) {
80+
auto inType = dyn_cast<VectorType>(in.getType());
81+
if (!inType) {
8482
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
8583
loc, rewriter.getF32Type(), in, 0);
8684
Value result = castF32To(outElemType, asFloat, loc, rewriter);
8785
return rewriter.replaceOp(op, result);
8886
}
89-
VectorType inType = cast<VectorType>(in.getType());
9087
int64_t numElements = inType.getNumElements();
9188
Value zero = rewriter.create<arith::ConstantOp>(
9289
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
93-
Value result =
94-
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
9590
if (inType.getShape().empty()) {
9691
Value scalarIn =
9792
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
9893
// Recurse to send the 0-D vector case to the 1-D vector case
9994
Value scalarExt =
10095
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
101-
result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
102-
ArrayRef<int64_t>{});
96+
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
97+
ArrayRef<int64_t>{});
10398
return rewriter.replaceOp(op, result);
10499
}
100+
101+
VectorType outType = cast<VectorType>(op.getOut().getType());
102+
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
103+
outType.getElementType());
104+
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
105+
106+
if (inType.getRank() > 1) {
107+
inType = VectorType::get(SmallVector<int64_t>{numElements},
108+
inType.getElementType());
109+
in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
110+
}
111+
105112
for (int64_t i = 0; i < numElements; i += 4) {
106113
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
107114
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -113,6 +120,11 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
113120
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
114121
}
115122
}
123+
124+
if (inType.getRank() != outType.getRank()) {
125+
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
126+
}
127+
116128
rewriter.replaceOp(op, result);
117129
}
118130

@@ -181,9 +193,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
181193
if (auto outVecType = dyn_cast<VectorType>(outType)) {
182194
if (outVecType.isScalable())
183195
return failure();
184-
if (outVecType.getShape().size() > 1)
185-
// Multi-dimensional vectors are currently unsupported.
186-
return failure();
187196
outType = outVecType.getElementType();
188197
}
189198
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
@@ -200,8 +209,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
200209
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
201210
if (saturateFP8)
202211
in = clampInput(rewriter, loc, outElemType, in);
212+
auto inVectorTy = dyn_cast<VectorType>(in.getType());
203213
VectorType truncResType = VectorType::get(4, outElemType);
204-
if (!isa<VectorType>(in.getType())) {
214+
if (!inVectorTy) {
205215
Value asFloat = castToF32(in, loc, rewriter);
206216
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
207217
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -213,18 +223,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
213223
int64_t numElements = outType.getNumElements();
214224
Value zero = rewriter.create<arith::ConstantOp>(
215225
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
216-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
217226
if (outType.getShape().empty()) {
218227
Value scalarIn =
219228
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
220229
// Recurse to send the 0-D vector case to the 1-D vector case
221230
Value scalarTrunc =
222231
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
223-
result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
224-
ArrayRef<int64_t>{});
232+
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
233+
ArrayRef<int64_t>{});
225234
return rewriter.replaceOp(op, result);
226235
}
227236

237+
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
238+
outType.getElementType());
239+
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
240+
241+
if (inVectorTy.getRank() > 1) {
242+
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
243+
inVectorTy.getElementType());
244+
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
245+
}
246+
228247
for (int64_t i = 0; i < numElements; i += 4) {
229248
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
230249
Value thisResult = nullptr;
@@ -245,6 +264,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
245264
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
246265
result, i, 1);
247266
}
267+
268+
if (inVectorTy.getRank() != outType.getRank()) {
269+
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
270+
}
271+
248272
rewriter.replaceOp(op, result);
249273
}
250274

mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
115115
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
116116
return %w : vector<9xf8E4M3FNUZ>
117117
}
118+
119+
// -----
120+
121+
// CHECK-LABEL: func.func @vector_trunc_long_2d
122+
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
123+
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
124+
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
125+
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
126+
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
127+
128+
// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
129+
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
130+
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
131+
132+
// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
133+
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
134+
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
135+
// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
136+
// CHECK: return [[RE]]
137+
func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
138+
%w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
139+
return %w : vector<1x9xf8E4M3FNUZ>
140+
}
141+
142+
// -----
143+
144+
// CHECK-LABEL: func.func @vector_ext_long_2d
145+
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
146+
// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
147+
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
148+
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
149+
// CHECK: [[W0:%.+]] = vector.insert [[F0]]
150+
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
151+
// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
152+
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
153+
// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
154+
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
155+
// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
156+
157+
// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
158+
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
159+
// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
160+
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
161+
// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
162+
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
163+
// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
164+
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
165+
// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
166+
167+
// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
168+
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
169+
// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
170+
// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
171+
// CHECK: return [[CAST]]
172+
func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
173+
%w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
174+
return %w : vector<1x9xf32>
175+
}

0 commit comments

Comments
 (0)