|
| 1 | +//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting ----===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs |
| 10 | +// where: |
| 11 | +// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 16 | +#include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 17 | +#include "mlir/Dialect/Index/IR/IndexOps.h" |
| 18 | +#include "mlir/IR/Builders.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
| 20 | +#include "mlir/Pass/Pass.h" |
| 21 | + |
| 22 | +using namespace mlir; |
| 23 | + |
| 24 | +namespace { |
| 25 | +struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> { |
| 26 | + using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern; |
| 27 | + |
| 28 | + LogicalResult |
| 29 | + matchAndRewrite(gpu::SubgroupIdOp op, |
| 30 | + PatternRewriter &rewriter) const override { |
| 31 | + // Calculation of the thread's subgroup identifier. |
| 32 | + // |
| 33 | + // The process involves mapping the thread's 3D identifier within its |
| 34 | + // block (b_id.x, b_id.y, b_id.z) to a 1D linear index. |
| 35 | + // This linearization assumes a layout where the x-dimension (w_dim.x) |
| 36 | + // varies most rapidly (i.e., it is the innermost dimension). |
| 37 | + // |
| 38 | + // The formula for the linearized thread index is: |
| 39 | + // L = tid.x + dim.x * (tid.y + (dim.y * tid.z)) |
| 40 | + // |
| 41 | + // Subsequently, the range of linearized indices [0, N_threads-1] is |
| 42 | + // divided into consecutive, non-overlapping segments, each representing |
| 43 | + // a subgroup of size 'subgroup_size'. |
| 44 | + // |
| 45 | + // Example Partitioning (N = subgroup_size): |
| 46 | + // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... | |
| 47 | + // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... | |
| 48 | + // |
| 49 | + // The subgroup identifier is obtained via integer division of the |
| 50 | + // linearized thread index by the predefined 'subgroup_size'. |
| 51 | + // |
| 52 | + // subgroup_id = floor( L / subgroup_size ) |
| 53 | + // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / |
| 54 | + // subgroup_size |
| 55 | + |
| 56 | + auto loc = op->getLoc(); |
| 57 | + |
| 58 | + Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); |
| 59 | + Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y); |
| 60 | + Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
| 61 | + Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y); |
| 62 | + Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z); |
| 63 | + |
| 64 | + Value dimYxIdZ = |
| 65 | + rewriter.create<index::MulOp>(loc, dimY, tidZ); |
| 66 | + Value dimYxIdZPlusIdY = |
| 67 | + rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY); |
| 68 | + Value dimYxIdZPlusIdYTimesDimX = |
| 69 | + rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY); |
| 70 | + Value IdXPlusDimYxIdZPlusIdYTimesDimX = |
| 71 | + rewriter.create<index::AddOp>(loc, tidX, |
| 72 | + dimYxIdZPlusIdYTimesDimX); |
| 73 | + Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>( |
| 74 | + loc, rewriter.getIndexType(), /*upper_bound = */ nullptr); |
| 75 | + Value subgroupIdOp = rewriter.create<index::DivUOp>( |
| 76 | + loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize); |
| 77 | + rewriter.replaceOp(op, {subgroupIdOp}); |
| 78 | + return success(); |
| 79 | + } |
| 80 | +}; |
| 81 | + |
| 82 | +} // namespace |
| 83 | + |
| 84 | +void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) { |
| 85 | + patterns.add<GpuSubgroupIdRewriter>(patterns.getContext()); |
| 86 | +} |
0 commit comments