Skip to content

Commit 2b23e6c

Browse files
authored
[mlir][nvgpu] Add nvgpu.rcp OP (llvm#100965)
This PR introduces a new OP for reciprocal calculation for `vector` types using `nvvm.rcp` OPs. Currently, it supports only f32 types --------- Co-authored-by: jingzec <[email protected]>
1 parent abc2fe3 commit 2b23e6c

File tree

6 files changed

+118
-1
lines changed

6 files changed

+118
-1
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#ifndef NVGPU
2121
#define NVGPU
2222

23+
include "mlir/Interfaces/InferTypeOpInterface.td"
2324
include "mlir/Interfaces/SideEffectInterfaces.td"
2425
include "mlir/IR/AttrTypeBase.td"
2526
include "mlir/IR/OpBase.td"
@@ -109,10 +110,22 @@ def TensorMapInterleaveKind : I32EnumAttr<"TensorMapInterleaveKind",
109110
let cppNamespace = "::mlir::nvgpu";
110111
}
111112

113+
def RcpApprox : I32EnumAttrCase<"APPROX", 0, "approx">;
114+
def RcpRN : I32EnumAttrCase<"RN", 1, "rn">;
115+
def RcpRZ : I32EnumAttrCase<"RZ", 2, "rz">;
116+
def RcpRM : I32EnumAttrCase<"RM", 3, "rm">;
117+
def RcpRP : I32EnumAttrCase<"RP", 4, "rp">;
118+
def RcpRoundingMode : I32EnumAttr<"RcpRoundingMode", "Rounding mode of rcp",
119+
[RcpApprox, RcpRN, RcpRZ, RcpRM, RcpRP]> {
120+
let genSpecializedAttr = 0;
121+
let cppNamespace = "::mlir::nvgpu";
122+
}
123+
112124
def TensorMapSwizzleAttr : EnumAttr<NVGPU_Dialect, TensorMapSwizzleKind, "swizzle">;
113125
def TensorMapL2PromoAttr : EnumAttr<NVGPU_Dialect, TensorMapL2PromoKind, "l2promo">;
114126
def TensorMapOOBAttr : EnumAttr<NVGPU_Dialect, TensorMapOOBKind, "oob">;
115127
def TensorMapInterleaveAttr : EnumAttr<NVGPU_Dialect, TensorMapInterleaveKind, "interleave">;
128+
def RcpRoundingModeAttr : EnumAttr<NVGPU_Dialect, RcpRoundingMode, "rcp_rounding_mode">;
116129

117130
//===----------------------------------------------------------------------===//
118131
// NVGPU Type Definitions
@@ -802,4 +815,24 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
802815
let hasVerifier = 1;
803816
}
804817

818+
def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
819+
SameOperandsAndResultType]> {
820+
let summary = "The reciprocal calculation for vector types";
821+
let description = [{
822+
Reciprocal calculation for `vector` types using `nvvm.rcp` OPs.
823+
824+
Currently, only the `approx` rounding mode and `ftz` are supported, and only for the `f32` type.
825+
826+
The input and output must be of the same vector type and shape.
827+
}];
828+
let arguments = (ins VectorOf<[F32]>:$in,
829+
DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
830+
UnitAttr:$ftz);
831+
let results = (outs VectorOf<[F32]>:$out);
832+
let assemblyFormat = [{
833+
$in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
834+
attr-dict `:` type($out)
835+
}];
836+
let hasVerifier = 1;
837+
}
805838
#endif // NVGPU

mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/BuiltinTypes.h"
1818
#include "mlir/IR/Dialect.h"
1919
#include "mlir/IR/OpDefinition.h"
20+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

2223
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1415
#include "mlir/Dialect/Arith/IR/Arith.h"
1516
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -1666,6 +1667,40 @@ struct NVGPUTmaPrefetchOpLowering
16661667
}
16671668
};
16681669

1670+
struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1671+
using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1672+
LogicalResult
1673+
matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1674+
ConversionPatternRewriter &rewriter) const override {
1675+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1676+
auto i64Ty = b.getI64Type();
1677+
auto f32Ty = b.getF32Type();
1678+
VectorType inTy = op.getIn().getType();
1679+
// apply rcp.approx.ftz.f on each element in vector.
1680+
auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1681+
Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
1682+
int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1683+
for (int i = 0; i < numElems; i++) {
1684+
Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1685+
Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1686+
Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1687+
ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1688+
}
1689+
return ret1DVec;
1690+
};
1691+
if (inTy.getRank() == 1) {
1692+
rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1693+
return success();
1694+
}
1695+
return LLVM::detail::handleMultidimensionalVectors(
1696+
op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1697+
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1698+
OpAdaptor adaptor(operands);
1699+
return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1700+
},
1701+
rewriter);
1702+
}
1703+
};
16691704
} // namespace
16701705

16711706
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1688,5 +1723,5 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
16881723
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
16891724
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
16901725
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1691-
NVGPUMmaSparseSyncLowering>(converter);
1726+
NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
16921727
}

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,21 @@ LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
644644
return success();
645645
}
646646

647+
//===----------------------------------------------------------------------===//
648+
// RcpOp
649+
//===----------------------------------------------------------------------===//
650+
651+
LogicalResult RcpOp::verify() {
652+
RcpRoundingModeAttr rounding = getRoundingAttr();
653+
bool ftz = getFtz();
654+
// Currently, only `rcp_approx` and `ftz` is supported.
655+
if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
656+
return emitOpError() << "has a limitation. " << rounding
657+
<< " or non-ftz is not supported yet.";
658+
}
659+
return success();
660+
}
661+
647662
//===----------------------------------------------------------------------===//
648663
// TableGen'd dialect, type, and op definitions
649664
//===----------------------------------------------------------------------===//

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,3 +1339,18 @@ module attributes {transform.with_named_sequence} {
13391339
transform.yield
13401340
}
13411341
}
1342+
1343+
// CHECK-LABEL: @rcp_approx_ftz_f32
1344+
// CHECK-SAME: %[[IN:.*]]: vector<32x16xf32>
1345+
func.func @rcp_approx_ftz_f32(%in: vector<32x16xf32>) {
1346+
// CHECK: %[[IN_LLVM:.*]] = builtin.unrealized_conversion_cast %[[IN]] : vector<32x16xf32> to !llvm.array<32 x vector<16xf32>>
1347+
// CHECK: %[[IN1DVEC:.*]] = llvm.extractvalue %[[IN_LLVM]][0] : !llvm.array<32 x vector<16xf32>>
1348+
// CHECK: %[[OUT1DVEC:.*]] = llvm.mlir.undef : vector<16xf32>
1349+
// CHECK: %[[IDX_0:.+]] = llvm.mlir.constant(0 : i64) : i64
1350+
// CHECK: %[[ELEM_0:.*]] = llvm.extractelement %[[IN1DVEC]][%[[IDX_0]] : i64]
1351+
// CHECK: %[[ELEM_RCP0:.*]] = nvvm.rcp.approx.ftz.f %[[ELEM_0]] : f32
1352+
// CHECK: llvm.insertelement %[[ELEM_RCP0]], %[[OUT1DVEC]][%[[IDX_0]] : i64] : vector<16xf32>
1353+
// CHECK-COUNT-511: nvvm.rcp.approx.ftz.f
1354+
%out = nvgpu.rcp %in {rounding = approx, ftz} : vector<32x16xf32>
1355+
return
1356+
}

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,21 @@ func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: m
336336
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
337337
return
338338
}
339+
// -----
340+
341+
func.func @rcp_unsupported_rounding_0(%in : vector<16xf32>) {
342+
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rn> or non-ftz is not supported yet.}}
343+
%out = nvgpu.rcp %in {rounding = rn, ftz} : vector<16xf32>
344+
}
345+
// -----
346+
347+
func.func @rcp_unsupported_rounding_1(%in : vector<16xf32>) {
348+
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rz> or non-ftz is not supported yet.}}
349+
%out = nvgpu.rcp %in {rounding = rz} : vector<16xf32>
350+
}
351+
// -----
352+
353+
func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
354+
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
355+
%out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
356+
}

0 commit comments

Comments
 (0)