10
10
#include " mlir/IR/Builders.h"
11
11
#include " mlir/IR/DialectImplementation.h"
12
12
#include " llvm/ADT/TypeSwitch.h"
13
+ #include < numeric>
13
14
14
15
namespace mlir {
15
16
namespace xegpu {
@@ -336,32 +337,30 @@ LogicalResult TensorDescType::verify(
336
337
// [n_distribution_units, lane_data_size]
337
338
FailureOr<VectorType> TensorDescType::getDistributedVectorType () {
338
339
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout ());
339
- // If no layout is provided, tensor desc is not used in SIMT mode.
340
- if (!layout)
340
+ // It only works for subgroup level layout, which only has lane_layout
341
+ // and lane_data, and is to distribute a SIMD code into SIMT code.
342
+ if (!layout || !layout.isSgLayout ())
341
343
return failure ();
342
344
343
345
SmallVector<int64_t > laneData (layout.getLaneData ().asArrayRef ());
344
346
SmallVector<int64_t > laneLayout (layout.getLaneLayout ().asArrayRef ());
345
347
auto tdescShape = getShape ();
346
348
347
- auto laneDataSize = 1 , sgSize = 1 ;
348
- for ( auto [laneDim, laneDataDim] : llvm::zip_equal ( laneLayout, laneData)) {
349
- laneDataSize *= laneDataDim;
350
- sgSize *= laneDim;
351
- }
349
+ // compute sgSize by multiply elements of laneLayout
350
+ // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
351
+ // e.g. for 1D layout, sgSize = laneLayout[0]
352
+ auto sgSize = std::accumulate (laneLayout. begin (), laneLayout. end (), 1 ,
353
+ std::multiplies< int64_t >());
352
354
353
355
// Case 1: regular loads/stores
354
356
auto scatterAttr = getEncodingAsScatterTensorDescAttr ();
355
357
if (scatterAttr) {
356
358
auto chunkSize = scatterAttr.getChunkSize ().getInt ();
357
359
// Verify if the first dimension of the tensor descriptor shape is
358
360
// distributable.
359
- assert (tdescShape[0 ] % ( laneLayout[0 ]) == 0 &&
361
+ assert (tdescShape[0 ] == laneLayout[0 ] &&
360
362
" tensor descriptor shape is not distributable" );
361
- if (chunkSize > 1 )
362
- return VectorType::get ({chunkSize / laneDataSize, laneDataSize},
363
- getElementType ());
364
- return VectorType::get ({laneDataSize}, getElementType ());
363
+ return VectorType::get ({chunkSize}, getElementType ());
365
364
}
366
365
367
366
// Case 2: block loads/stores
@@ -376,8 +375,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
376
375
// tensorSize must be adjusted for array_length.
377
376
tensorSize *= getArrayLength ();
378
377
379
- return VectorType::get ({tensorSize / (sgSize * laneDataSize), laneDataSize},
380
- getElementType ());
378
+ return VectorType::get ({tensorSize / sgSize}, getElementType ());
381
379
}
382
380
383
381
} // namespace xegpu
0 commit comments