-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][XEGPU] Add blocking support for scatter ops #144766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1db8a64
1e6c2f3
77fdcc1
3bb754b
915830c
beb5c8e
adf6358
bd150c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { | |
|
||
std::optional<SmallVector<int64_t>> | ||
XeGPUBlockingPass::getTileShape(Operation *op) const { | ||
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op)) | ||
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp, | ||
xegpu::UpdateOffsetOp>(op)) | ||
return getTileShape(op->getOpResult(0)); | ||
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op)) | ||
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, | ||
xegpu::LoadGatherOp>(op)) | ||
return getTileShape(op->getOpOperand(0)); | ||
if (isa<xegpu::StoreNdOp>(op)) | ||
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) | ||
return getTileShape(op->getOpOperand(1)); | ||
|
||
if (isa<xegpu::DpasOp>(op)) { | ||
|
@@ -295,12 +297,36 @@ void XeGPUBlockingPass::runOnOperation() { | |
Type elemTy = type.getElementType(); | ||
Type newTy; | ||
|
||
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) | ||
newTy = xegpu::TensorDescType::get( | ||
ctx, tileShape, elemTy, tdescTy.getEncoding(), | ||
tdescTy.getLayoutAttr().dropInstData()); | ||
else | ||
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) { | ||
|
||
Attribute encoding = tdescTy.getEncoding(); | ||
// If the encoding is a ScatterTensorDescAttr, we need to | ||
// potentially adjust the chunk size based on the inst_data. | ||
if (tdescTy.isScattered()) { | ||
auto scatterAttr = | ||
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding); | ||
int64_t chunkSize = scatterAttr.getChunkSize().getInt(); | ||
|
||
if (chunkSize > 1) { | ||
int64_t blockedChunkSize = chunkSize; | ||
auto instData = tdescTy.getLayoutAttr().getInstData(); | ||
if (!instData.empty()) | ||
blockedChunkSize = instData.asArrayRef().back(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe can reuse the chunkSize variable here, so line 311 can be removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is written purposely to introduce the variable name: "blockedChunkSize". It makes the code easy to understand: the logic here is to block the chunksize and use the blocked one. |
||
|
||
// To create a new attribute with a different chunk_size: | ||
auto newEncoding = xegpu::ScatterTensorDescAttr::get( | ||
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize); | ||
|
||
encoding = newEncoding; | ||
} | ||
} | ||
|
||
newTy = | ||
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, | ||
tdescTy.getLayoutAttr().dropInstData()); | ||
} else { | ||
newTy = type.clone(tileShape, elemTy); | ||
} | ||
|
||
std::optional<SmallVector<int64_t>> ratio = | ||
computeShapeRatio(type.getShape(), tileShape); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -250,8 +250,7 @@ gpu.module @test_kernel { | |
// ----- | ||
#l = #xegpu.layout<inst_data = [16, 16]> | ||
#r = #xegpu.layout<inst_data = [16]> | ||
|
||
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.module @test_kernel { | ||
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
%acc = arith.constant dense<0.0> : vector<64xf32> | ||
%c64 = arith.constant 64 : index | ||
|
@@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce< | |
// ----- | ||
#l = #xegpu.layout<inst_data = [16, 16]> | ||
#r = #xegpu.layout<inst_data = [16]> | ||
|
||
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.module @test_kernel { | ||
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
%c1 = arith.constant 1 : index | ||
%c32 = arith.constant 32 : index | ||
|
@@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce< | |
// ----- | ||
#r = #xegpu.layout<inst_data = [16]> | ||
#l = #xegpu.layout<inst_data = [16, 16]> | ||
|
||
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.module @test_kernel { | ||
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
|
||
%c64 = arith.constant 64 : index | ||
|
@@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce< | |
// ----- | ||
#r = #xegpu.layout<inst_data = [16]> | ||
#l = #xegpu.layout<inst_data = [16, 16]> | ||
|
||
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.module @test_kernel { | ||
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
|
||
%c32 = arith.constant 32 : index | ||
|
@@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce< | |
// ----- | ||
#l = #xegpu.layout<inst_data = [16, 8]> | ||
#t = #xegpu.layout<inst_data = [8, 16]> | ||
|
||
gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.module @test_kernel { | ||
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
|
||
%c32 = arith.constant 32 : index | ||
|
@@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce< | |
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t> | ||
gpu.return | ||
} | ||
} | ||
} | ||
|
||
// ----- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for a gpu.module or gpu.func for testing purposes. we just use func.func. we use gpu.module if some gpu op is necessary for the test case (like thread ids etc). is that the case here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test doesn't require gpu op, but currently almost all test file uses this style. We may consider changing it for all test cases in a separate PR. |
||
gpu.module @test_kernel { | ||
// CHECK-LABEL: test_prefetch_load_store_update | ||
// CHECK-SAME: [[arg0:%.+]]: ui64 | ||
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> | ||
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> | ||
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> | ||
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> | ||
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> | ||
|
||
gpu.func @test_prefetch_load_store_update(%src: ui64) { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens with constants? do the get blocked in this pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
%cst = arith.constant dense<[ | ||
0, 8, 16, 24, 32, 40, 48, 56, | ||
64, 72, 80, 88, 96, 104, 112, 120, | ||
128, 136, 144, 152, 160, 168, 176, 184, | ||
192, 200, 208, 216, 224, 232, 240, 248 | ||
]> : vector<32xindex> | ||
|
||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> | ||
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> | ||
|
||
%delta = arith.constant dense<[ | ||
32, 32, 32, 32, 32, 32, 32, 32, | ||
32, 32, 32, 32, 32, 32, 32, 64, | ||
128, 128, 128, 128, 128, 128, 128, 128, | ||
128, 128, 128, 128, 128, 128, 128, 256 | ||
]> : vector<32xindex> | ||
%new_tdesc = xegpu.update_offset %tdesc, %delta | ||
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex> | ||
|
||
%c17 = arith.constant 17: index | ||
%mask = vector.create_mask %c17: vector<32xi1> | ||
|
||
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32> | ||
|
||
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32> | ||
xegpu.store %st_vec, %tdesc, %mask: | ||
vector<32xf32>, | ||
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, | ||
vector<32xi1> | ||
|
||
gpu.return | ||
} | ||
|
||
} | ||
|
||
// ----- | ||
|
||
gpu.module @test_kernel { | ||
// CHECK-LABEL: test_prefetch_load_store_update_chunk | ||
// CHECK-SAME: [[arg0:%.+]]: ui64 | ||
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>> | ||
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>> | ||
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex> | ||
// CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32> | ||
// CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> | ||
|
||
gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) { | ||
|
||
%cst = arith.constant dense<[ | ||
0, 8, 16, 24, 32, 40, 48, 56, | ||
64, 72, 80, 88, 96, 104, 112, 120, | ||
128, 136, 144, 152, 160, 168, 176, 184, | ||
192, 200, 208, 216, 224, 232, 240, 248 | ||
]> : vector<32xindex> | ||
|
||
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> | ||
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> | ||
|
||
%delta = arith.constant dense<[ | ||
32, 32, 32, 32, 32, 32, 32, 32, | ||
32, 32, 32, 32, 32, 32, 32, 64, | ||
128, 128, 128, 128, 128, 128, 128, 128, | ||
128, 128, 128, 128, 128, 128, 128, 256 | ||
]> : vector<32xindex> | ||
%new_tdesc = xegpu.update_offset %tdesc, %delta | ||
: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex> | ||
|
||
%c17 = arith.constant 17: index | ||
%mask = vector.create_mask %c17: vector<32xi1> | ||
|
||
%ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32> | ||
|
||
%st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32> | ||
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: | ||
vector<4x32xf32>, | ||
!xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, | ||
vector<32xi1> | ||
|
||
gpu.return | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can someone clarify the diff between xegpu blocking vs unrolling? why do we need 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unrolling is test the pattern so the blocking strategy can be very naive. Xegpu blocking pass built on top of unrolling pattern with sophistic blocking strategy. @chencha3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XeGPU::unrollPattern is following the design of vector::unrollPattern. They provide a very generic interface, verifing targetShape is valid and then unroll the op based on it. xegpu-blocking is built on top of it. It simply focuses on passing the
inst_data
to thetargetShape
of unroll pattern, with some logics for robustness.