Skip to content

Commit 330a232

Browse files
committed
[mlir][gpu] Add i64 & f64 support to gpu.shuffle
This patch adds support for i64, f64 values in `gpu.shuffle`, rewriting 64bit shuffles into two 32bit shuffles. The reason behind this change is that both CUDA & HIP support this kind of shuffling. The implementation provided by this patch is based on the LLVM IR emitted by clang for 64bit shuffles when using `-O3`. Reviewed By: makslevental Differential Revision: https://reviews.llvm.org/D148974
1 parent ee740b7 commit 330a232

File tree

6 files changed

+163
-5
lines changed

6 files changed

+163
-5
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -936,14 +936,17 @@ def GPU_ShuffleMode : I32EnumAttr<"ShuffleMode",
936936
def GPU_ShuffleModeAttr : EnumAttr<GPU_Dialect, GPU_ShuffleMode,
937937
"shuffle_mode">;
938938

939-
def I32OrF32 : TypeConstraint<Or<[I32.predicate, F32.predicate]>,
940-
"i32 or f32">;
939+
def I32I64F32OrF64 : TypeConstraint<Or<[I32.predicate,
940+
I64.predicate,
941+
F32.predicate,
942+
F64.predicate]>,
943+
"i32, i64, f32 or f64">;
941944

942945
def GPU_ShuffleOp : GPU_Op<
943946
"shuffle", [Pure, AllTypesMatch<["value", "shuffleResult"]>]>,
944-
Arguments<(ins I32OrF32:$value, I32:$offset, I32:$width,
947+
Arguments<(ins I32I64F32OrF64:$value, I32:$offset, I32:$width,
945948
GPU_ShuffleModeAttr:$mode)>,
946-
Results<(outs I32OrF32:$shuffleResult, I1:$valid)> {
949+
Results<(outs I32I64F32OrF64:$shuffleResult, I1:$valid)> {
947950
let summary = "Shuffles values within a subgroup.";
948951
let description = [{
949952
The "shuffle" op moves values to a different invocation within the same

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,17 @@ std::unique_ptr<OperationPass<func::FuncOp>> createGpuMapParallelLoopsPass();
5656
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
5757
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
5858

59+
/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
60+
void populateGpuShufflePatterns(RewritePatternSet &patterns);
61+
5962
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
6063
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
6164

6265
/// Collect all patterns to rewrite ops within the GPU dialect.
6366
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
6467
populateGpuAllReducePatterns(patterns);
6568
populateGpuGlobalIdPatterns(patterns);
69+
populateGpuShufflePatterns(patterns);
6670
}
6771

6872
namespace gpu {

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
5050
Transforms/KernelOutlining.cpp
5151
Transforms/MemoryPromotion.cpp
5252
Transforms/ParallelLoopMapper.cpp
53+
Transforms/ShuffleRewriter.cpp
5354
Transforms/SerializeToBlob.cpp
5455
Transforms/SerializeToCubin.cpp
5556
Transforms/SerializeToHsaco.cpp
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements in-dialect rewriting of the shuffle op for types i64 and
10+
// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
11+
// implementation using shifts and truncations can be obtained using clang: by
12+
// emitting IR for shuffle operations with `-O3`.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Pass/Pass.h"
22+
23+
using namespace mlir;
24+
25+
namespace {
26+
struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
27+
using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
28+
29+
void initialize() {
30+
// Required as the pattern will replace the Op with 2 additional ShuffleOps.
31+
setHasBoundedRewriteRecursion();
32+
}
33+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
34+
PatternRewriter &rewriter) const override {
35+
auto loc = op.getLoc();
36+
auto value = op.getValue();
37+
auto valueType = value.getType();
38+
auto valueLoc = value.getLoc();
39+
auto i32 = rewriter.getI32Type();
40+
auto i64 = rewriter.getI64Type();
41+
42+
// If the type of the value is either i32 or f32, the op is already valid.
43+
if (valueType.getIntOrFloatBitWidth() == 32)
44+
return failure();
45+
46+
Value lo, hi;
47+
48+
// Float types must be converted to i64 to extract the bits.
49+
if (isa<FloatType>(valueType))
50+
value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
51+
52+
// Get the low bits by trunc(value).
53+
lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
54+
55+
// Get the high bits by trunc(value >> 32).
56+
auto c32 = rewriter.create<arith::ConstantOp>(
57+
valueLoc, rewriter.getIntegerAttr(i64, 32));
58+
hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
59+
hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
60+
61+
// Shuffle the values.
62+
ValueRange loRes =
63+
rewriter
64+
.create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
65+
op.getWidth(), op.getMode())
66+
.getResults();
67+
ValueRange hiRes =
68+
rewriter
69+
.create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
70+
op.getWidth(), op.getMode())
71+
.getResults();
72+
73+
// Convert lo back to i64.
74+
lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
75+
76+
// Convert hi back to i64.
77+
hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
78+
hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
79+
80+
// Obtain the shuffled bits hi | lo.
81+
value = rewriter.create<arith::OrIOp>(loc, hi, lo);
82+
83+
// Convert the value back to float.
84+
if (isa<FloatType>(valueType))
85+
value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
86+
87+
// Obtain the shuffle validity by combining both validities.
88+
auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
89+
90+
// Replace the op.
91+
rewriter.replaceOp(op, {value, validity});
92+
return success();
93+
}
94+
};
95+
} // namespace
96+
97+
void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
98+
patterns.add<GpuShuffleRewriter>(patterns.getContext());
99+
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
318318
// -----
319319

320320
func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
321-
// expected-error@+1 {{operand #0 must be i32 or f32}}
321+
// expected-error@+1 {{operand #0 must be i32, i64, f32 or f64}}
322322
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
323323
return
324324
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
2+
3+
module {
4+
// CHECK-LABEL: func.func @shuffleF64
5+
// CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: f64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<f64, 1>) {
6+
func.func @shuffleF64(%sz : index, %value: f64, %offset: i32, %width: i32, %mem: memref<f64, 1>) {
7+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
8+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
9+
// CHECK: %[[INTVAL:.*]] = arith.bitcast %[[VALUE]] : f64 to i64
10+
// CHECK-NEXT: %[[LO:.*]] = arith.trunci %[[INTVAL]] : i64 to i32
11+
// CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[INTVAL]], %[[C32:.*]] : i64
12+
// CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
13+
// CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
14+
// CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
15+
// CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
16+
// CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
17+
// CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
18+
// CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
19+
// CHECK-NEXT: = arith.bitcast %[[SHFLINT]] : i64 to f64
20+
%shfl, %pred = gpu.shuffle xor %value, %offset, %width : f64
21+
memref.store %shfl, %mem[] : memref<f64, 1>
22+
gpu.terminator
23+
}
24+
return
25+
}
26+
}
27+
28+
// -----
29+
30+
module {
31+
// CHECK-LABEL: func.func @shuffleI64
32+
// CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: i64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<i64, 1>) {
33+
func.func @shuffleI64(%sz : index, %value: i64, %offset: i32, %width: i32, %mem: memref<i64, 1>) {
34+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
35+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
36+
// CHECK: %[[LO:.*]] = arith.trunci %[[VALUE]] : i64 to i32
37+
// CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[VALUE]], %[[C32:.*]] : i64
38+
// CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
39+
// CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
40+
// CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
41+
// CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
42+
// CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
43+
// CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
44+
// CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
45+
%shfl, %pred = gpu.shuffle xor %value, %offset, %width : i64
46+
memref.store %shfl, %mem[] : memref<i64, 1>
47+
gpu.terminator
48+
}
49+
return
50+
}
51+
}

0 commit comments

Comments
 (0)