-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XEGPU] Add blocking support for scatter ops #144766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Jianhui Li (Jianhui-Li) ChangesAdd blocking support for scatter ops: Create_tdesc, update, prefetch, load and store. It also enables the load/store with chunk size. Full diff: https://github.com/llvm/llvm-project/pull/144766.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index a3826c56e1f62..02bfab8263847 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(Operation *op) const {
- if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
+ xegpu::UpdateOffsetOp>(op))
return getTileShape(op->getOpResult(0));
- if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+ if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp>(op))
+ if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
return getTileShape(op->getOpOperand(1));
if (isa<xegpu::DpasOp>(op)) {
@@ -295,12 +297,39 @@ void XeGPUBlockingPass::runOnOperation() {
Type elemTy = type.getElementType();
Type newTy;
- if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
- newTy = xegpu::TensorDescType::get(
- ctx, tileShape, elemTy, tdescTy.getEncoding(),
- tdescTy.getLayoutAttr().dropInstData());
- else
+ if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+
+ Attribute encoding = tdescTy.getEncoding();
+ // If the encoding is a ScatterTensorDescAttr, we need to
+ // potentially adjust the chunk size based on the inst_data.
+ if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr =
+ mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+ int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+ if (chunkSize > 1) {
+ int64_t blockedChunkSize = chunkSize;
+ auto instData = tdescTy.getLayoutAttr().getInstData();
+ if (!instData.empty())
+ blockedChunkSize = instData.asArrayRef().back();
+
+ auto chunkSizeAttr = mlir::IntegerAttr::get(
+ mlir::IntegerType::get(ctx, 64), blockedChunkSize);
+
+ // To create a new attribute with a different chunk_size:
+ auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+ ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+ encoding = newEncoding;
+ }
+ }
+
+ newTy =
+ xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+ tdescTy.getLayoutAttr().dropInstData());
+ } else {
newTy = type.clone(tileShape, elemTy);
+ }
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 67d3bd9b393c0..f977ba3c11bcf 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -250,8 +250,7 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%acc = arith.constant dense<0.0> : vector<64xf32>
%c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 8]>
#t = #xegpu.layout<inst_data = [8, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c32 = arith.constant 32 : index
@@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
gpu.return
}
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test_kernel {
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
+
+}
+
+// -----
+
+gpu.module @test_kernel {
+ // CHECK-LABEL: test_prefetch_load_store_update_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+ // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
+ xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>:
+ vector<4x32xf32>,
+ !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
+}
+
+
|
|
||
// ----- | ||
gpu.module @test_kernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for a gpu.module or gpu.func for testing purposes. we just use func.func. we use gpu.module if some gpu op is necessary for the test case (like thread ids etc). is that the case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test doesn't require gpu op, but currently almost all test file uses this style. We may consider changing it for all test cases in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can someone clarify the diff between xegpu blocking vs unrolling? why do we need 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unrolling is test the pattern so the blocking strategy can be very naive. Xegpu blocking pass built on top of unrolling pattern with sophistic blocking strategy. @chencha3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XeGPU::unrollPattern is following the design of vector::unrollPattern. They provide a very generic interface, verifing targetShape is valid and then unroll the op based on it. xegpu-blocking is built on top of it. It simply focuses on passing the inst_data
to the targetShape
of unroll pattern, with some logics for robustness.
Attribute encoding = tdescTy.getEncoding(); | ||
// If the encoding is a ScatterTensorDescAttr, we need to | ||
// potentially adjust the chunk size based on the inst_data. | ||
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use isScattered?
// potentially adjust the chunk size based on the inst_data. | ||
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) { | ||
auto scatterAttr = | ||
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use getEncodingAsScatterTensorDescAttr
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this read better? I could use this code sequence.
Attribute encoding = tdescTy.getEncoding();
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
gpu.func @test_prefetch_load_store_update(%src: ui64) { | ||
|
||
%cst = arith.constant dense<[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens with constants? do the get blocked in this pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
// potentially adjust the chunk size based on the inst_data. | ||
if (tdescTy.isScattered()) { | ||
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr(); | ||
// mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This line can be removed. But I like to use mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding)
here, it is clearer about what happened to encoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change back using "dyn_cast_if_present"
int64_t blockedChunkSize = chunkSize; | ||
auto instData = tdescTy.getLayoutAttr().getInstData(); | ||
if (!instData.empty()) | ||
blockedChunkSize = instData.asArrayRef().back(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe can reuse the chunkSize variable here, so line 311 can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is written purposely to introduce the variable name: "blockedChunkSize". It makes the code easy to understand: the logic here is to block the chunksize and use the blocked one.
|
||
// To create a new attribute with a different chunk_size: | ||
auto newEncoding = xegpu::ScatterTensorDescAttr::get( | ||
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there is no need to create chunkSizeAttr, can simply pass blockedChunkSize here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
✅ With the latest revision this PR passed the C/C++ code formatter. |
Add blocking support for scatter ops: Create_tdesc, update, prefetch, load and store. It also enables the load/store with chunk size.