Skip to content

[mlir][nvgpu] Add nvgpu.rcp OP #100965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef NVGPU
#define NVGPU

include "mlir/Interfaces/InferTypeOpInterface.td"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit : In this dialect, we do not infer types. It is clearer to see the input and output directly in the IR. Otherwise, one would need to read TableGen or, even more challenging, the C++ implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to check the input/output type of rcp are equal with SameOperandsAndResultType. So you suggested we'd better change to use verifier?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point was about inferring the result type. When someone sees an IR like %out = nvgpu.rcp_approx %in : vector<32x16xf32>, the return type is unclear, making the IR unreadable. To understand the return type, one would need to look at the TableGen definition of the operation or read the C++ implementation.

// Current OP, unclear what is %out
%out = nvgpu.rcp_approx %in : vector<32x16xf32> 
// This is my proposal, it's clear what is %out
%out = nvgpu.rcp_approx %in : vector<32x16xf32> -> vector<32x16xf32>

You could use PredOpTrait or SameOperandsAndResultType, or handle it in the verifier. Does SameOperandsAndResultType forces to infer result type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, got it. Let me update the assembly format.

You could use PredOpTrait or SameOperandsAndResultType, or handle it in the verifier. Does SameOperandsAndResultType forces to infer result type?

Yes, SameOperandsAndResultType relies on InferTypeOpInterface and it forces to create build function with InferReturnTypes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about PredOpTrait?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @joker-eph and @grypp , both of formats are good to me. Since most ops of nvgpu dialect have explicit result types, I personally think it is fine to follow this style.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency is important, so that's reasonable.

That said it's not a matter of majority: we can also see the nvgpu dialect as "not using the common upstream practices" and in need of an upgrade.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recovered original style. Please help check if you have other questions about this pr, thanks a lot. :-)

%out = nvgpu.rcp_approx %in {rounding=approx, ftz}: vector<32x16xf32>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said it's not a matter of majority: we can also see the nvgpu dialect as "not using the common upstream practices" and in need of an upgrade.

We can upgrade it, but what's the guideline with inferring result types? It makes sense when input and output types same, like with arith. Otherwise, imho, it turns IR obscure and unreadable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense when input and output types same, like with arith.

I thought this is the case we're talking about here? The example provided above was %out = nvgpu.rcp_approx %in : vector<32x16xf32> -> vector<32x16xf32>

For the generality, it is slightly more subtle, for example what about %res = memref.load %ptr[%idx] : memref<10xf32> ?

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -109,10 +110,22 @@ def TensorMapInterleaveKind : I32EnumAttr<"TensorMapInterleaveKind",
let cppNamespace = "::mlir::nvgpu";
}

def RcpApprox : I32EnumAttrCase<"APPROX", 0, "approx">;
def RcpRN : I32EnumAttrCase<"RN", 1, "rn">;
def RcpRZ : I32EnumAttrCase<"RZ", 2, "rz">;
def RcpRM : I32EnumAttrCase<"RM", 3, "rm">;
def RcpRP : I32EnumAttrCase<"RP", 4, "rp">;
def RcpRoundingMode : I32EnumAttr<"RcpRoundingMode", "Rounding mode of rcp",
[RcpApprox, RcpRN, RcpRZ, RcpRM, RcpRP]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::nvgpu";
}

def TensorMapSwizzleAttr : EnumAttr<NVGPU_Dialect, TensorMapSwizzleKind, "swizzle">;
def TensorMapL2PromoAttr : EnumAttr<NVGPU_Dialect, TensorMapL2PromoKind, "l2promo">;
def TensorMapOOBAttr : EnumAttr<NVGPU_Dialect, TensorMapOOBKind, "oob">;
def TensorMapInterleaveAttr : EnumAttr<NVGPU_Dialect, TensorMapInterleaveKind, "interleave">;
def RcpRoundingModeAttr : EnumAttr<NVGPU_Dialect, RcpRoundingMode, "rcp_rounding_mode">;

//===----------------------------------------------------------------------===//
// NVGPU Type Definitions
Expand Down Expand Up @@ -802,4 +815,24 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
let hasVerifier = 1;
}

def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
SameOperandsAndResultType]> {
let summary = "The reciprocal calculation for vector types";
let description = [{
Reciprocal calculation for `vector` types using `nvvm.rcp` OPs.

Currently, only the `approx` rounding mode and `ftz` are supported, and only for the `f32` type.

The input and output must be of the same vector type and shape.
}];
let arguments = (ins VectorOf<[F32]>:$in,
DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
UnitAttr:$ftz);
let results = (outs VectorOf<[F32]>:$out);
let assemblyFormat = [{
$in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
attr-dict `:` type($out)
}];
let hasVerifier = 1;
}
#endif // NVGPU
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
Expand Down
37 changes: 36 additions & 1 deletion mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -1666,6 +1667,40 @@ struct NVGPUTmaPrefetchOpLowering
}
};

struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto i64Ty = b.getI64Type();
auto f32Ty = b.getF32Type();
VectorType inTy = op.getIn().getType();
// apply rcp.approx.ftz.f on each element in vector.
auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
for (int i = 0; i < numElems; i++) {
Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
}
return ret1DVec;
};
if (inTy.getRank() == 1) {
rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
return success();
}
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
OpAdaptor adaptor(operands);
return convert1DVec(llvm1DVectorTy, adaptor.getIn());
},
rewriter);
}
};
} // namespace

void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
Expand All @@ -1688,5 +1723,5 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
}
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,21 @@ LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// RcpOp
//===----------------------------------------------------------------------===//

LogicalResult RcpOp::verify() {
RcpRoundingModeAttr rounding = getRoundingAttr();
bool ftz = getFtz();
// Currently, only `rcp_approx` and `ftz` is supported.
if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
return emitOpError() << "has a limitation. " << rounding
<< " or non-ftz is not supported yet.";
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1339,3 +1339,18 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// CHECK-LABEL: @rcp_approx_ftz_f32
// CHECK-SAME: %[[IN:.*]]: vector<32x16xf32>
func.func @rcp_approx_ftz_f32(%in: vector<32x16xf32>) {
// CHECK: %[[IN_LLVM:.*]] = builtin.unrealized_conversion_cast %[[IN]] : vector<32x16xf32> to !llvm.array<32 x vector<16xf32>>
// CHECK: %[[IN1DVEC:.*]] = llvm.extractvalue %[[IN_LLVM]][0] : !llvm.array<32 x vector<16xf32>>
// CHECK: %[[OUT1DVEC:.*]] = llvm.mlir.undef : vector<16xf32>
// CHECK: %[[IDX_0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM_0:.*]] = llvm.extractelement %[[IN1DVEC]][%[[IDX_0]] : i64]
// CHECK: %[[ELEM_RCP0:.*]] = nvvm.rcp.approx.ftz.f %[[ELEM_0]] : f32
// CHECK: llvm.insertelement %[[ELEM_RCP0]], %[[OUT1DVEC]][%[[IDX_0]] : i64] : vector<16xf32>
// CHECK-COUNT-511: nvvm.rcp.approx.ftz.f
%out = nvgpu.rcp %in {rounding = approx, ftz} : vector<32x16xf32>
return
}
18 changes: 18 additions & 0 deletions mlir/test/Dialect/NVGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,21 @@ func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: m
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
return
}
// -----

func.func @rcp_unsupported_rounding_0(%in : vector<16xf32>) {
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rn> or non-ftz is not supported yet.}}
%out = nvgpu.rcp %in {rounding = rn, ftz} : vector<16xf32>
}
// -----

func.func @rcp_unsupported_rounding_1(%in : vector<16xf32>) {
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rz> or non-ftz is not supported yet.}}
%out = nvgpu.rcp %in {rounding = rz} : vector<16xf32>
}
// -----

func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
%out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
}
Loading