Skip to content

Commit d7eb941

Browse files
committed
[mlir][tosa] Support DenseResourceElementsAttr in TOSA transpose folder
Handle dense resource attributes in the transpose TOSA folder. Currently their interface does not align with the rest of the `ElementsAttr` when it comes to data accessing hence the special handling. Signed-off-by: Georgios Pinitas <[email protected]>
1 parent 87103a0 commit d7eb941

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/DialectResourceBlobManager.h"
2122
#include "mlir/IR/Matchers.h"
2223
#include "mlir/Pass/Pass.h"
2324
#include "llvm/ADT/APFloat.h"
@@ -176,13 +177,38 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
176177
llvm::ArrayRef<ElementType>(outputValues));
177178
}
178179

180+
// Function that tries to wrap the DenseResourceElementsAttr data access
181+
// handling as unfortunately at the moment don't share the same interface
182+
// with DenseElementsAttr
183+
template <typename T>
184+
std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
185+
if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
186+
// Check that the resource memory blob exists
187+
AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
188+
if (!blob)
189+
return std::nullopt;
190+
191+
// Check that the data are in a valid form
192+
bool isSplat = false;
193+
if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
194+
blob->getData(), isSplat)) {
195+
return std::nullopt;
196+
}
197+
198+
return blob->template getDataAs<T>();
199+
}
200+
201+
return std::nullopt;
202+
}
203+
179204
// A type specialized transposition of an ElementsAttr.
180205
// This implementation tries to operate on the underlying data in its raw
181206
// representation when possible to avoid allocating a large number of Attribute
182207
// objects.
183208
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184209
ShapedType outputType,
185210
llvm::ArrayRef<int64_t> permValues) {
211+
// Handle generic ElementsAttr
186212
if (auto data = attr.tryGetValues<bool>())
187213
return transposeType(*data, inputType, outputType, permValues);
188214

@@ -204,6 +230,35 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
204230
if (auto data = attr.tryGetValues<APFloat>())
205231
return transposeType(*data, inputType, outputType, permValues);
206232

233+
// Handle DenseResourceElementsAttr
234+
if (isa<DenseResourceElementsAttr>(attr)) {
235+
auto elementTy = attr.getElementType();
236+
237+
if (auto data = tryGetDenseResourceValues<bool>(attr);
238+
data && elementTy.isInteger(1))
239+
return transposeType(*data, inputType, outputType, permValues);
240+
241+
if (auto data = tryGetDenseResourceValues<int8_t>(attr);
242+
data && elementTy.isInteger(8))
243+
return transposeType(*data, inputType, outputType, permValues);
244+
245+
if (auto data = tryGetDenseResourceValues<int16_t>(attr);
246+
data && elementTy.isInteger(16))
247+
return transposeType(*data, inputType, outputType, permValues);
248+
249+
if (auto data = tryGetDenseResourceValues<int32_t>(attr);
250+
data && elementTy.isInteger(32))
251+
return transposeType(*data, inputType, outputType, permValues);
252+
253+
if (auto data = tryGetDenseResourceValues<int64_t>(attr);
254+
data && elementTy.isInteger(64))
255+
return transposeType(*data, inputType, outputType, permValues);
256+
257+
if (auto data = tryGetDenseResourceValues<float>(attr);
258+
data && elementTy.isF32())
259+
return transposeType(*data, inputType, outputType, permValues);
260+
}
261+
207262
return nullptr;
208263
}
209264

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,19 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
117117
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
118118
}
119119

120-
// CHECK-LABEL: @transpose_nofold_dense_resource
121-
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
120+
// CHECK-LABEL: @transpose_fold_dense_resource
121+
func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
122122
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
123123
%1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
124124

125-
// CHECK: tosa.transpose
125+
// CHECK-NOT: tosa.transpose
126126
%2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
127127
return %2 : tensor<2x2xf32>
128128
}
129129
{-#
130130
dialect_resources: {
131131
builtin: {
132-
resource: "0x08000000010000000000000002000000000000000300000000000000"
132+
resource: "0x040000003f800000400000004040000040800000"
133133
}
134134
}
135135
#-}

0 commit comments

Comments
 (0)