Skip to content

Commit 9c697b3

Browse files
authored
[MLIR][XeGPU] Update the type of offsets for CreateDescOp and UpdateOffsetOp (#110741)
This PR changes the type of `offsets` operand of CreateDescOp and UpdateOffsetOp to 1D Vector of index, for convenience of users.
1 parent c2601f1 commit 9c697b3

File tree

5 files changed

+170
-78
lines changed

5 files changed

+170
-78
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "mlir/IR/Dialect.h"
1617
#include "mlir/IR/TypeUtilities.h"

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
424424
It accepts the following parameters:
425425

426426
* source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
427-
* offsets: a array containing offsets of each access point. Its size
427+
* offsets: a vector containing offsets of each access point. Its size
428428
is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
429-
implying each element in the array corresponds to a work-item (SIMT lane)
429+
implying each element in the vector corresponds to a work-item (SIMT lane)
430430
in the subgroup.
431431

432432
The first dimension of the result TensorDesc corresponds to work-items, so it should
@@ -436,56 +436,59 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
436436
Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
437437
```mlir
438438
%a = memref.alloc() : memref<1024xf32>
439-
%1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
439+
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
440+
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
440441
```
441442

442443
Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
443444
It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
444445
```mlir
445446
%0 = memref.alloc() : memref<1024xf32>
446-
%1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>
447+
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
448+
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
449+
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
447450
```
448451

449452
Example 3. It is similar to Example 2, but there is some overlaps among workitems.
450453
It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
451454
```mlir
452455
%0 = memref.alloc() : memref<1024xf32>
453-
%1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>>
456+
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
457+
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
458+
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
454459
```
455460
}];
456461

457462
let arguments = (ins XeGPU_BaseAddrType: $source,
458-
Variadic<Index>: $offsets,
459-
DenseI64ArrayAttr: $const_offsets);
463+
XeGPU_OffsetType: $offsets);
460464
let results = (outs XeGPU_TensorDesc:$TensorDesc);
461465

466+
let builders = [
467+
OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "mlir::Value": $source,
468+
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
469+
OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "mlir::Value": $source,
470+
"llvm::ArrayRef<int64_t>": $offsets)>,
471+
];
472+
462473
let assemblyFormat = [{
463-
$source
464-
custom<DynamicIndexList>($offsets, $const_offsets)
465-
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
474+
$source `,` $offsets attr-dict `:` type($source) `,` type($offsets) `->` qualified(type($TensorDesc))
466475
}];
467476

468-
let extraClassDeclaration = extraBaseClassDeclaration # [{
477+
let extraClassDeclaration = [{
469478
xegpu::TensorDescType getTensorDescType() {
470479
return getTensorDesc().getType();
471480
}
472481

473-
SmallVector<OpFoldResult> getMixedOffsets() {
474-
Builder b(getContext());
475-
return getMixedValues(getConstOffsets(), getOffsets(), b);
482+
mlir::VectorType getOffsetsType() {
483+
return getOffsets().getType();
476484
}
477485

478486
size_t getNumOffsets() {
479-
return getMixedOffsets().size();
487+
return getOffsetsType().getNumElements();
480488
}
481489

482490
mlir::Value getViewSource() { return getSource(); }
483491

484-
OpFoldResult getOffset(unsigned idx) {
485-
assert(idx < getNumOffsets() && "Invalid out of bound access.");
486-
return getMixedOffsets()[idx];
487-
}
488-
489492
unsigned getSourceMemorySpace() {
490493
auto srcTy = getSource().getType();
491494
if (auto memrefTy = llvm::dyn_cast<mlir::MemRefType>(srcTy)) {
@@ -550,24 +553,33 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
550553
describes the data being loaded at the subgroup level, so its size is
551554
consistent with the number of work-items in a subgroup. When the chunk size
552555
is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
553-
to work-items, and dim-0 corresponding to the chunk_size loaded by each work-item.
556+
to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
554557
Specially, there is a transpose effect on the result (as compared to the TensorDesc)
555558
due to the hardware implementation. Therefore, a transpose attribute is introduced
556559
on purpose, making sure users are aware of this implicit transformation.
557560

558561
The mask operand masks out memory access so that it is safe to pass out-of-boundary
559562
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
560563

561-
Example:
564+
Example 1:
562565
```mlir
563-
%2 = xegpu.load %1, %0 {transpose,
564-
l1_hint = #xegpu.cache_hint<cached>,
566+
%2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
565567
l2_hint = #xegpu.cache_hint<uncached>,
566568
l3_hint = #xegpu.cache_hint<uncached>}
567569
: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
568570
vector<16xi1> -> vector<16xf32>
569571
```
570572

573+
Example 2:
574+
```mlir
575+
%2 = xegpu.load %1, %0 {transpose,
576+
l1_hint = #xegpu.cache_hint<cached>,
577+
l2_hint = #xegpu.cache_hint<uncached>,
578+
l3_hint = #xegpu.cache_hint<uncached>}
579+
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
580+
vector<16xi1> -> vector<8x16xf32>
581+
```
582+
571583
}];
572584

573585
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -610,17 +622,27 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "T
610622
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
611623
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
612624
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
613-
and the dim-0 of the value corresponds to the chunk_size stored per lane. So `store_scatter`
625+
and the dim-0 of the value corresponds to the chunk size stored per lane. So `store_scatter`
614626
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
615627
introduced on purpose, making sure users are aware of this implicit transformation.
616628

617-
Example:
629+
Example 1:
618630
```mlir
619631
%3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
620632
l2_hint = #xegpu.cache_hint<write_back>,
621633
l3_hint = #xegpu.cache_hint<write_through>}
622-
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
634+
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
635+
```
636+
637+
Example 2:
638+
```mlir
639+
%3 = xegpu.store %0, %1, %2 {transpose,
640+
l1_hint = #xegpu.cache_hint<uncached>,
641+
l2_hint = #xegpu.cache_hint<write_back>,
642+
l3_hint = #xegpu.cache_hint<write_through>}
643+
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
623644
```
645+
624646
}];
625647

626648
let arguments = (ins
@@ -666,40 +688,39 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
666688

667689
Example:
668690
```mlir
669-
%2 = xegpu.update_offset %1, [32, 32, 32, 32]
670-
: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
691+
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
692+
%2 = xegpu.update_offset %1, %off :
693+
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<>>, vector<4xindex>
671694
```
672695
}];
673696

674697
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
675-
Variadic<Index>: $offsets,
676-
DenseI64ArrayAttr: $const_offsets);
698+
XeGPU_OffsetType: $offsets);
677699
let results = (outs XeGPU_TensorDesc: $result);
678700

679-
let extraClassDeclaration = extraBaseClassDeclaration # [{
701+
let builders = [
702+
OpBuilder<(ins "mlir::Value": $TensorDesc,
703+
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
704+
OpBuilder<(ins "mlir::Value": $TensorDesc,
705+
"llvm::ArrayRef<int64_t>": $offsets)>
706+
];
707+
708+
let extraClassDeclaration = [{
680709
xegpu::TensorDescType getTensorDescType() {
681710
return getTensorDesc().getType();
682711
}
683712

684-
SmallVector<OpFoldResult> getMixedOffsets() {
685-
Builder b(getContext());
686-
return getMixedValues(getConstOffsets(), getOffsets(), b);
713+
mlir::VectorType getOffsetsType() {
714+
return getOffsets().getType();
687715
}
688716

689717
size_t getNumOffsets() {
690-
return getMixedOffsets().size();
691-
}
692-
693-
OpFoldResult getOffset(unsigned idx) {
694-
assert(idx < getNumOffsets() && "Invalid out of bound access.");
695-
return getMixedOffsets()[idx];
718+
return getOffsetsType().getNumElements();
696719
}
697720
}];
698721

699722
let assemblyFormat = [{
700-
$TensorDesc `,`
701-
custom<DynamicIndexList>($offsets, $const_offsets)
702-
attr-dict `:` qualified(type($TensorDesc))
723+
$TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
703724
}];
704725
}
705726

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Arith/Utils/Utils.h"
910
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1011
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1112
#include "mlir/IR/Builders.h"
@@ -308,6 +309,24 @@ LogicalResult UpdateNdOffsetOp::verify() {
308309
// XeGPU_CreateDescOp
309310
//===----------------------------------------------------------------------===//
310311

312+
void CreateDescOp::build(OpBuilder &builder, OperationState &state,
313+
TensorDescType TensorDesc, Value source,
314+
llvm::ArrayRef<OpFoldResult> offsets) {
315+
auto loc = source.getLoc();
316+
int64_t size = static_cast<int64_t>(offsets.size());
317+
auto type = VectorType::get(size, builder.getIndexType());
318+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
319+
auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
320+
build(builder, state, TensorDesc, source, offset);
321+
}
322+
323+
void CreateDescOp::build(OpBuilder &builder, OperationState &state,
324+
TensorDescType TensorDesc, Value source,
325+
llvm::ArrayRef<int64_t> offsets) {
326+
auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
327+
build(builder, state, TensorDesc, source, ofrs);
328+
}
329+
311330
LogicalResult CreateDescOp::verify() {
312331
auto tdescTy = getTensorDescType();
313332

@@ -473,6 +492,29 @@ LogicalResult StoreScatterOp::verify() {
473492

474493
return success();
475494
}
495+
496+
//===----------------------------------------------------------------------===//
497+
// XeGPU_UpdateOffsetOp
498+
//===----------------------------------------------------------------------===//
499+
void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
500+
mlir::Value tensorDesc,
501+
llvm::ArrayRef<OpFoldResult> offsets) {
502+
auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
503+
assert(tdescTy && "Expecting the source is a TensorDescType value.");
504+
auto loc = tensorDesc.getLoc();
505+
int64_t size = static_cast<int64_t>(offsets.size());
506+
auto type = VectorType::get({size}, builder.getIndexType());
507+
auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
508+
auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
509+
build(builder, state, tdescTy, tensorDesc, offset);
510+
}
511+
512+
void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
513+
Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
514+
auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
515+
build(builder, state, tensorDesc, ofrs);
516+
}
517+
476518
//===----------------------------------------------------------------------===//
477519
// XeGPU_DpasOp
478520
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)