Skip to content

Commit 6c783e1

Browse files
authored
[MLIR][XeGPU] Refine XeGPU definitions (#100763)
This PR has following changes/fixes to XeGPU definition: - Fix type print format for atomic_rmw - removed 2D support for MaskType - Update LoadNd definition - Add 1D TensorDesc support - Replaced vnni_axis attribute with packed attribute - Update DPAS op definition, limiting A to 2D vector, and B to either 2D/3D vector.
1 parent e96687a commit 6c783e1

File tree

5 files changed

+146
-90
lines changed

5 files changed

+146
-90
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -53,47 +53,56 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
5353
let summary = "Create nd-tensor descriptor operation";
5454
let description = [{
5555
The "create_nd_tdesc" operation creates a TensorDescType which represents
56-
a sub-view of a 2D memory region (It can be extended to support n-D memory
57-
region if needed in future). Elements in the subview continuous in each
58-
dimension. It encodes the following important information for supporting
59-
Intel hardware features:
60-
61-
* source: an object representing (starting address/pointer of) a 2D memory region.
62-
It can be either a 2D memref object, or simply a pointer represented by uint64_t type.
63-
for the later case, the shape and layout information of the 2D memory region should
64-
be explicitly passed via `shape` and `strides` parameters.
65-
* offsets: two index values represents offsets from the "source" at the each dimension
66-
at which the subview of the target memory will be created. It is encoded via two
67-
variables, including "offsets" and "const_offsets", such that it can
68-
accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
69-
* shape: the shape information of the memory region pointed by the "source". It is
70-
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
56+
a sub-view of a 1D/2D memory region inside the one or two innermost dimensions
57+
of the source. (It can be extended to support n-D memory region if needed in
58+
future). Elements in the subview continuous in each dimension. It encodes the
59+
following important information for supporting Intel hardware features:
60+
61+
* source: an object representing (starting address/pointer of) a memory region.
62+
It can be either a memref object, or simply a pointer represented by uint64_t type.
63+
For the case of dynamic memrefs or pointer, the shape and layout information of the
64+
memory region should be explicitly passed via `shape` and `strides` parameters.
65+
66+
* offsets: index values represents offsets from the "source" at the each dimension
67+
at which the subview of the target memory will be created. It is encoded via
68+
"offsets" and "const_offsets", such that it can accept various forms, such as,
69+
operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
70+
71+
* shape: the shape information of the memory region pointed by the "source". It is
72+
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
7173
But if "source" is simply a pointer represented as uint64_t type, or a memref
7274
type without shape information e.g., memref<?x?xf16>, the shape information has
7375
to be explicitly passed via the "shape" and "const_shape" arguments.
76+
7477
* strides: the strides of the memory region pointed by the "source". Similar to shape,
7578
it is typically encoded via the MemRefType of the source too. But if "source" is
7679
simply a pointer represented as uint64_t type, or a memref type without shape
7780
information e.g., memref<?x?xf16>, the strides information has to be explicitly
7881
passed via the "strides" and "const_strides" argument.
7982

8083
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
84+
```mlir
8185
%0 = memref.alloc() : memref<1024x1024xf32>
8286
%c0 = arith.constant 0 : index
8387
%c1 = arith.constant 1 : index
8488
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
89+
```
8590

8691
Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
92+
```mlir
8793
%0 = memref.alloc(%h, %w) : memref<?x?xf32>
8894
%c0 = arith.constant 0 : index
8995
%c1 = arith.constant 1 : index
9096
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
97+
```
9198

9299
Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
100+
```mlir
93101
%0 = ... : ui64
94102
%c0 = arith.constant 0 : index
95103
%c1 = arith.constant 1 : index
96104
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
105+
```
97106
}];
98107

99108
let arguments = (ins
@@ -219,7 +228,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
219228
memory regions to each level of the cache based on their cache policy.
220229

221230
Example:
222-
```
231+
```mlir
223232
xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
224233
l2_hint = #xegpu.cache_hint<cached>,
225234
l3_hint = #xegpu.cache_hint<cached>}
@@ -245,8 +254,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
245254
}
246255

247256

248-
def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>,
249-
AllElementCountsMatch<["value", "TensorDesc"]>]> {
257+
def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
250258
let summary = "loads a n-D block from memory (represented by TensorDesc)"
251259
"to registers (represented by vector)";
252260
let description = [{
@@ -263,7 +271,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
263271
same time.
264272

265273
Example:
266-
```
274+
```mlir
267275
xegpu.load_nd %1 {transpose = [1, 0],
268276
l1_hint = #xegpu.cache_hint<cached>,
269277
l2_hint = #xegpu.cache_hint<uncached>,
@@ -275,7 +283,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
275283
}];
276284

277285
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
278-
OptionalAttr<I64Attr>: $vnni_axis,
286+
OptionalAttr<UnitAttr>: $packed,
279287
OptionalAttr<DenseI64ArrayAttr>: $transpose,
280288
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
281289
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -309,7 +317,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc
309317
Corresponding cache hint attribute will be masked.
310318

311319
Example:
312-
```
320+
```mlir
313321
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
314322
l2_hint = #xegpu.cache_hint<write_back>,
315323
l3_hint = #xegpu.cache_hint<write_through>}
@@ -407,21 +415,21 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
407415
elements accessed for each offset, default is 1.
408416

409417
Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
410-
```
418+
```mlir
411419
%a = memref.alloc() : memref<1024xf32>
412420
%1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
413421
```
414422

415423
Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
416424
It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
417-
```
425+
```mlir
418426
%0 = memref.alloc() : memref<1024xf32>
419427
%1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
420428
```
421429

422430
Example 3. It is similar to Example 2, but there is some overlaps among workitems.
423431
It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
424-
```
432+
```mlir
425433
%0 = memref.alloc() : memref<1024xf32>
426434
%1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
427435
```
@@ -480,7 +488,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
480488
it works on scattered TensorDesc instead.
481489

482490
Example:
483-
```
491+
```mlir
484492
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
485493
l2_hint = #xegpu.cache_hint<cached>,
486494
l3_hint = #xegpu.cache_hint<cached>}
@@ -520,7 +528,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
520528
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
521529

522530
Example:
523-
```
531+
```mlir
524532
%2 = xegpu.load %1, %0 {transpose = [1, 0],
525533
l1_hint = #xegpu.cache_hint<cached>,
526534
l2_hint = #xegpu.cache_hint<uncached>,
@@ -572,7 +580,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
572580
It has similar semantic to `load_gather`.
573581

574582
Example:
575-
```
583+
```mlir
576584
%3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
577585
l2_hint = #xegpu.cache_hint<write_back>,
578586
l3_hint = #xegpu.cache_hint<write_through>}
@@ -621,7 +629,7 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
621629
shifts for each work-item.
622630

623631
Example:
624-
```
632+
```mlir
625633
%2 = xegpu.update_offset %1, [32, 32, 32, 32]
626634
: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
627635
```
@@ -668,14 +676,12 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
668676
data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
669677
and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
670678
also requires A and B to be loaded with the required data layout. Specially,
671-
VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
672-
of the corresponding `load_nd` operator. To keep both operands as 3D vector,
673-
operand A is loaded via setting `vnni_axis = 1` without impacting the
674-
physical layouts change in register. Due to the VNNI transformation, A and B operands
675-
are represented as 3D vector, with the last dimension representing the VNNI factor,
676-
which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
677-
is represented as `A: vector<8x8x2xf16>`, and `B: vector<16x16xf16>` is
678-
represented as `B: vector<8x16x2xf16>`.
679+
680+
VNNI layout is required for B operand. It is achieved via adding `packed`
681+
attribute to the `load_nd` operator. Due to the VNNI transformation, B operands
682+
can be represented as a 3D vector, with the last dimension representing the VNNI
683+
factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
684+
can be represented as `B: vector<8x16x2xf16>`.
679685

680686
Note: on PVC, the hardware can perform load with VNNI transformation when data
681687
element type is 16-bit or lower precision, taking 2 or 4 elements from
@@ -739,7 +745,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
739745

740746
let assemblyFormat = [{
741747
$kind $tensorDesc `,` $mask `,` $value attr-dict `:`
742-
type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
748+
qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
743749
}];
744750
}
745751

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ include "mlir/IR/BuiltinTypes.td"
1616
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
1717
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
1818
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
19-
def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
19+
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
2020
def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
2121
def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
22-
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>;
22+
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
2323
def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
2424
def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
2525

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
122122

123123
LogicalResult CreateNdDescOp::verify() {
124124
auto rank = (int64_t)getMixedOffsets().size();
125-
bool invalidRank = (rank != 2);
125+
bool invalidRank = false;
126126
bool invalidElemTy = false;
127127

128128
// check source type matches the rank if it is a memref.
@@ -133,17 +133,21 @@ LogicalResult CreateNdDescOp::verify() {
133133
invalidElemTy |= memrefTy.getElementType() != getElementType();
134134
}
135135

136-
// check result type matches the rank
137-
invalidRank = (getType().getRank() != rank);
138-
139136
// mismatches among shape, strides, and offsets are
140137
// already handeled by OffsetSizeAndStrideOpInterface.
141138
// So they are not check here.
142139
if (invalidRank)
143140
return emitOpError(
144-
"Expecting the rank of shape, strides, offsets, "
145-
"source memref type (if source is a memref) and TensorDesc "
146-
"should match with each other. They currenlty are 2D.");
141+
"Expecting the rank of shape, strides, offsets, and source (if source "
142+
"is a memref) should match with each other.");
143+
144+
// check result TensorDesc rank
145+
invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
146+
147+
if (invalidRank)
148+
return emitOpError(
149+
"Expecting the TensorDesc rank is up to 2 and not greater than the "
150+
"ranks of shape, strides, offsets or the memref source.");
147151

148152
if (invalidElemTy)
149153
return emitOpError("TensorDesc should have the same element "
@@ -182,8 +186,8 @@ LogicalResult LoadNdOp::verify() {
182186
auto tdescTy = getTensorDescType();
183187
auto valueTy = getType();
184188

185-
if (tdescTy.getRank() != 2)
186-
return emitOpError("Expecting a 2D TensorDesc.\n");
189+
if (tdescTy.getRank() > 2)
190+
return emitOpError("Expecting a 1D/2D TensorDesc.\n");
187191

188192
if (tdescTy.getScattered())
189193
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -206,17 +210,28 @@ LogicalResult LoadNdOp::verify() {
206210

207211
if (getTranspose()) {
208212
auto trans = getTranspose().value();
209-
if (tdescShape.size() >= trans.size())
213+
214+
// Make sure the transpose value is valid.
215+
bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
216+
return t >= 0 && t < tdescTy.getRank();
217+
});
218+
219+
if (valid)
210220
transpose(trans, tdescShape);
211221
else
212222
emitWarning("Invalid transpose attr. It is ignored.");
213223
}
214224

215-
if (getVnniAxis()) {
216-
auto axis = getVnniAxis().value();
217-
auto vnni_factor = valueShape.back();
218-
tdescShape[axis] /= vnni_factor;
219-
tdescShape.push_back(vnni_factor);
225+
if (getPacked()) {
226+
if (tdescTy.getRank() == 2) {
227+
const int axis = 0;
228+
auto vnni_factor = valueShape.back();
229+
tdescShape[axis] /= vnni_factor;
230+
tdescShape.push_back(vnni_factor);
231+
} else {
232+
return emitWarning("Invalid Packed Attr. It is ignored (available for 2D "
233+
"TensorDesc only).");
234+
}
220235
}
221236

222237
if (array_len > 1) {
@@ -239,8 +254,8 @@ LogicalResult StoreNdOp::verify() {
239254
auto dstTy = getTensorDescType(); // Tile
240255
auto valTy = getValueType(); // Vector
241256

242-
if (dstTy.getRank() != 2)
243-
return emitOpError("Expecting a 2D TensorDesc.\n");
257+
if (dstTy.getRank() > 2)
258+
return emitOpError("Expecting a 1D/2D TensorDesc.\n");
244259

245260
if (dstTy.getScattered())
246261
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -413,18 +428,15 @@ LogicalResult DpasOp::verify() {
413428
int64_t lhsRank = getLhsType().getRank();
414429
int64_t rhsRank = getRhsType().getRank();
415430

416-
if (lhsRank != rhsRank || lhsRank != 3)
417-
return emitOpError(
418-
"lhs and rhs rank does not match for dpas op, or their rank is not 3.");
419-
420-
if (getAcc() && getAccType() != getResultType())
421-
return emitOpError("Accumulator and Result for dpas op should have the "
422-
"same type (both shape and element type).");
431+
if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
432+
return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
433+
"2D or 3D (packed) vector.");
423434

424435
auto lhsShape = getLhsType().getShape();
425436
auto rhsShape = getRhsType().getShape();
426-
if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
427-
return emitOpError("K-dimension or vnni-factor mismatch.");
437+
auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
438+
if (bK != lhsShape[1])
439+
return emitOpError("K-dimension mismatch.");
428440

429441
return success();
430442
}

0 commit comments

Comments
 (0)