9
9
#include " mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10
10
11
11
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12
- #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
13
12
#include " mlir/Dialect/Arith/IR/Arith.h"
14
13
#include " mlir/Dialect/Arith/Utils/Utils.h"
15
- #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
16
- #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
17
14
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
15
#include " mlir/IR/BuiltinTypes.h"
19
16
#include " mlir/IR/PatternMatch.h"
@@ -27,7 +24,6 @@ namespace mlir {
27
24
} // namespace mlir
28
25
29
26
using namespace mlir ;
30
- using namespace mlir ::amdgpu;
31
27
32
28
namespace {
33
29
struct ArithToAMDGPUConversionPass final
@@ -47,25 +43,12 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
47
43
48
44
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
49
45
bool saturateFP8 = false ;
50
- TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8,
51
- Chipset chipset)
52
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53
- chipset (chipset) {}
54
- Chipset chipset;
46
+ TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8)
47
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
55
48
56
49
LogicalResult match (arith::TruncFOp op) const override ;
57
50
void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
58
51
};
59
-
60
- struct TruncfToFloat16RewritePattern final
61
- : public OpRewritePattern<arith::TruncFOp> {
62
-
63
- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
64
-
65
- LogicalResult match (arith::TruncFOp op) const override ;
66
- void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
67
- };
68
-
69
52
} // end namespace
70
53
71
54
static Value castF32To (Type elementType, Value f32 , Location loc,
@@ -289,105 +272,17 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
289
272
rewriter.replaceOp (op, result);
290
273
}
291
274
292
- LogicalResult TruncfToFloat16RewritePattern::match (arith::TruncFOp op) const {
293
- Type outType = op.getOut ().getType ();
294
- Type inputType = getElementTypeOrSelf (op.getIn ());
295
- if (auto outVecType = dyn_cast<VectorType>(outType)) {
296
- if (outVecType.isScalable ())
297
- return failure ();
298
- outType = outVecType.getElementType ();
299
- }
300
- return success (outType.isF16 () && inputType.isF32 ());
301
- }
302
-
303
- void TruncfToFloat16RewritePattern::rewrite (arith::TruncFOp op,
304
- PatternRewriter &rewriter) const {
305
- Location loc = op.getLoc ();
306
- Value in = op.getIn ();
307
- Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
308
- VectorType truncResType = VectorType::get (2 , outElemType);
309
- auto inVectorTy = dyn_cast<VectorType>(in.getType ());
310
-
311
- // Handle the case where input type is not a vector type
312
- if (!inVectorTy) {
313
- auto sourceB = rewriter.create <LLVM::PoisonOp>(loc, rewriter.getF32Type ());
314
- Value asF16s =
315
- rewriter.create <ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316
- Value result = rewriter.create <vector::ExtractElementOp>(
317
- loc, asF16s, rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 ));
318
- return rewriter.replaceOp (op, result);
319
- }
320
- VectorType outType = cast<VectorType>(op.getOut ().getType ());
321
- int64_t numElements = outType.getNumElements ();
322
- Value zero = rewriter.createOrFold <arith::ConstantOp>(
323
- loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
324
- Value result = rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
325
-
326
- if (inVectorTy.getRank () > 1 ) {
327
- inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
328
- inVectorTy.getElementType ());
329
- in = rewriter.create <vector::ShapeCastOp>(loc, inVectorTy, in);
330
- }
331
-
332
- // Handle the vector case. We also handle the (uncommon) case where the vector
333
- // length is odd
334
- for (int64_t i = 0 ; i < numElements; i += 2 ) {
335
- int64_t elemsThisOp = std::min (numElements, i + 2 ) - i;
336
- Value thisResult = nullptr ;
337
- Value elemA = rewriter.create <vector::ExtractElementOp>(
338
- loc, in, rewriter.create <arith::ConstantIndexOp>(loc, i));
339
- Value elemB = rewriter.create <LLVM::PoisonOp>(loc, rewriter.getF32Type ());
340
-
341
- if (elemsThisOp == 2 ) {
342
- elemB = rewriter.create <vector::ExtractElementOp>(
343
- loc, in, rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + 1 ));
344
- }
345
-
346
- thisResult =
347
- rewriter.create <ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
348
- // Place back the truncated result into the possibly larger vector. If we
349
- // are operating on a size 2 vector, these operations should be folded away
350
- thisResult = rewriter.create <vector::ExtractStridedSliceOp>(
351
- loc, thisResult, 0 , elemsThisOp, 1 );
352
- result = rewriter.create <vector::InsertStridedSliceOp>(loc, thisResult,
353
- result, i, 1 );
354
- }
355
-
356
- if (inVectorTy.getRank () != outType.getRank ()) {
357
- result = rewriter.create <vector::ShapeCastOp>(loc, outType, result);
358
- }
359
-
360
- rewriter.replaceOp (op, result);
361
- }
362
-
363
275
void mlir::arith::populateArithToAMDGPUConversionPatterns (
364
- RewritePatternSet &patterns, bool convertFP8Arithmetic,
365
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
366
-
367
- if (convertFP8Arithmetic) {
368
- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
369
- patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
370
- saturateFP8Truncf, chipset);
371
- }
372
- if (allowPackedF16Rtz)
373
- patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
276
+ RewritePatternSet &patterns, bool saturateFP8TruncF) {
277
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
278
+ patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
279
+ saturateFP8TruncF);
374
280
}
375
281
376
282
void ArithToAMDGPUConversionPass::runOnOperation () {
377
283
Operation *op = getOperation ();
378
- MLIRContext *ctx = &getContext ();
379
284
RewritePatternSet patterns (op->getContext ());
380
- FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
381
- if (failed (maybeChipset)) {
382
- emitError (UnknownLoc::get (ctx), " Invalid chipset name: " + chipset);
383
- return signalPassFailure ();
384
- }
385
-
386
- bool convertFP8Arithmetic =
387
- (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40 ;
388
- arith::populateArithToAMDGPUConversionPatterns (
389
- patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390
- *maybeChipset);
285
+ arith::populateArithToAMDGPUConversionPatterns (patterns, saturateFP8Truncf);
391
286
if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
392
287
return signalPassFailure ();
393
288
}
0 commit comments