Skip to content

[mlir][spirv] Add IsInf/IsNan expansion for WebGPU #86903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
namespace mlir {
namespace spirv {

/// Appends to a pattern list additional patterns to expand extended
/// multiplication ops into regular arithmetic ops. Extended multiplication ops
/// are not supported by the WebGPU Shading Language (WGSL).
/// Appends patterns to expand extended multiplication and adition ops into
/// regular arithmetic ops. Extended arithmetic ops are not supported by the
/// WebGPU Shading Language (WGSL).
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns);

/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
/// These are not supported by the WebGPU Shading Language (WGSL). We follow
/// fast math assumptions and assume that all floating point values are finite.
void populateSPIRVExpandNonFiniteArithmeticPatterns(
RewritePatternSet &patterns);

} // namespace spirv
} // namespace mlir

Expand Down
54 changes: 41 additions & 13 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);

return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}

Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
Value lhs, Value rhs,
bool signExtendArguments) {
static Value lowerExtendedMultiplication(Operation *mulOp,
PatternRewriter &rewriter, Value lhs,
Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
// Emulate 64-bit multiplication by splitting each input element of type i32
Expand Down Expand Up @@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
}
};

struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IsInfOp op,
PatternRewriter &rewriter) const override {
// We assume values to be finite and turn `IsInf` info `false`.
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};

struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IsNanOp op,
PatternRewriter &rewriter) const override {
// We assume values to be finite and turn `IsNan` info `false`.
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
class WebGPUPreparePass
: public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
public:
struct WebGPUPreparePass final
: impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
Expand All @@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
patterns.add<
// clang-format off
ExpandSMulExtendedPattern,
ExpandUMulExtendedPattern,
ExpandAddCarryPattern
>(patterns.getContext());
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
ExpandAddCarryPattern>(patterns.getContext());
}

void populateSPIRVExpandNonFiniteArithmeticPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support `isInf` and `isNan`, see:
// https://github.com/gpuweb/gpuweb/pull/2311.
patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
}

} // namespace spirv
} // namespace mlir
32 changes: 32 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}

// CHECK-LABEL: func @is_inf_f32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
%0 = spirv.IsInf %a : f32
spirv.ReturnValue %0 : i1
}

// CHECK-LABEL: func @is_inf_4xf32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
%0 = spirv.IsInf %a : vector<4xf32>
spirv.ReturnValue %0 : vector<4xi1>
}

// CHECK-LABEL: func @is_nan_f32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
%0 = spirv.IsNan %a : f32
spirv.ReturnValue %0 : i1
}

// CHECK-LABEL: func @is_nan_4xf32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
%0 = spirv.IsNan %a : vector<4xf32>
spirv.ReturnValue %0 : vector<4xi1>
}

} // end module