@@ -67,9 +67,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
67
67
if (auto inVecType = dyn_cast<VectorType>(inType)) {
68
68
if (inVecType.isScalable ())
69
69
return failure ();
70
- if (inVecType.getShape ().size () > 1 )
71
- // Multi-dimensional vectors are currently unsupported.
72
- return failure ();
73
70
inType = inVecType.getElementType ();
74
71
}
75
72
return success (inType.isFloat8E5M2FNUZ () || inType.isFloat8E4M3FNUZ ());
@@ -80,28 +77,38 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
80
77
Location loc = op.getLoc ();
81
78
Value in = op.getIn ();
82
79
Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
83
- if (!isa<VectorType>(in.getType ())) {
80
+ auto inType = dyn_cast<VectorType>(in.getType ());
81
+ if (!inType) {
84
82
Value asFloat = rewriter.create <amdgpu::ExtPackedFp8Op>(
85
83
loc, rewriter.getF32Type (), in, 0 );
86
84
Value result = castF32To (outElemType, asFloat, loc, rewriter);
87
85
return rewriter.replaceOp (op, result);
88
86
}
89
- VectorType inType = cast<VectorType>(in.getType ());
90
87
int64_t numElements = inType.getNumElements ();
91
88
Value zero = rewriter.create <arith::ConstantOp>(
92
89
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
93
- Value result =
94
- rewriter.createOrFold <vector::SplatOp>(loc, op.getOut ().getType (), zero);
95
90
if (inType.getShape ().empty ()) {
96
91
Value scalarIn =
97
92
rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
98
93
// Recurse to send the 0-D vector case to the 1-D vector case
99
94
Value scalarExt =
100
95
rewriter.create <arith::ExtFOp>(loc, outElemType, scalarIn);
101
- result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero,
102
- ArrayRef<int64_t >{});
96
+ Value result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero,
97
+ ArrayRef<int64_t >{});
103
98
return rewriter.replaceOp (op, result);
104
99
}
100
+
101
+ VectorType outType = cast<VectorType>(op.getOut ().getType ());
102
+ VectorType flatTy = VectorType::get (SmallVector<int64_t >{numElements},
103
+ outType.getElementType ());
104
+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, flatTy, zero);
105
+
106
+ if (inType.getRank () > 1 ) {
107
+ inType = VectorType::get (SmallVector<int64_t >{numElements},
108
+ inType.getElementType ());
109
+ in = rewriter.create <vector::ShapeCastOp>(loc, inType, in);
110
+ }
111
+
105
112
for (int64_t i = 0 ; i < numElements; i += 4 ) {
106
113
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
107
114
Value inSlice = rewriter.create <vector::ExtractStridedSliceOp>(
@@ -113,6 +120,11 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
113
120
result = rewriter.create <vector::InsertOp>(loc, asType, result, i + j);
114
121
}
115
122
}
123
+
124
+ if (inType.getRank () != outType.getRank ()) {
125
+ result = rewriter.create <vector::ShapeCastOp>(loc, outType, result);
126
+ }
127
+
116
128
rewriter.replaceOp (op, result);
117
129
}
118
130
@@ -181,9 +193,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
181
193
if (auto outVecType = dyn_cast<VectorType>(outType)) {
182
194
if (outVecType.isScalable ())
183
195
return failure ();
184
- if (outVecType.getShape ().size () > 1 )
185
- // Multi-dimensional vectors are currently unsupported.
186
- return failure ();
187
196
outType = outVecType.getElementType ();
188
197
}
189
198
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf (op.getIn ().getType ()));
@@ -200,8 +209,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
200
209
Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
201
210
if (saturateFP8)
202
211
in = clampInput (rewriter, loc, outElemType, in);
212
+ auto inVectorTy = dyn_cast<VectorType>(in.getType ());
203
213
VectorType truncResType = VectorType::get (4 , outElemType);
204
- if (!isa<VectorType>(in. getType ()) ) {
214
+ if (!inVectorTy ) {
205
215
Value asFloat = castToF32 (in, loc, rewriter);
206
216
Value asF8s = rewriter.create <amdgpu::PackedTrunc2xFp8Op>(
207
217
loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
@@ -213,18 +223,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
213
223
int64_t numElements = outType.getNumElements ();
214
224
Value zero = rewriter.create <arith::ConstantOp>(
215
225
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
216
- Value result = rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
217
226
if (outType.getShape ().empty ()) {
218
227
Value scalarIn =
219
228
rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
220
229
// Recurse to send the 0-D vector case to the 1-D vector case
221
230
Value scalarTrunc =
222
231
rewriter.create <arith::TruncFOp>(loc, outElemType, scalarIn);
223
- result = rewriter.create <vector::InsertOp>(loc, scalarTrunc, zero,
224
- ArrayRef<int64_t >{});
232
+ Value result = rewriter.create <vector::InsertOp>(loc, scalarTrunc, zero,
233
+ ArrayRef<int64_t >{});
225
234
return rewriter.replaceOp (op, result);
226
235
}
227
236
237
+ VectorType flatTy = VectorType::get (SmallVector<int64_t >{numElements},
238
+ outType.getElementType ());
239
+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, flatTy, zero);
240
+
241
+ if (inVectorTy.getRank () > 1 ) {
242
+ inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
243
+ inVectorTy.getElementType ());
244
+ in = rewriter.create <vector::ShapeCastOp>(loc, inVectorTy, in);
245
+ }
246
+
228
247
for (int64_t i = 0 ; i < numElements; i += 4 ) {
229
248
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
230
249
Value thisResult = nullptr ;
@@ -245,6 +264,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
245
264
result = rewriter.create <vector::InsertStridedSliceOp>(loc, thisResult,
246
265
result, i, 1 );
247
266
}
267
+
268
+ if (inVectorTy.getRank () != outType.getRank ()) {
269
+ result = rewriter.create <vector::ShapeCastOp>(loc, outType, result);
270
+ }
271
+
248
272
rewriter.replaceOp (op, result);
249
273
}
250
274
0 commit comments