Skip to content

Commit 9ba0a77

Browse files
authored
[mlir][spirv] Add support for dense_resource in arith to spirv (#91318)
This adds support for `dense_resource` in arith to spirv. Note that this inlines the blob into the IR. Another possibility would be to add proper dense_resource support to spirv, but there is a lot of special handling going on to convert a `DenseElementsAttr` to the correct SPIRV type. Some of that even iterates over all the values in the Attribute. For proper support of a `DenseResourceElementsAttr` this probably needs a redesign. I would like to hear some opinions on that! The test is disabled on non little Endian machines. See #63469 for more information.
1 parent 30d0850 commit 9ba0a77

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
1818
#include "mlir/IR/BuiltinAttributes.h"
1919
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/DialectResourceBlobManager.h"
2021
#include "llvm/ADT/APInt.h"
2122
#include "llvm/ADT/ArrayRef.h"
2223
#include "llvm/ADT/STLExtras.h"
@@ -229,16 +230,41 @@ struct ConstantCompositeOpPattern final
229230
if (!srcType || srcType.getNumElements() == 1)
230231
return failure();
231232

232-
// arith.constant should only have vector or tenor types.
233-
assert((isa<VectorType, RankedTensorType>(srcType)));
233+
// arith.constant should only have vector or tensor types. This is a MLIR
234+
// wide problem at the moment.
235+
if (!isa<VectorType, RankedTensorType>(srcType))
236+
return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
234237

235238
Type dstType = getTypeConverter()->convertType(srcType);
236239
if (!dstType)
237240
return failure();
238241

239-
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
240-
if (!dstElementsAttr)
241-
return failure();
242+
// Import the resource into the IR to make use of the special handling of
243+
// element types later on.
244+
mlir::DenseElementsAttr dstElementsAttr;
245+
if (auto denseElementsAttr =
246+
dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247+
dstElementsAttr = denseElementsAttr;
248+
} else if (auto resourceAttr =
249+
dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250+
251+
AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252+
if (!blob)
253+
return constOp->emitError("could not find resource blob");
254+
255+
ArrayRef<char> ptr = blob->getData();
256+
257+
// Check that the buffer meets the requirements to get converted to a
258+
// DenseElementsAttr
259+
bool detectedSplat = false;
260+
if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261+
return constOp->emitError("resource is not a valid buffer");
262+
263+
dstElementsAttr =
264+
DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265+
} else {
266+
return constOp->emitError("unsupported elements attribute");
267+
}
242268

243269
ShapedType dstAttrType = dstElementsAttr.getType();
244270

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s
2+
3+
4+
//===----------------------------------------------------------------------===//
5+
// arith.constant dense_resource
6+
//
7+
// The decoding of dense_resource differs between little and big endian
8+
// machines. At the moment only litte endian is supported.
9+
// See https://github.com/llvm/llvm-project/issues/63469 for more infos.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
// XFAIL: target=s390x-{{.*}}
14+
15+
module attributes {
16+
spirv.target_env = #spirv.target_env<
17+
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
18+
} {
19+
func.func @constant_dense_resource() {
20+
// CHECK: %{{.*}} = spirv.Constant dense<[0.203224242, -0.254296064, -0.365104556, -0.469196141, 0.466041982]> : tensor<5xf32> : !spirv.array<5 x f32>
21+
%0 = arith.constant dense_resource<dense_resource_test_5xf32> : tensor<5xf32>
22+
// CHECK: %{{.*}} = spirv.Constant dense<[1, 2]> : vector<2xi32>
23+
%1 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>
24+
// CHECK: %{{.*}} = spirv.Constant dense<[0.35476172, 0.351080596, -0.0795008316, 0.366843373]> : tensor<4xf32> : !spirv.array<4 x f32>
25+
%2 = arith.constant dense_resource<dense_resource_test_2x2xf32> : tensor<1x2x2xf32>
26+
return
27+
}
28+
}
29+
// Resources are kept at end of file. New tests should be added above this.
30+
{-#
31+
dialect_resources: {
32+
builtin: {
33+
dense_resource_test_2xi32: "0x400000000100000002000000",
34+
dense_resource_test_5xf32: "0x08000000041A503E183382BEFCEEBABE7A3AF0BE0E9DEE3E",
35+
dense_resource_test_2x2xf32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
36+
}
37+
}
38+
#-}

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,34 @@ func.func @unsupported_constant_tensor_2xf64_0() {
9595
return
9696
}
9797

98+
// -----
99+
100+
func.func @constant_dense_resource_non_existant() {
101+
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
102+
// expected-error @+1 {{could not find resource blob}}
103+
%0 = arith.constant dense_resource<non_existant> : tensor<5xf32>
104+
return
105+
}
106+
107+
// -----
108+
109+
module {
110+
func.func @constant_dense_resource_invalid_buffer() {
111+
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
112+
// expected-error @+1 {{resource is not a valid buffer}}
113+
%0 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>
114+
return
115+
}
116+
}
117+
// This is a buffer of wrong type and shape
118+
{-#
119+
dialect_resources: {
120+
builtin: {
121+
dense_resource_test_2xi32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
122+
}
123+
}
124+
#-}
125+
98126
///===----------------------------------------------------------------------===//
99127
// Type emulation
100128
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)