@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
67
67
return rewriter.create <LLVM::ExtractValueOp>(loc, val, pos);
68
68
}
69
69
70
+ // Helper that returns data layout alignment of a vector.
71
+ LogicalResult getVectorAlignment (const LLVMTypeConverter &typeConverter,
72
+ VectorType vectorType, unsigned &align) {
73
+ Type convertedVectorTy = typeConverter.convertType (vectorType);
74
+ if (!convertedVectorTy)
75
+ return failure ();
76
+
77
+ llvm::LLVMContext llvmContext;
78
+ align = LLVM::TypeToLLVMIRTranslator (llvmContext)
79
+ .getPreferredAlignment (convertedVectorTy,
80
+ typeConverter.getDataLayout ());
81
+
82
+ return success ();
83
+ }
84
+
70
85
// Helper that returns data layout alignment of a memref.
71
86
LogicalResult getMemRefAlignment (const LLVMTypeConverter &typeConverter,
72
87
MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
82
97
return success ();
83
98
}
84
99
100
+ // Helper to resolve the alignment for vector load/store, gather and scatter
101
+ // ops. If useVectorAlignment is true, get the preferred alignment for the
102
+ // vector type in the operation. This option is used for hardware backends with
103
+ // vectorization. Otherwise, use the preferred alignment of the element type of
104
+ // the memref. Note that if you choose to use vector alignment, the shape of the
105
+ // vector type must be resolved before the ConvertVectorToLLVM pass is run.
106
+ LogicalResult getVectorToLLVMAlignment (const LLVMTypeConverter &typeConverter,
107
+ VectorType vectorType,
108
+ MemRefType memrefType, unsigned &align,
109
+ bool useVectorAlignment) {
110
+ if (useVectorAlignment) {
111
+ if (failed (getVectorAlignment (typeConverter, vectorType, align))) {
112
+ return failure ();
113
+ }
114
+ } else {
115
+ if (failed (getMemRefAlignment (typeConverter, memrefType, align))) {
116
+ return failure ();
117
+ }
118
+ }
119
+ return success ();
120
+ }
121
+
85
122
// Check if the last stride is non-unit and has a valid memory space.
86
123
static LogicalResult isMemRefTypeSupported (MemRefType memRefType,
87
124
const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
224
261
template <class LoadOrStoreOp >
225
262
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern <LoadOrStoreOp> {
226
263
public:
264
+ explicit VectorLoadStoreConversion (const LLVMTypeConverter &typeConv,
265
+ bool useVectorAlign)
266
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267
+ useVectorAlignment(useVectorAlign) {}
227
268
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
228
269
229
270
LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
240
281
241
282
// Resolve alignment.
242
283
unsigned align;
243
- if (failed (getMemRefAlignment (*this ->getTypeConverter (), memRefTy, align)))
244
- return failure ();
284
+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vectorTy,
285
+ memRefTy, align, useVectorAlignment)))
286
+ return rewriter.notifyMatchFailure (loadOrStoreOp,
287
+ " could not resolve alignment" );
245
288
246
289
// Resolve address.
247
290
auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
252
295
rewriter);
253
296
return success ();
254
297
}
298
+
299
+ private:
300
+ // If true, use the preferred alignment of the vector type.
301
+ // If false, use the preferred alignment of the element type
302
+ // of the memref. This flag is intended for use with hardware
303
+ // backends that require alignment of vector operations.
304
+ const bool useVectorAlignment;
255
305
};
256
306
257
307
// / Conversion pattern for a vector.gather.
258
308
class VectorGatherOpConversion
259
309
: public ConvertOpToLLVMPattern<vector::GatherOp> {
260
310
public:
311
+ explicit VectorGatherOpConversion (const LLVMTypeConverter &typeConv,
312
+ bool useVectorAlign)
313
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314
+ useVectorAlignment(useVectorAlign) {}
261
315
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
262
316
263
317
LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
278
332
279
333
// Resolve alignment.
280
334
unsigned align;
281
- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
282
- return rewriter.notifyMatchFailure (gather,
283
- " could not resolve memref alignment" );
284
- }
335
+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vType,
336
+ memRefType, align, useVectorAlignment)))
337
+ return rewriter.notifyMatchFailure (gather, " could not resolve alignment" );
285
338
286
339
// Resolve address.
287
340
Value ptr = getStridedElementPtr (loc, memRefType, adaptor.getBase (),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
297
350
adaptor.getPassThru (), rewriter.getI32IntegerAttr (align));
298
351
return success ();
299
352
}
353
+
354
+ private:
355
+ // If true, use the preferred alignment of the vector type.
356
+ // If false, use the preferred alignment of the element type
357
+ // of the memref. This flag is intended for use with hardware
358
+ // backends that require alignment of vector operations.
359
+ const bool useVectorAlignment;
300
360
};
301
361
302
362
// / Conversion pattern for a vector.scatter.
303
363
class VectorScatterOpConversion
304
364
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
305
365
public:
366
+ explicit VectorScatterOpConversion (const LLVMTypeConverter &typeConv,
367
+ bool useVectorAlign)
368
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369
+ useVectorAlignment(useVectorAlign) {}
370
+
306
371
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
307
372
308
373
LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
322
387
323
388
// Resolve alignment.
324
389
unsigned align;
325
- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
390
+ if (failed (getVectorToLLVMAlignment (*this ->getTypeConverter (), vType,
391
+ memRefType, align, useVectorAlignment)))
326
392
return rewriter.notifyMatchFailure (scatter,
327
- " could not resolve memref alignment" );
328
- }
393
+ " could not resolve alignment" );
329
394
330
395
// Resolve address.
331
396
Value ptr = getStridedElementPtr (loc, memRefType, adaptor.getBase (),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
340
405
rewriter.getI32IntegerAttr (align));
341
406
return success ();
342
407
}
408
+
409
+ private:
410
+ // If true, use the preferred alignment of the vector type.
411
+ // If false, use the preferred alignment of the element type
412
+ // of the memref. This flag is intended for use with hardware
413
+ // backends that require alignment of vector operations.
414
+ const bool useVectorAlignment;
343
415
};
344
416
345
417
// / Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
1928
2000
// / Populate the given list with patterns that convert from Vector to LLVM.
1929
2001
void mlir::populateVectorToLLVMConversionPatterns (
1930
2002
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1931
- bool reassociateFPReductions, bool force32BitVectorIndices) {
2003
+ bool reassociateFPReductions, bool force32BitVectorIndices,
2004
+ bool useVectorAlignment) {
1932
2005
// This function populates only ConversionPatterns, not RewritePatterns.
1933
2006
MLIRContext *ctx = converter.getDialect ()->getContext ();
1934
2007
patterns.add <VectorReductionOpConversion>(converter, reassociateFPReductions);
1935
2008
patterns.add <VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2009
+ patterns.add <VectorLoadStoreConversion<vector::LoadOp>,
2010
+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
2011
+ VectorLoadStoreConversion<vector::StoreOp>,
2012
+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
2013
+ VectorGatherOpConversion, VectorScatterOpConversion>(
2014
+ converter, useVectorAlignment);
1936
2015
patterns.add <VectorBitCastOpConversion, VectorShuffleOpConversion,
1937
2016
VectorExtractElementOpConversion, VectorExtractOpConversion,
1938
2017
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1939
2018
VectorInsertOpConversion, VectorPrintOpConversion,
1940
2019
VectorTypeCastOpConversion, VectorScaleOpConversion,
1941
- VectorLoadStoreConversion<vector::LoadOp>,
1942
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
1943
- VectorLoadStoreConversion<vector::StoreOp>,
1944
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
1945
- VectorGatherOpConversion, VectorScatterOpConversion,
1946
2020
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1947
2021
VectorSplatOpLowering, VectorSplatNdOpLowering,
1948
2022
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
0 commit comments