Skip to content

Commit d210964

Browse files
authored
[MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. (#113476)
1 parent e4dfb51 commit d210964

File tree

2 files changed

+58
-32
lines changed

2 files changed

+58
-32
lines changed

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,40 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
3737
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
3838
}
3939

40-
/// Verifies if the stride matches proper tile access.
41-
LogicalResult verifyStride(MemRefType mType) {
42-
if (mType.getRank() < 2)
43-
return failure();
44-
int64_t last = mType.getRank() - 1;
45-
int64_t offset;
46-
SmallVector<int64_t, 4> strides;
47-
if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
48-
return failure();
49-
return success();
50-
}
51-
5240
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
5341
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
54-
Value getStride(ConversionPatternRewriter &rewriter,
55-
const LLVMTypeConverter &typeConverter, MemRefType mType,
56-
Value base, Location loc) {
57-
assert(mType.getRank() >= 2);
58-
int64_t last = mType.getRank() - 1;
42+
/// Returns failure if proper stride couldn't be found.
43+
FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
44+
const LLVMTypeConverter &typeConverter,
45+
MemRefType mType, Value base, Location loc) {
46+
if (mType.getRank() < 2)
47+
return failure();
48+
int64_t preLast = mType.getRank() - 2;
5949
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
6050
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
6151
assert(llvm::isPowerOf2_64(width) && width >= 8);
6252
unsigned bytes = width >> 3;
63-
if (mType.isDynamicDim(last)) {
64-
// Dynamic size needs code to compute the stride at runtime.
53+
int64_t offset;
54+
SmallVector<int64_t, 4> strides;
55+
if (failed(getStridesAndOffset(mType, strides, offset)) ||
56+
strides.back() != 1)
57+
return failure();
58+
if (strides[preLast] == ShapedType::kDynamic) {
59+
// Dynamic stride needs code to compute the stride at runtime.
6560
MemRefDescriptor memrefDescriptor(base);
6661
auto attr = rewriter.getI64IntegerAttr(bytes);
6762
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68-
return rewriter.create<LLVM::MulOp>(
69-
loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
63+
return rewriter
64+
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
65+
memrefDescriptor.stride(rewriter, loc, preLast))
66+
.getResult();
7067
}
71-
// Use direct constant for static size.
72-
auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
73-
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68+
// Use direct constant for static stride.
69+
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
70+
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71+
.getResult();
7472
}
7573

7674
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
@@ -102,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
102100
std::pair<Value, Value> tsz =
103101
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
104102
// Determine stride.
105-
if (failed(verifyStride(mType)))
103+
auto stride = getStride(rewriter, *getTypeConverter(), mType,
104+
adaptor.getBase(), op.getLoc());
105+
if (failed(stride))
106106
return failure();
107-
Value stride = getStride(rewriter, *getTypeConverter(), mType,
108-
adaptor.getBase(), op.getLoc());
109107
// Replace operation with intrinsic.
110108
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
111109
adaptor.getIndices(), rewriter);
112110
Type resType = typeConverter->convertType(vType);
113111
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
114-
op, resType, tsz.first, tsz.second, ptr, stride);
112+
op, resType, tsz.first, tsz.second, ptr, stride.value());
115113
return success();
116114
}
117115
};
@@ -128,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
128126
std::pair<Value, Value> tsz =
129127
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
130128
// Determine stride.
131-
if (failed(verifyStride(mType)))
129+
auto stride = getStride(rewriter, *getTypeConverter(), mType,
130+
adaptor.getBase(), op.getLoc());
131+
if (failed(stride))
132132
return failure();
133-
Value stride = getStride(rewriter, *getTypeConverter(), mType,
134-
adaptor.getBase(), op.getLoc());
135133
// Replace operation with intrinsic.
136134
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
137135
adaptor.getIndices(), rewriter);
138136
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
139-
op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
137+
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
140138
return success();
141139
}
142140
};

mlir/test/Dialect/AMX/legalize-for-llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
4343
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
4444
return
4545
}
46+
47+
// CHECK-LABEL: strides(
48+
// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
49+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
50+
// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
51+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
52+
// CHECK: llvm.mlir.constant(2 : i64) : i64
53+
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
54+
// CHECK: %[[STRIDE_1:.+]] = llvm.mul
55+
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
56+
// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
57+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
58+
// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
59+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
60+
// CHECK: llvm.mlir.constant(2 : i64) : i64
61+
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
62+
// CHECK: %[[STRIDE_2:.+]] = llvm.mul
63+
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
64+
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
65+
%0 = arith.constant 0 : index
66+
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
67+
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
68+
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
69+
amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
70+
amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
71+
amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
72+
return
73+
}

0 commit comments

Comments
 (0)