@@ -37,40 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
37
37
rewriter.create <LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
38
38
}
39
39
40
- // / Verifies if the stride matches proper tile access.
41
- LogicalResult verifyStride (MemRefType mType ) {
42
- if (mType .getRank () < 2 )
43
- return failure ();
44
- int64_t last = mType .getRank () - 1 ;
45
- int64_t offset;
46
- SmallVector<int64_t , 4 > strides;
47
- if (failed (getStridesAndOffset (mType , strides, offset)) || strides[last] != 1 )
48
- return failure ();
49
- return success ();
50
- }
51
-
52
40
// / Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
53
41
// / shape may "envelop" the actual tile shape, and may be dynamically sized.
54
- Value getStride (ConversionPatternRewriter &rewriter,
55
- const LLVMTypeConverter &typeConverter, MemRefType mType ,
56
- Value base, Location loc) {
57
- assert (mType .getRank () >= 2 );
58
- int64_t last = mType .getRank () - 1 ;
42
+ // / Returns failure if proper stride couldn't be found.
43
+ FailureOr<Value> getStride (ConversionPatternRewriter &rewriter,
44
+ const LLVMTypeConverter &typeConverter,
45
+ MemRefType mType , Value base, Location loc) {
46
+ if (mType .getRank () < 2 )
47
+ return failure ();
48
+ int64_t preLast = mType .getRank () - 2 ;
59
49
Type llvmInt64Type = IntegerType::get (&typeConverter.getContext (), 64 );
60
50
unsigned width = mType .getElementType ().getIntOrFloatBitWidth ();
61
51
assert (llvm::isPowerOf2_64 (width) && width >= 8 );
62
52
unsigned bytes = width >> 3 ;
63
- if (mType .isDynamicDim (last)) {
64
- // Dynamic size needs code to compute the stride at runtime.
53
+ int64_t offset;
54
+ SmallVector<int64_t , 4 > strides;
55
+ if (failed (getStridesAndOffset (mType , strides, offset)) ||
56
+ strides.back () != 1 )
57
+ return failure ();
58
+ if (strides[preLast] == ShapedType::kDynamic ) {
59
+ // Dynamic stride needs code to compute the stride at runtime.
65
60
MemRefDescriptor memrefDescriptor (base);
66
61
auto attr = rewriter.getI64IntegerAttr (bytes);
67
62
Value scale = rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68
- return rewriter.create <LLVM::MulOp>(
69
- loc, llvmInt64Type, scale, memrefDescriptor.size (rewriter, loc, last));
63
+ return rewriter
64
+ .create <LLVM::MulOp>(loc, llvmInt64Type, scale,
65
+ memrefDescriptor.stride (rewriter, loc, preLast))
66
+ .getResult ();
70
67
}
71
- // Use direct constant for static size.
72
- auto attr = rewriter.getI64IntegerAttr (mType .getDimSize (last) * bytes);
73
- return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68
+ // Use direct constant for static stride.
69
+ auto attr = rewriter.getI64IntegerAttr (strides[preLast] * bytes);
70
+ return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71
+ .getResult ();
74
72
}
75
73
76
74
struct TileZeroConversion : public ConvertOpToLLVMPattern <TileZeroOp> {
@@ -102,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
102
100
std::pair<Value, Value> tsz =
103
101
getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
104
102
// Determine stride.
105
- if (failed (verifyStride (mType )))
103
+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
104
+ adaptor.getBase (), op.getLoc ());
105
+ if (failed (stride))
106
106
return failure ();
107
- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
108
- adaptor.getBase (), op.getLoc ());
109
107
// Replace operation with intrinsic.
110
108
Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
111
109
adaptor.getIndices (), rewriter);
112
110
Type resType = typeConverter->convertType (vType);
113
111
rewriter.replaceOpWithNewOp <amx::x86_amx_tileloadd64>(
114
- op, resType, tsz.first , tsz.second , ptr, stride);
112
+ op, resType, tsz.first , tsz.second , ptr, stride. value () );
115
113
return success ();
116
114
}
117
115
};
@@ -128,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
128
126
std::pair<Value, Value> tsz =
129
127
getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
130
128
// Determine stride.
131
- if (failed (verifyStride (mType )))
129
+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
130
+ adaptor.getBase (), op.getLoc ());
131
+ if (failed (stride))
132
132
return failure ();
133
- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
134
- adaptor.getBase (), op.getLoc ());
135
133
// Replace operation with intrinsic.
136
134
Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
137
135
adaptor.getIndices (), rewriter);
138
136
rewriter.replaceOpWithNewOp <amx::x86_amx_tilestored64>(
139
- op, tsz.first , tsz.second , ptr, stride, adaptor.getVal ());
137
+ op, tsz.first , tsz.second , ptr, stride. value () , adaptor.getVal ());
140
138
return success ();
141
139
}
142
140
};
0 commit comments