13
13
#include " mlir/Dialect/Index/IR/IndexDialect.h"
14
14
#include " mlir/Dialect/Index/IR/IndexOps.h"
15
15
#include " mlir/Dialect/MemRef/IR/MemRef.h"
16
+ #include " mlir/Dialect/SCF/Transforms/Patterns.h"
16
17
#include " mlir/Dialect/Utils/IndexingUtils.h"
17
18
#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
18
19
#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
20
+ #include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
19
21
#include " mlir/Transforms/DialectConversion.h"
20
22
21
23
namespace mlir {
@@ -29,6 +31,29 @@ using namespace mlir;
29
31
30
32
namespace {
31
33
34
+ static std::pair<SmallVector<int64_t >, int >
35
+ getSgShapeAndCount (ArrayRef<int64_t > shape, xegpu::LayoutAttr layout) {
36
+ int count = 1 ;
37
+ SmallVector<int64_t > sgShape (shape);
38
+
39
+ if (layout && layout.isWgLayout ()) {
40
+ DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout ();
41
+ auto sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
42
+ if (DenseI32ArrayAttr sgDataAttr = layout.getSgData ())
43
+ sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
44
+ else
45
+ sgShape = computeShapeRatio (shape, sgLayout).value_or (sgShape);
46
+ SmallVector<int64_t > distUnit = computeElementwiseMul (sgLayout, sgShape);
47
+ // Clamp distUnit to the original shape to handle cases where data is
48
+ // shared among subgroups, which may cause distUnit to exceed the original
49
+ // shape.
50
+ for (size_t i = 0 ; i < distUnit.size (); ++i)
51
+ distUnit[i] = std::min (shape[i], distUnit[i]);
52
+ count = computeProduct (shape) / computeProduct (distUnit);
53
+ }
54
+ return std::make_pair (sgShape, count);
55
+ }
56
+
32
57
// / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
33
58
// / from a workgroup descriptor. It replaces the offsets and sizes with
34
59
// / appropriate values for the subgroup.
@@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
129
154
return rewriter.notifyMatchFailure (
130
155
op, " sgLayout attribute is required in layout" );
131
156
132
- SmallVector<int64_t > sgShape;
133
- if (auto sgDataAttr = layout.getSgData ()) {
134
- sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
135
- } else {
136
- assert (wgShape.size () == sgLayout.size () &&
137
- " sgLayout and wgShape must have the same rank" );
138
- sgShape.reserve (wgShape.size ());
139
- for (size_t i = 0 ; i < wgShape.size (); ++i) {
140
- assert (sgLayout[i] != 0 && " sgLayout elements must be non-zero" );
141
- sgShape.push_back (wgShape[i] / sgLayout[i]);
142
- }
143
- }
157
+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
144
158
145
159
// TODO : Handle order attribute
146
160
// Get the subgroup ID
@@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
266
280
if (resultTy.getRank () != 2 )
267
281
return failure ();
268
282
269
- auto originalLayout =
270
- llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
283
+ auto originalLayout = xegpu::getLayoutAttr (op.getResult ());
271
284
if (!originalLayout)
272
285
return failure ();
273
286
274
- SmallVector<Value> newDpasOps;
275
287
size_t i = 0 ;
288
+ SmallVector<Value> newDpasOps;
276
289
for (auto aVec : adaptor.getLhs ()) {
277
290
for (auto bVec : adaptor.getRhs ()) {
291
+
278
292
llvm::SmallVector<Value> operands ({aVec, bVec});
279
293
Value tmpC;
280
294
if (op.getAcc ()) {
@@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
288
302
llvm::cast<VectorType>(bVec.getType ()).getShape ();
289
303
VectorType resTy = VectorType::get ({aVecShape[0 ], bVecShape[1 ]},
290
304
resultTy.getElementType ());
291
- tmpC = rewriter.create <xegpu::DpasOp>(
292
- loc, resTy, operands ,
293
- llvm::ArrayRef<NamedAttribute>(
294
- { " layout_result_0 " , originalLayout. dropSgLayoutAndData ()}));
305
+ tmpC = rewriter.create <xegpu::DpasOp>(loc, resTy, operands);
306
+ xegpu::setLayoutAttr (cast<OpResult>(tmpC) ,
307
+ originalLayout. dropSgLayoutAndData ());
308
+
295
309
newDpasOps.push_back (tmpC);
296
310
}
297
311
}
@@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314
328
}
315
329
};
316
330
331
+ // Handles UnrealizedConversionCastOp generated during
332
+ // SCFStructuralTypeConversions (step 1). This op may appear as either a
333
+ // target or source materialization for Vector values, e.g.:
334
+ // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
335
+ // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
336
+ // it could be either 1:N or N:1 cast. In both cases, the pattern
337
+ // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
338
+ // for example, the following scf::forOp
339
+ // ```
340
+ // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
341
+ // %n = use(%arg1): vector<128x128xf16>
342
+ // scf.yield %n : vector<128x128xf16>
343
+ // }
344
+ // ```
345
+ // Could be converted to:
346
+ // ```
347
+ // %1 = unrealized_conversion_cast %0
348
+ // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
349
+ // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
350
+ // -> (vector<16x16xf16>, vector<16x16xf16) {
351
+ // %m = unrealized_conversion_cast %arg1, %arg2
352
+ // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
353
+ // %n = use(%m): vector<128x128xf16>
354
+ // %b = unrealized_conversion_cast %n
355
+ // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
356
+ // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
357
+ // }
358
+ // %cast = unrealized_conversion_cast %for:2
359
+ // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
360
+ // ```
361
+ // TODO: remove it when context-aware type converter is ready.
362
+ struct UnrealizedConversionCastOpPattern
363
+ : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
364
+ using OpConversionPattern<
365
+ mlir::UnrealizedConversionCastOp>::OpConversionPattern;
366
+
367
+ mlir::LogicalResult
368
+ matchAndRewrite (mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
369
+ ConversionPatternRewriter &rewriter) const override {
370
+ SmallVector<Value> inputs = xegpu::flattenValues (adaptor.getInputs ());
371
+
372
+ auto inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
373
+ auto outputTy = dyn_cast<VectorType>(op->getOpResult (0 ).getType ());
374
+
375
+ if (!inputTy || !outputTy || !llvm::all_equal (op->getResultTypes ()) ||
376
+ !llvm::all_equal (ValueRange (inputs).getTypes ()))
377
+ return failure ();
378
+
379
+ // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
380
+ // It is generated by source materialization (e.g., inits to scf forOp).
381
+ // The input values provided by the adaptor should already be distributed,
382
+ // and their types should correspond exactly to the result types of the
383
+ // operation.
384
+ if (op.getNumOperands () == 1 &&
385
+ llvm::equal (ValueRange (inputs).getTypes (), op->getResultTypes ())) {
386
+ rewriter.replaceOp (op, inputs);
387
+ return success ();
388
+ }
389
+
390
+ // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
391
+ // It is generated by target materialization (e.g., arguments/results
392
+ // of scf forOp). All input values must have the same vector type, and
393
+ // their shape must be evenly divisible by the output vector's shape
394
+ // (determined by the nature of the workgroup to subgroup distribution).
395
+ // TODO: it is not safe to do such forward, since such N:1 cast could be
396
+ // from others.
397
+ if (op.getNumResults () == 1 &&
398
+ computeShapeRatio (outputTy.getShape (), inputTy.getShape ())) {
399
+ rewriter.replaceOpWithMultiple (op, {inputs});
400
+ return success ();
401
+ }
402
+
403
+ return mlir::failure ();
404
+ }
405
+ };
406
+
317
407
} // namespace
318
408
319
409
namespace mlir {
320
410
namespace xegpu {
321
411
void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
322
412
patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324
- patterns.getContext ());
413
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
414
+ UnrealizedConversionCastOpPattern>( patterns.getContext ());
325
415
}
326
416
} // namespace xegpu
327
417
} // namespace mlir
@@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass
334
424
} // namespace
335
425
336
426
void XeGPUWgToSgDistributePass::runOnOperation () {
427
+ // Track existing UnrealizedConversionCastOps
428
+ SmallVector<Operation *> existingCastOps;
429
+ getOperation ()->walk ([&](UnrealizedConversionCastOp castOp) {
430
+ existingCastOps.push_back (castOp.getOperation ());
431
+ });
432
+
433
+ {
434
+ // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
435
+ // VectorType operands. This first converts such operands to
436
+ // RankedTensorType, propagates the layout attribute into the encoding
437
+ // attribute, and finally converts the RankedTensorType to VectorType based
438
+ // on the encoding.
439
+
440
+ TypeConverter converter;
441
+ converter.addConversion ([&](Type type) -> Type { return type; });
442
+ converter.addConversion (
443
+ [&](RankedTensorType type,
444
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
445
+ Type elemTy = type.getElementType ();
446
+ ArrayRef<int64_t > shape = type.getShape ();
447
+
448
+ int count;
449
+ SmallVector<int64_t > subShape;
450
+ std::tie (subShape, count) = getSgShapeAndCount (
451
+ shape,
452
+ dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding ()));
453
+
454
+ auto newTy = VectorType::get (subShape, elemTy);
455
+ result.append (count, newTy);
456
+ return success ();
457
+ });
458
+
459
+ xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (),
460
+ converter);
461
+ }
462
+
463
+ // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
464
+ // as well as XeGPU, Arith, and Vector operations.
337
465
MLIRContext *ctx = &getContext ();
338
466
RewritePatternSet patterns (ctx);
339
467
ConversionTarget target (*ctx);
468
+ TypeConverter converter;
469
+ converter.addConversion ([&](Type type) -> Type { return type; });
470
+ converter.addConversion (
471
+ [&](xegpu::TensorDescType type,
472
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
473
+ Type elemTy = type.getElementType ();
474
+ ArrayRef<int64_t > shape = type.getShape ();
475
+
476
+ int count;
477
+ SmallVector<int64_t > subShape;
478
+ xegpu::LayoutAttr layout = type.getLayoutAttr ();
479
+ std::tie (subShape, count) = getSgShapeAndCount (shape, layout);
480
+
481
+ if (layout)
482
+ layout = layout.dropSgLayoutAndData ();
483
+
484
+ auto newTy = xegpu::TensorDescType::get (
485
+ type.getContext (), subShape, elemTy, type.getEncoding (), layout);
486
+ result.append (count, newTy);
487
+ return success ();
488
+ });
340
489
341
490
auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
342
491
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
@@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
353
502
};
354
503
355
504
auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
356
- return !layout || layout.getSgLayout () == nullptr ;
505
+ return !layout || ! layout.isWgLayout () ;
357
506
};
358
507
359
508
target.addDynamicallyLegalOp <xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360
509
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361
510
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
362
511
auto tdescTy = getTensorDescType (op);
363
- auto layout = dyn_cast_or_null <xegpu::LayoutAttr>(tdescTy.getLayout ());
512
+ auto layout = dyn_cast_if_present <xegpu::LayoutAttr>(tdescTy.getLayout ());
364
513
return isLegal (layout);
365
514
});
366
515
367
516
target.addDynamicallyLegalOp <xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
368
- auto layout = dyn_cast_or_null< xegpu::LayoutAttr> (op-> getAttr ( " layout " ));
517
+ auto layout = xegpu::getLayoutAttr (op. getResult ( ));
369
518
return isLegal (layout);
370
519
});
371
520
521
+ target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
522
+ [=](UnrealizedConversionCastOp op) {
523
+ return llvm::is_contained (existingCastOps, op.getOperation ());
524
+ });
525
+
372
526
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
373
527
528
+ scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
529
+ target);
374
530
xegpu::populateXeGPUWgToSgDistributePatterns (patterns);
375
531
if (failed (
376
532
applyPartialConversion (getOperation (), target, std::move (patterns))))
377
533
return signalPassFailure ();
534
+
535
+ // Remove sg_layout and sg_data attributes from the Layout
536
+ // attribute for each VectorType result of the operation.
537
+ // For Structured Control Flow ops, the layout is simply removed,
538
+ // since in 1:N case, the layout for new results are missing.
539
+ // Layout propagation pass will activated.
540
+ getOperation ()->walk ([](Operation *op) {
541
+ for (OpResult result : op->getOpResults ()) {
542
+ std::string name = xegpu::getLayoutName (result);
543
+ if (auto layout = op->getAttrOfType <xegpu::LayoutAttr>(name)) {
544
+ op->removeAttr (name);
545
+ if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
546
+ op->setAttr (name, layout.dropSgLayoutAndData ());
547
+ }
548
+ }
549
+ });
378
550
}
0 commit comments