Skip to content

Commit 3df9219

Browse files
authored
[mlir][tosa] Support DenseResourceElementsAttr in TOSA transpose folders (#124532)
Handle dense resource attributes in the transpose TOSA folder. Currently their interface does not align with the rest of the `ElementsAttr` when it comes to data accessing hence the special handling. Signed-off-by: Georgios Pinitas <[email protected]>
1 parent 5a668bd commit 3df9219

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/DialectResourceBlobManager.h"
2122
#include "mlir/IR/Matchers.h"
2223
#include "mlir/Pass/Pass.h"
2324
#include "llvm/ADT/APFloat.h"
@@ -176,13 +177,36 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
176177
llvm::ArrayRef<ElementType>(outputValues));
177178
}
178179

180+
// Try to get the values of a DenseResourceElementsAttr construct
181+
template <typename T>
182+
std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
183+
if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
184+
// Check that the resource memory blob exists
185+
AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
186+
if (!blob)
187+
return std::nullopt;
188+
189+
// Check that the data are in a valid form
190+
bool isSplat = false;
191+
if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
192+
blob->getData(), isSplat)) {
193+
return std::nullopt;
194+
}
195+
196+
return blob->template getDataAs<T>();
197+
}
198+
199+
return std::nullopt;
200+
}
201+
179202
// A type specialized transposition of an ElementsAttr.
180203
// This implementation tries to operate on the underlying data in its raw
181204
// representation when possible to avoid allocating a large number of Attribute
182205
// objects.
183206
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184207
ShapedType outputType,
185208
llvm::ArrayRef<int64_t> permValues) {
209+
// Handle generic ElementsAttr
186210
if (auto data = attr.tryGetValues<bool>())
187211
return transposeType(*data, inputType, outputType, permValues);
188212

@@ -204,6 +228,35 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
204228
if (auto data = attr.tryGetValues<APFloat>())
205229
return transposeType(*data, inputType, outputType, permValues);
206230

231+
// Handle DenseResourceElementsAttr
232+
if (isa<DenseResourceElementsAttr>(attr)) {
233+
auto elementTy = attr.getElementType();
234+
235+
if (auto data = tryGetDenseResourceValues<bool>(attr);
236+
data && elementTy.isInteger(1))
237+
return transposeType(*data, inputType, outputType, permValues);
238+
239+
if (auto data = tryGetDenseResourceValues<int8_t>(attr);
240+
data && elementTy.isInteger(8))
241+
return transposeType(*data, inputType, outputType, permValues);
242+
243+
if (auto data = tryGetDenseResourceValues<int16_t>(attr);
244+
data && elementTy.isInteger(16))
245+
return transposeType(*data, inputType, outputType, permValues);
246+
247+
if (auto data = tryGetDenseResourceValues<int32_t>(attr);
248+
data && elementTy.isInteger(32))
249+
return transposeType(*data, inputType, outputType, permValues);
250+
251+
if (auto data = tryGetDenseResourceValues<int64_t>(attr);
252+
data && elementTy.isInteger(64))
253+
return transposeType(*data, inputType, outputType, permValues);
254+
255+
if (auto data = tryGetDenseResourceValues<float>(attr);
256+
data && elementTy.isF32())
257+
return transposeType(*data, inputType, outputType, permValues);
258+
}
259+
207260
return nullptr;
208261
}
209262

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,18 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i
108108
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
109109
}
110110

111-
// CHECK-LABEL: @transpose_nofold_dense_resource
112-
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
111+
// CHECK-LABEL: @transpose_fold_dense_resource
112+
func.func @transpose_fold_dense_resource() -> tensor<2x2xf32> {
113113
%0 = "tosa.const"() <{values = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
114114

115-
// CHECK: tosa.transpose
115+
// CHECK-NOT: tosa.transpose
116116
%2 = tosa.transpose %0 { perms = array<i32: 1, 0> }: (tensor<2x2xf32>) -> tensor<2x2xf32>
117117
return %2 : tensor<2x2xf32>
118118
}
119119
{-#
120120
dialect_resources: {
121121
builtin: {
122-
resource: "0x08000000010000000000000002000000000000000300000000000000"
122+
resource: "0x040000003f800000400000004040000040800000"
123123
}
124124
}
125125
#-}

0 commit comments

Comments
 (0)