Skip to content

Commit 9261ab7

Browse files
authored
[mlir][Target] Teach dense_resource conversion to LLVMIR Target (#78958)
This patch adds support for translating dense_resource attributes to LLVMIR Target. The support added is similar to how DenseElementsAttr is handled, except we don't need to handle splats. Another possible way of doing this is adding iteration on dense_resource, but that is non-trivial as DenseResourceAttr is not meant to be something you should directly access. It has subclasses which you are supposed to use to iterate on it.
1 parent 2b8649f commit 9261ab7

File tree

3 files changed

+174
-0
lines changed

3 files changed

+174
-0
lines changed

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/Attributes.h"
2828
#include "mlir/IR/BuiltinOps.h"
2929
#include "mlir/IR/BuiltinTypes.h"
30+
#include "mlir/IR/DialectResourceBlobManager.h"
3031
#include "mlir/IR/RegionGraphTraits.h"
3132
#include "mlir/Support/LLVM.h"
3233
#include "mlir/Support/LogicalResult.h"
@@ -446,6 +447,99 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
446447
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
447448
}
448449

450+
/// Convert a dense resource elements attribute to an LLVM IR constant using its
451+
/// raw data storage if possible. This supports elements attributes of tensor or
452+
/// vector type and avoids constructing separate objects for individual values
453+
/// of the innermost dimension. Constants for other dimensions are still
454+
/// constructed recursively. Returns nullptr on failure and emits errors at
455+
/// `loc`.
456+
static llvm::Constant *convertDenseResourceElementsAttr(
457+
Location loc, DenseResourceElementsAttr denseResourceAttr,
458+
llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
459+
assert(denseResourceAttr && "expected non-null attribute");
460+
461+
llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
462+
if (!llvm::ConstantDataSequential::isElementTypeCompatible(
463+
innermostLLVMType)) {
464+
emitError(loc, "no known conversion for innermost element type");
465+
return nullptr;
466+
}
467+
468+
ShapedType type = denseResourceAttr.getType();
469+
assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
470+
471+
AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
472+
if (!blob) {
473+
emitError(loc, "resource does not exist");
474+
return nullptr;
475+
}
476+
477+
ArrayRef<char> rawData = blob->getData();
478+
479+
// Check that the raw data size matches what is expected for the scalar size.
480+
// TODO: in theory, we could repack the data here to keep constructing from
481+
// raw data.
482+
// TODO: we may also need to consider endianness when cross-compiling to an
483+
// architecture where it is different.
484+
int64_t numElements = denseResourceAttr.getType().getNumElements();
485+
int64_t elementByteSize = rawData.size() / numElements;
486+
if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
487+
emitError(loc, "raw data size does not match element type size");
488+
return nullptr;
489+
}
490+
491+
// Compute the shape of all dimensions but the innermost. Note that the
492+
// innermost dimension may be that of the vector element type.
493+
bool hasVectorElementType = isa<VectorType>(type.getElementType());
494+
int64_t numAggregates =
495+
numElements / (hasVectorElementType
496+
? 1
497+
: denseResourceAttr.getType().getShape().back());
498+
ArrayRef<int64_t> outerShape = type.getShape();
499+
if (!hasVectorElementType)
500+
outerShape = outerShape.drop_back();
501+
502+
// Create a constructor for the innermost constant from a piece of raw data.
503+
std::function<llvm::Constant *(StringRef)> buildCstData;
504+
if (isa<TensorType>(type)) {
505+
auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
506+
if (vectorElementType && vectorElementType.getRank() == 1) {
507+
buildCstData = [&](StringRef data) {
508+
return llvm::ConstantDataVector::getRaw(
509+
data, vectorElementType.getShape().back(), innermostLLVMType);
510+
};
511+
} else if (!vectorElementType) {
512+
buildCstData = [&](StringRef data) {
513+
return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
514+
innermostLLVMType);
515+
};
516+
}
517+
} else if (isa<VectorType>(type)) {
518+
buildCstData = [&](StringRef data) {
519+
return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
520+
innermostLLVMType);
521+
};
522+
}
523+
if (!buildCstData) {
524+
emitError(loc, "unsupported dense_resource type");
525+
return nullptr;
526+
}
527+
528+
// Create innermost constants and defer to the default constant creation
529+
// mechanism for other dimensions.
530+
SmallVector<llvm::Constant *> constants;
531+
int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
532+
(innermostLLVMType->getScalarSizeInBits() / 8);
533+
constants.reserve(numAggregates);
534+
for (unsigned i = 0; i < numAggregates; ++i) {
535+
StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
536+
constants.push_back(buildCstData(data));
537+
}
538+
539+
ArrayRef<llvm::Constant *> constantsRef = constants;
540+
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
541+
}
542+
449543
/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
450544
/// This currently supports integer, floating point, splat and dense element
451545
/// attributes and combinations thereof. Also, an array attribute with two
@@ -546,6 +640,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
546640
return result;
547641
}
548642

643+
if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
644+
return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
645+
moduleTranslation);
646+
}
647+
549648
// Fall back to element-by-element construction otherwise.
550649
if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
551650
assert(elementsAttr.getShapedType().hasStaticShape());

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,55 @@ llvm.func @foo() {
313313
// expected-error @below{{must appear at the module level}}
314314
llvm.linker_options ["test"]
315315
}
316+
317+
// -----
318+
319+
module @does_not_exist {
320+
// expected-error @below{{resource does not exist}}
321+
llvm.mlir.global internal constant @constant(dense_resource<test0> : tensor<4xf32>) : !llvm.array<4 x f32>
322+
}
323+
324+
// -----
325+
326+
module @raw_data_does_not_match_element_type_size {
327+
// expected-error @below{{raw data size does not match element type size}}
328+
llvm.mlir.global internal constant @constant(dense_resource<test1> : tensor<5xf32>) : !llvm.array<4 x f32>
329+
}
330+
331+
{-#
332+
dialect_resources: {
333+
builtin: {
334+
test1: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
335+
}
336+
}
337+
#-}
338+
339+
// -----
340+
341+
module @does_not_exist {
342+
// expected-error @below{{unsupported dense_resource type}}
343+
llvm.mlir.global internal constant @constant(dense_resource<test1> : memref<4xf32>) : !llvm.array<4 x f32>
344+
}
345+
346+
{-#
347+
dialect_resources: {
348+
builtin: {
349+
test1: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
350+
}
351+
}
352+
#-}
353+
354+
// -----
355+
356+
module @no_known_conversion_innermost_eltype {
357+
// expected-error @below{{no known conversion for innermost element type}}
358+
llvm.mlir.global internal constant @constant(dense_resource<test0> : tensor<4xi4>) : !llvm.array<4 x i4>
359+
}
360+
361+
{-#
362+
dialect_resources: {
363+
builtin: {
364+
test1: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
365+
}
366+
}
367+
#-}

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,19 @@ llvm.mlir.global internal @dense_float_vector_3d(dense<[[[1.0, 2.0], [3.0, 4.0]]
101101
// CHECK{LITERAL}: @splat_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>], [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]]
102102
llvm.mlir.global internal @splat_float_vector_3d(dense<42.0> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>
103103

104+
// CHECK{LITERAL}: @dense_resource_tensor_constant = internal constant [5 x float] [float 0x3FCA034080000000, float 0xBFD0466300000000, float 0xBFD75DDF80000000, float 0xBFDE074F40000000, float 0x3FDDD3A1C0000000]
105+
llvm.mlir.global internal constant @dense_resource_tensor_constant(dense_resource<dense_resource_test_5xf32> : tensor<5xf32>) : !llvm.array<5 x f32>
106+
107+
// CHECK{LITERAL}: @dense_resource_vector_constant = internal constant <5 x float> <float 0x3FCA034080000000, float 0xBFD0466300000000, float 0xBFD75DDF80000000, float 0xBFDE074F40000000, float 0x3FDDD3A1C0000000>
108+
llvm.mlir.global internal constant @dense_resource_vector_constant(dense_resource<dense_resource_test_5xf32> : vector<5xf32>) : vector<5xf32>
109+
110+
111+
// CHECK{LITERAL}: @dense_resource_multidim_tensor_constant = internal constant [1 x [2 x [2 x float]]] [[2 x [2 x float]] [[2 x float] [float 0x3FD6B46A80000000, float 0x3FD6781AC0000000], [2 x float] [float 0xBFB45A2AA0000000, float 0x3FD77A5CA0000000]]]
112+
llvm.mlir.global internal constant @dense_resource_multidim_tensor_constant(dense_resource<dense_resource_test_2x2xf32> : tensor<1x2x2xf32>) : !llvm.array<1 x !llvm.array<2 x !llvm.array<2 x f32>>>
113+
114+
// CHECK{LITERAL}: @dense_resource_multidim_vector_constant = internal constant [1 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 0x3FD6B46A80000000, float 0x3FD6781AC0000000>, <2 x float> <float 0xBFB45A2AA0000000, float 0x3FD77A5CA0000000>]]
115+
llvm.mlir.global internal constant @dense_resource_multidim_vector_constant(dense_resource<dense_resource_test_2x2xf32> : vector<1x2x2xf32>) : !llvm.array<1 x !llvm.array<2 x vector<2 x f32>>>
116+
104117
//
105118
// Linkage attribute.
106119
//
@@ -1577,6 +1590,16 @@ llvm.func @invokeLandingpad() -> i32 attributes { personality = @__gxx_personali
15771590
llvm.invoke %9(%6, %0) to ^bb2 unwind ^bb1 vararg(!llvm.func<void (ptr, ...)>) : !llvm.ptr, (!llvm.ptr, i32) -> ()
15781591
}
15791592

1593+
// Resources are kept at end of file. New tests should be added above this.
1594+
{-#
1595+
dialect_resources: {
1596+
builtin: {
1597+
dense_resource_test_5xf32: "0x08000000041A503E183382BEFCEEBABE7A3AF0BE0E9DEE3E",
1598+
dense_resource_test_2x2xf32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
1599+
}
1600+
}
1601+
#-}
1602+
15801603
// -----
15811604

15821605
llvm.func @foo() -> i8

0 commit comments

Comments
 (0)