Skip to content

[mlir][SPIRV] Add support for dense_resource in arith to spirv #91318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -229,16 +230,41 @@ struct ConstantCompositeOpPattern final
if (!srcType || srcType.getNumElements() == 1)
return failure();

// arith.constant should only have vector or tenor types.
assert((isa<VectorType, RankedTensorType>(srcType)));
// arith.constant should only have vector or tensor types. This is a MLIR
// wide problem at the moment.
if (!isa<VectorType, RankedTensorType>(srcType))
return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");

Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();

auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return failure();
// Import the resource into the IR to make use of the special handling of
// element types later on.
mlir::DenseElementsAttr dstElementsAttr;
if (auto denseElementsAttr =
dyn_cast<DenseElementsAttr>(constOp.getValue())) {
dstElementsAttr = denseElementsAttr;
} else if (auto resourceAttr =
dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {

AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
if (!blob)
return constOp->emitError("could not find resource blob");

ArrayRef<char> ptr = blob->getData();

// Check that the buffer meets the requirements to get converted to a
// DenseElementsAttr
bool detectedSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
return constOp->emitError("resource is not a valid buffer");

dstElementsAttr =
DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
} else {
return constOp->emitError("unsupported elements attribute");
}

ShapedType dstAttrType = dstElementsAttr.getType();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s


//===----------------------------------------------------------------------===//
// arith.constant dense_resource
//
// The decoding of dense_resource differs between little and big endian
// machines. At the moment only litte endian is supported.
// See https://github.com/llvm/llvm-project/issues/63469 for more infos.
//
//===----------------------------------------------------------------------===//

// XFAIL: target=s390x-{{.*}}

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
} {
func.func @constant_dense_resource() {
// CHECK: %{{.*}} = spirv.Constant dense<[0.203224242, -0.254296064, -0.365104556, -0.469196141, 0.466041982]> : tensor<5xf32> : !spirv.array<5 x f32>
%0 = arith.constant dense_resource<dense_resource_test_5xf32> : tensor<5xf32>
// CHECK: %{{.*}} = spirv.Constant dense<[1, 2]> : vector<2xi32>
%1 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>
// CHECK: %{{.*}} = spirv.Constant dense<[0.35476172, 0.351080596, -0.0795008316, 0.366843373]> : tensor<4xf32> : !spirv.array<4 x f32>
%2 = arith.constant dense_resource<dense_resource_test_2x2xf32> : tensor<1x2x2xf32>
return
}
}
// Resources are kept at end of file. New tests should be added above this.
{-#
dialect_resources: {
builtin: {
dense_resource_test_2xi32: "0x400000000100000002000000",
dense_resource_test_5xf32: "0x08000000041A503E183382BEFCEEBABE7A3AF0BE0E9DEE3E",
dense_resource_test_2x2xf32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
}
}
#-}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,34 @@ func.func @unsupported_constant_tensor_2xf64_0() {
return
}

// -----

func.func @constant_dense_resource_non_existant() {
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
// expected-error @+1 {{could not find resource blob}}
%0 = arith.constant dense_resource<non_existant> : tensor<5xf32>
return
}

// -----

module {
func.func @constant_dense_resource_invalid_buffer() {
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
// expected-error @+1 {{resource is not a valid buffer}}
%0 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>
return
}
}
// This is a buffer of wrong type and shape
{-#
dialect_resources: {
builtin: {
dense_resource_test_2xi32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
}
}
#-}

///===----------------------------------------------------------------------===//
// Type emulation
//===----------------------------------------------------------------------===//
Expand Down
Loading