Skip to content

[MLIR] [AMX] Fix strides used by AMX lowering for tile loads and stores. #113476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
}

/// Verifies if the stride matches proper tile access.
LogicalResult verifyStride(MemRefType mType) {
if (mType.getRank() < 2)
return failure();
int64_t last = mType.getRank() - 1;
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
return failure();
return success();
}

/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
Value getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, MemRefType mType,
Value base, Location loc) {
assert(mType.getRank() >= 2);
int64_t last = mType.getRank() - 1;
/// Returns failure if proper stride couldn't be found.
FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
MemRefType mType, Value base, Location loc) {
if (mType.getRank() < 2)
return failure();
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
unsigned width = mType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
if (mType.isDynamicDim(last)) {
// Dynamic size needs code to compute the stride at runtime.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(mType, strides, offset)) ||
strides.back() != 1)
return failure();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
return rewriter.create<LLVM::MulOp>(
loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
return rewriter
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
memrefDescriptor.stride(rewriter, loc, preLast))
.getResult();
}
// Use direct constant for static size.
auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes);
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
.getResult();
}

struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
Expand Down Expand Up @@ -102,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
if (failed(stride))
return failure();
Value stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Type resType = typeConverter->convertType(vType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride);
op, resType, tsz.first, tsz.second, ptr, stride.value());
return success();
}
};
Expand All @@ -128,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
std::pair<Value, Value> tsz =
getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
auto stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
if (failed(stride))
return failure();
Value stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
// Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
return success();
}
};
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/AMX/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,31 @@ func.func @mulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, vector<16x16xf32>
return
}

// CHECK-LABEL: strides(
// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
// CHECK: llvm.mlir.constant(2 : i64) : i64
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
// CHECK: %[[STRIDE_1:.+]] = llvm.mul
// CHECK: "amx.tileloadd64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
// CHECK: llvm.mlir.constant(2 : i64) : i64
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
// CHECK: %[[STRIDE_2:.+]] = llvm.mul
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into vector<16x32xbf16>
%3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into vector<16x32xbf16>
amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, vector<16x32xbf16>
amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, vector<16x32xbf16>
amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, vector<16x32xbf16>
return
}
Loading