Skip to content

Commit f6935c7

Browse files
Revert "[MLIR][AMDGPU] Introduce fp16 packed arithmetic (llvm#105688)"
This reverts commit 1387ba4.
1 parent bea0be3 commit f6935c7

File tree

12 files changed

+11
-203
lines changed

12 files changed

+11
-203
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
#ifndef MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
1010
#define MLIR_CONVERSION_ARITHTOAMDGPU_ARITHTOAMDGPU_H
1111

12-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1312
#include <memory>
14-
#include <string>
1513

1614
namespace mlir {
1715

@@ -28,10 +26,7 @@ namespace arith {
2826
/// to the largest value of that type instead of being rewritten to Inf (aka
2927
/// NaN).
3028
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
31-
bool convertFP8Arithmetic,
32-
bool saturateFP8Truncf,
33-
bool allowPackedF16Rtz,
34-
amdgpu::Chipset chipset);
29+
bool saturateFP8TruncF);
3530
} // namespace arith
3631
} // namespace mlir
3732

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,9 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
153153
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
154154

155155
let options = [
156-
Option<"chipset", "chipset", "std::string",
157-
/*default=*/"\"gfx000\"",
158-
"Chipset that these operations will run on">,
159156
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
160157
/*default=*/"false",
161158
"Use saturating truncation for 8-bit float types">,
162-
Option<"allowPackedF16Rtz", "allow-packed-f16-round-to-zero", "bool",
163-
/*default=*/"false",
164-
"Whether we should allow f32->f16 packed round-to-zero conversion">,
165159
];
166160
}
167161

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def AMDGPU_Dialect : Dialect {
2525

2626

2727
let dependentDialects = [
28-
"ROCDL::ROCDLDialect",
2928
"arith::ArithDialect",
3029
"gpu::GPUDialect"
3130
];

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def ROCDL_BallotOp :
166166
let summary = "Vote across thread group";
167167

168168
let description = [{
169-
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
169+
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
170170
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
171171
}];
172172

@@ -579,21 +579,6 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
579579
}];
580580
}
581581

582-
//===---------------------------------------------------------------------===//
583-
// 16-bit float intrinsics
584-
//===---------------------------------------------------------------------===//
585-
def ROCDL_CvtPkRtz:
586-
ROCDL_IntrOp<"cvt.pkrtz", [], [], [Pure], 1>,
587-
Arguments<(ins F32:$srcA, F32:$srcB)> {
588-
let summary = "Convert two f32 input into a vector<2xf16>";
589-
let description = [{
590-
Convert two f32 values into a packed vector<2xf16>.
591-
}];
592-
let assemblyFormat = [{
593-
attr-dict $srcA `,` $srcB `:` type($res)
594-
}];
595-
}
596-
597582
//===---------------------------------------------------------------------===//
598583
// 8-bit float intrinsics
599584
//===---------------------------------------------------------------------===//

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 7 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12-
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1312
#include "mlir/Dialect/Arith/IR/Arith.h"
1413
#include "mlir/Dialect/Arith/Utils/Utils.h"
15-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16-
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1714
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1815
#include "mlir/IR/BuiltinTypes.h"
1916
#include "mlir/IR/PatternMatch.h"
@@ -27,7 +24,6 @@ namespace mlir {
2724
} // namespace mlir
2825

2926
using namespace mlir;
30-
using namespace mlir::amdgpu;
3127

3228
namespace {
3329
struct ArithToAMDGPUConversionPass final
@@ -47,25 +43,12 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
4743

4844
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
4945
bool saturateFP8 = false;
50-
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
51-
Chipset chipset)
52-
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53-
chipset(chipset) {}
54-
Chipset chipset;
46+
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
47+
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
5548

5649
LogicalResult match(arith::TruncFOp op) const override;
5750
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
5851
};
59-
60-
struct TruncfToFloat16RewritePattern final
61-
: public OpRewritePattern<arith::TruncFOp> {
62-
63-
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
64-
65-
LogicalResult match(arith::TruncFOp op) const override;
66-
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
67-
};
68-
6952
} // end namespace
7053

7154
static Value castF32To(Type elementType, Value f32, Location loc,
@@ -289,105 +272,17 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
289272
rewriter.replaceOp(op, result);
290273
}
291274

292-
LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
293-
Type outType = op.getOut().getType();
294-
Type inputType = getElementTypeOrSelf(op.getIn());
295-
if (auto outVecType = dyn_cast<VectorType>(outType)) {
296-
if (outVecType.isScalable())
297-
return failure();
298-
outType = outVecType.getElementType();
299-
}
300-
return success(outType.isF16() && inputType.isF32());
301-
}
302-
303-
void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
304-
PatternRewriter &rewriter) const {
305-
Location loc = op.getLoc();
306-
Value in = op.getIn();
307-
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
308-
VectorType truncResType = VectorType::get(2, outElemType);
309-
auto inVectorTy = dyn_cast<VectorType>(in.getType());
310-
311-
// Handle the case where input type is not a vector type
312-
if (!inVectorTy) {
313-
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314-
Value asF16s =
315-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316-
Value result = rewriter.create<vector::ExtractElementOp>(
317-
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
318-
return rewriter.replaceOp(op, result);
319-
}
320-
VectorType outType = cast<VectorType>(op.getOut().getType());
321-
int64_t numElements = outType.getNumElements();
322-
Value zero = rewriter.createOrFold<arith::ConstantOp>(
323-
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
324-
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
325-
326-
if (inVectorTy.getRank() > 1) {
327-
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
328-
inVectorTy.getElementType());
329-
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
330-
}
331-
332-
// Handle the vector case. We also handle the (uncommon) case where the vector
333-
// length is odd
334-
for (int64_t i = 0; i < numElements; i += 2) {
335-
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
336-
Value thisResult = nullptr;
337-
Value elemA = rewriter.create<vector::ExtractElementOp>(
338-
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
339-
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
340-
341-
if (elemsThisOp == 2) {
342-
elemB = rewriter.create<vector::ExtractElementOp>(
343-
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
344-
}
345-
346-
thisResult =
347-
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
348-
// Place back the truncated result into the possibly larger vector. If we
349-
// are operating on a size 2 vector, these operations should be folded away
350-
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
351-
loc, thisResult, 0, elemsThisOp, 1);
352-
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
353-
result, i, 1);
354-
}
355-
356-
if (inVectorTy.getRank() != outType.getRank()) {
357-
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
358-
}
359-
360-
rewriter.replaceOp(op, result);
361-
}
362-
363275
void mlir::arith::populateArithToAMDGPUConversionPatterns(
364-
RewritePatternSet &patterns, bool convertFP8Arithmetic,
365-
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
366-
367-
if (convertFP8Arithmetic) {
368-
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
369-
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
370-
saturateFP8Truncf, chipset);
371-
}
372-
if (allowPackedF16Rtz)
373-
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
276+
RewritePatternSet &patterns, bool saturateFP8TruncF) {
277+
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
278+
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
279+
saturateFP8TruncF);
374280
}
375281

376282
void ArithToAMDGPUConversionPass::runOnOperation() {
377283
Operation *op = getOperation();
378-
MLIRContext *ctx = &getContext();
379284
RewritePatternSet patterns(op->getContext());
380-
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
381-
if (failed(maybeChipset)) {
382-
emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
383-
return signalPassFailure();
384-
}
385-
386-
bool convertFP8Arithmetic =
387-
(*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
388-
arith::populateArithToAMDGPUConversionPatterns(
389-
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390-
*maybeChipset);
285+
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
391286
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
392287
return signalPassFailure();
393288
}

mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ add_mlir_conversion_library(MLIRArithToAMDGPU
1212

1313
LINK_LIBS PUBLIC
1414
MLIRAMDGPUDialect
15-
MLIRAMDGPUUtils
1615
MLIRArithDialect
1716
MLIRArithUtils
1817
MLIRVectorDialect

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17-
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1817
#include "mlir/IR/Builders.h"
1918
#include "mlir/IR/BuiltinTypes.h"
2019
#include "mlir/IR/Diagnostics.h"

mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
1111

1212
LINK_LIBS PUBLIC
1313
MLIRArithDialect
14-
MLIRROCDLDialect
1514
# Needed for GPU address space enum definition
1615
MLIRGPUDialect
1716
MLIRIR

mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir

Lines changed: 0 additions & 51 deletions
This file was deleted.

mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt --split-input-file %s \
2-
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{chipset=gfx940 saturate-fp8-truncf=true}))' \
2+
// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
33
// RUN: | FileCheck %s
44

55
// CHECK-LABEL: func.func @scalar_trunc

mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx940" | FileCheck %s
1+
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu | FileCheck %s
22

33
// CHECK-LABEL: func.func @scalar_ext
44
// CHECK-SAME: ([[V:%.+]]: f8E5M2FNUZ)

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,12 +530,6 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
530530
llvm.return %source5 : i32
531531
}
532532

533-
llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf16> {
534-
// CHECK: call <2 x half> @llvm.amdgcn.cvt.pkrtz(float {{.*}}, float {{.*}})
535-
%source = rocdl.cvt.pkrtz %sourceA, %sourceB : vector<2xf16>
536-
llvm.return %source : vector<2xf16>
537-
}
538-
539533
// CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
540534
// CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
541535
// CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

0 commit comments

Comments
 (0)