Skip to content

[MLIR][XeGPU] Add sg_map attribute to support Work Item level semantics #108864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}

def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
let summary = [{
Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
}];
let description = [{
To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
attribute at the tensor description creation time.
Within the `sg_map`, `wi_layout` specifies the layout of work items,
describing the mapping of work items to the tensor.
wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
`wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.

E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
In this example, the subgroup has 16 work items in wi_layout=[1, 16],
each accessing 1 element as specified by wi_data=[1, 1].

`wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
}];
let parameters = (ins
ArrayRefParameter<"uint32_t">:$wi_layout,
ArrayRefParameter<"uint32_t">:$wi_data);

let builders = [
AttrBuilder<(ins)>
];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}

#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
18 changes: 14 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
```

Examples:
Expand All @@ -77,25 +77,31 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",

// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>

// A TensorDesc with a sg_map
xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
```
}];

let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
"mlir::Type": $elementType,
OptionalParameter<"mlir::Attribute">: $encoding);
OptionalParameter<"mlir::Attribute">: $encoding,
OptionalParameter<"mlir::Attribute">: $sg_map);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it also available to scattered TensorDesc (created by create_tdesc), which has slightly different semantics to Blocked TensorDesc, the one created by create_nd_tdesc? If so, is there any restrictions on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it also available to scattered TensorDesc (created by create_tdesc)

Yes, it produces the same type, so the map can (and should) be attached to the scattered variation. It works now but I'll add a test to cover it.

If so, is there any restrictions on it?

There are no additional restrictions on the layout; it only has to match the sub-group size. The data WI field represents the minimum requirement for the distribution (e.g., to cover packing). We may want to bind it to the chunk size somehow but I don't think it is necessary at the moment.


let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>": $shape,
"mlir::Type": $elementType,
CArg<"int", "1">: $array_length,
CArg<"bool", "true">: $boundary_check,
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
Comment on lines -93 to +98
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to slightly modify the signature since the new two conflict with default parameters (boundary check gets implicitly casted to memory space). I think it's better anyway.

TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>": $shape,
"mlir::Type": $elementType,
CArg<"int", "1">: $chunk_size,
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
];

let extraClassDeclaration = [{
Expand All @@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
}

SGMapAttr getSGMapAttr() const {
return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
}

xegpu::MemorySpace getMemorySpace() const {
auto block_attr = getEncodingAsBlockTensorDescAttr();
if (block_attr && block_attr.getMemorySpace())
Expand Down
110 changes: 98 additions & 12 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, chunkSizeAttr);
}

//===----------------------------------------------------------------------===//
// XeGPU_SGMapAttr
//===----------------------------------------------------------------------===//
namespace {
template <typename T, unsigned N>
LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
llvm::SmallVector<T, N> &result,
llvm::StringRef fieldName) {
if (failed(parser.parseKeyword(fieldName))) {
parser.emitError(parser.getCurrentLocation(),
"unexpected field name. Expected " + fieldName + ".");
return failure();
}

if (failed(parser.parseEqual())) {
parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
return failure();
}

auto elemParser = [&]() -> llvm::ParseResult {
uint32_t elem = 0;
auto res = parser.parseInteger(elem);
result.push_back(elem);
return res;
};

return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
elemParser, fieldName);
}
} // namespace

mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
::mlir::Type attrType) {
if (failed(parser.parseLess()))
return {};

llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
return {};

if (failed(parser.parseComma()))
return {};

if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
return {};

return SGMapAttr::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), wi_layout, wi_data);
}

void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
printer << "<";
printer.printKeywordOrString("wi_layout");
printer << " = [" << getWiLayout() << "], ";
printer.printKeywordOrString("wi_data");
printer << " = [" << getWiData() << "]";
printer << ">";
}

LogicalResult
SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<uint32_t> wi_layout,
llvm::ArrayRef<uint32_t> wi_data) {
if (wi_layout.size() != 2)
return emitError() << "expected wi_layout of size 2";
if (wi_data.size() != 2)
return emitError() << "expected wi_data of size 2";
return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
Expand All @@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
mlir::FailureOr<mlir::Attribute> sg_map;

// Parse literal '<'
if (parser.parseLess())
Expand All @@ -81,22 +153,31 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
}

// parse optional attributes
if (mlir::succeeded(parser.parseOptionalComma())) {
encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
if (mlir::failed(encoding)) {
parser.emitError(
parser.getCurrentLocation(),
"Failed to parse the attribute field for TensorDescType.\n");
return {};
while (mlir::succeeded(parser.parseOptionalComma())) {
mlir::Attribute attr;
ParseResult res = parser.parseAttribute(attr);
if (mlir::succeeded(res)) {
if (mlir::isa<SGMapAttr>(attr)) {
sg_map = attr;
continue;
}
if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
encoding = attr;
continue;
}
}
parser.emitError(parser.getCurrentLocation(),
"Failed to parse the attribute.\n");
return {};
}

// Parse literal '>'
if (parser.parseGreater())
return {};

return TensorDescType::get(parser.getContext(), shape, elementType,
encoding.value_or(mlir::Attribute()));
encoding.value_or(mlir::Attribute()),
sg_map.value_or(mlir::Attribute()));
}

void TensorDescType::print(::mlir::AsmPrinter &printer) const {
Expand All @@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
if (auto encoding = getEncoding())
printer << ", " << encoding;

if (auto sg_map = getSgMap())
printer << ", " << sg_map;

printer << ">";
}

TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, int array_length,
bool boundary_check,
MemorySpace memory_space) {
MemorySpace memory_space,
mlir::Attribute sg_map) {
auto context = elementType.getContext();
auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
boundary_check);
return Base::get(context, shape, elementType, attr);
return Base::get(context, shape, elementType, attr, sg_map);
}

TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
mlir::Type elementType, int chunk_size,
MemorySpace memory_space) {
MemorySpace memory_space,
mlir::Attribute sg_map) {
auto context = elementType.getContext();
auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
return Base::get(context, shape, elementType, attr);
return Base::get(context, shape, elementType, attr, sg_map);
}

} // namespace xegpu
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/XeGPU/XeGPUOps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
gpu.return
}

// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
gpu.return
}

// CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
Expand Down Expand Up @@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
gpu.return
}

// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
gpu.return
}

// CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
Expand Down Expand Up @@ -116,6 +131,13 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
gpu.return
}

// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
%1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
gpu.return
}

// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
gpu.func @test_prefetch_vc(%src: ui64) {
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
Expand Down
Loading