@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
67
67
return rewriter.create <LLVM::ExtractValueOp>(loc, val, pos);
68
68
}
69
69
70
+ // Helper that returns data layout alignment of a vector.
71
+ LogicalResult getVectorAlignment (const LLVMTypeConverter &typeConverter,
72
+ VectorType vectorType, unsigned &align) {
73
+ Type convertedVectorTy = typeConverter.convertType (vectorType);
74
+ if (!convertedVectorTy)
75
+ return failure ();
76
+
77
+ llvm::LLVMContext llvmContext;
78
+ align = LLVM::TypeToLLVMIRTranslator (llvmContext)
79
+ .getPreferredAlignment (convertedVectorTy,
80
+ typeConverter.getDataLayout ());
81
+
82
+ return success ();
83
+ }
84
+
70
85
// Helper that returns data layout alignment of a memref.
71
86
LogicalResult getMemRefAlignment (const LLVMTypeConverter &typeConverter,
72
87
MemRefType memrefType, unsigned &align) {
73
- // If shape is statically known, assign MemRefTypes to the alignment of a
74
- // VectorType with the same size and dtype. Otherwise, fall back to the
75
- // alignment of the element type.
76
- Type convertedType;
77
- if (memrefType.hasStaticShape ()) {
78
- convertedType = typeConverter.convertType (VectorType::get (
79
- memrefType.getNumElements (), memrefType.getElementType ()));
80
- } else {
81
- convertedType = typeConverter.convertType (memrefType.getElementType ());
82
- }
83
-
84
- if (!convertedType)
88
+ Type elementTy = typeConverter.convertType (memrefType.getElementType ());
89
+ if (!elementTy)
85
90
return failure ();
86
91
87
92
// TODO: this should use the MLIR data layout when it becomes available and
88
93
// stop depending on translation.
89
94
llvm::LLVMContext llvmContext;
90
- align =
91
- LLVM::TypeToLLVMIRTranslator (llvmContext)
92
- .getPreferredAlignment (convertedType, typeConverter.getDataLayout ());
95
+ align = LLVM::TypeToLLVMIRTranslator (llvmContext)
96
+ .getPreferredAlignment (elementTy, typeConverter.getDataLayout ());
93
97
return success ();
94
98
}
95
99
@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
235
239
template <class LoadOrStoreOp >
236
240
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern <LoadOrStoreOp> {
237
241
public:
242
+ explicit VectorLoadStoreConversion (const LLVMTypeConverter &typeConv,
243
+ bool useVectorAlign)
244
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
245
+ useVectorAlignment(useVectorAlign) {}
238
246
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
239
247
240
248
LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
251
259
252
260
// Resolve alignment.
253
261
unsigned align;
254
- if (failed (getMemRefAlignment (*this ->getTypeConverter (), memRefTy, align)))
255
- return failure ();
262
+ if (useVectorAlignment) {
263
+ if (failed (
264
+ getVectorAlignment (*this ->getTypeConverter (), vectorTy, align)))
265
+ return rewriter.notifyMatchFailure (
266
+ loadOrStoreOp, " could not resolve vector alignment" );
267
+ } else {
268
+ if (failed (
269
+ getMemRefAlignment (*this ->getTypeConverter (), memRefTy, align)))
270
+ return rewriter.notifyMatchFailure (
271
+ loadOrStoreOp, " could not resolve memref alignment" );
272
+ }
256
273
257
274
// Resolve address.
258
275
auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
263
280
rewriter);
264
281
return success ();
265
282
}
283
+
284
+ private:
285
+ const bool useVectorAlignment;
266
286
};
267
287
268
288
// / Conversion pattern for a vector.gather.
269
289
class VectorGatherOpConversion
270
290
: public ConvertOpToLLVMPattern<vector::GatherOp> {
271
291
public:
292
+ explicit VectorGatherOpConversion (const LLVMTypeConverter &typeConv,
293
+ bool useVectorAlign)
294
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
295
+ useVectorAlignment(useVectorAlign) {}
272
296
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
273
297
274
298
LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
289
313
290
314
// Resolve alignment.
291
315
unsigned align;
292
- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
293
- return rewriter.notifyMatchFailure (gather,
294
- " could not resolve memref alignment" );
316
+ if (useVectorAlignment) {
317
+ if (failed (getVectorAlignment (*this ->getTypeConverter (), vType, align)))
318
+ return rewriter.notifyMatchFailure (
319
+ gather, " could not resolve vector alignment" );
320
+ } else {
321
+ if (failed (
322
+ getMemRefAlignment (*this ->getTypeConverter (), memRefType, align)))
323
+ return rewriter.notifyMatchFailure (
324
+ gather, " could not resolve memref alignment" );
295
325
}
296
326
297
327
// Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
308
338
adaptor.getPassThru (), rewriter.getI32IntegerAttr (align));
309
339
return success ();
310
340
}
341
+
342
+ private:
343
+ const bool useVectorAlignment;
311
344
};
312
345
313
346
// / Conversion pattern for a vector.scatter.
314
347
class VectorScatterOpConversion
315
348
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
316
349
public:
350
+ explicit VectorScatterOpConversion (const LLVMTypeConverter &typeConv,
351
+ bool useVectorAlign)
352
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
353
+ useVectorAlignment(useVectorAlign) {}
354
+
317
355
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
318
356
319
357
LogicalResult
@@ -333,9 +371,18 @@ class VectorScatterOpConversion
333
371
334
372
// Resolve alignment.
335
373
unsigned align;
336
- if (failed (getMemRefAlignment (*getTypeConverter (), memRefType, align))) {
337
- return rewriter.notifyMatchFailure (scatter,
338
- " could not resolve memref alignment" );
374
+
375
+ // Resolve alignment.
376
+ unsigned align;
377
+ if (useVectorAlignment) {
378
+ if (failed (getVectorAlignment (*this ->getTypeConverter (), vType, align)))
379
+ return rewriter.notifyMatchFailure (
380
+ scatter, " could not resolve vector alignment" );
381
+ } else {
382
+ if (failed (
383
+ getMemRefAlignment (*this ->getTypeConverter (), memRefType, align)))
384
+ return rewriter.notifyMatchFailure (
385
+ scatter, " could not resolve memref alignment" );
339
386
}
340
387
341
388
// Resolve address.
@@ -351,6 +398,9 @@ class VectorScatterOpConversion
351
398
rewriter.getI32IntegerAttr (align));
352
399
return success ();
353
400
}
401
+
402
+ private:
403
+ const bool useVectorAlignment;
354
404
};
355
405
356
406
// / Conversion pattern for a vector.expandload.
@@ -1939,7 +1989,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
1939
1989
// / Populate the given list with patterns that convert from Vector to LLVM.
1940
1990
void mlir::populateVectorToLLVMConversionPatterns (
1941
1991
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1942
- bool reassociateFPReductions, bool force32BitVectorIndices) {
1992
+ bool reassociateFPReductions, bool force32BitVectorIndices,
1993
+ bool useVectorAlignment) {
1943
1994
// This function populates only ConversionPatterns, not RewritePatterns.
1944
1995
MLIRContext *ctx = converter.getDialect ()->getContext ();
1945
1996
patterns.add <VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1999,23 @@ void mlir::populateVectorToLLVMConversionPatterns(
1948
1999
VectorExtractElementOpConversion, VectorExtractOpConversion,
1949
2000
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1950
2001
VectorInsertOpConversion, VectorPrintOpConversion,
1951
- VectorTypeCastOpConversion, VectorScaleOpConversion,
1952
- VectorLoadStoreConversion<vector::LoadOp>,
1953
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
1954
- VectorLoadStoreConversion<vector::StoreOp>,
1955
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
1956
- VectorGatherOpConversion, VectorScatterOpConversion,
1957
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1958
- VectorSplatOpLowering, VectorSplatNdOpLowering,
1959
- VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1960
- MaskedReductionOpConversion, VectorInterleaveOpLowering,
1961
- VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1962
- VectorScalableStepOpLowering>(converter);
2002
+ VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
2003
+ patterns.add <VectorLoadStoreConversion<vector::LoadOp>>(
2004
+ ctx, useVectorAlignment),
2005
+ patterns.add <VectorLoadStoreConversion<vector::MaskedLoadOp>>(
2006
+ ctx, useVectorAlignment),
2007
+ patterns.add <VectorLoadStoreConversion<vector::StoreOp>>(
2008
+ ctx, useVectorAlignment),
2009
+ patterns.add <VectorLoadStoreConversion<vector::MaskedStoreOp>>(
2010
+ ctx, useVectorAlignment),
2011
+ patterns.add <VectorGatherOpConversion, VectorScatterOpConversion,
2012
+ VectorExpandLoadOpConversion,
2013
+ VectorCompressStoreOpConversion, VectorSplatOpLowering,
2014
+ VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
2015
+ VectorScalableExtractOpLowering, MaskedReductionOpConversion,
2016
+ VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
2017
+ VectorFromElementsLowering, VectorScalableStepOpLowering>(
2018
+ converter);
1963
2019
}
1964
2020
1965
2021
void mlir::populateVectorToLLVMMatrixConversionPatterns (
0 commit comments