Skip to content

[tosa] Change VariableOp to align with spec #142240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
input, paddings);
}]>;

// This builder is called on the TOSA variable operator with a variable type
// and optional initial value. The builder will extract var_shape and element type
// attributes from variable type.
def Tosa_VariableOpBuilder : OpBuilder<
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
[{
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
}]>;


// Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
Expand Down
15 changes: 11 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ class PatternRewriter;

namespace tosa {

ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr);
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr);
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
DenseElementsAttr &varShapeAttr,
TypeAttr &typeAttr,
Attribute &initialValueAttr);
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
DenseElementsAttr varShapeAttr,
TypeAttr typeAttr,
Attribute initialValueAttr);

#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"

Expand Down Expand Up @@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
int32_t val = 0);

// returns type of variable op
RankedTensorType getVariableType(VariableOp variableOp);

} // namespace tosa
} // namespace mlir

Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {

let arguments = (ins
SymbolNameAttr:$name,
IndexElementsAttr:$var_shape,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value
);
Expand All @@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
Extension<[Tosa_EXT_VARIABLE]>,
];

let hasCustomAssemblyFormat = 1;

let assemblyFormat = [{
$name
attr-dict
custom<TypeOrAttr>($type, $initial_value)
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
}];

let builders = [Tosa_VariableOpBuilder];

let hasVerifier = 1;
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {

LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto variableType = tosa::getVariableType(op);
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
Expand Down
132 changes: 105 additions & 27 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
return {&getBodyGraph()};
}

//===----------------------------------------------------------------------===//
// TOSA variable operator support.
//===----------------------------------------------------------------------===//

static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
}));
}

// returns type of variable op
RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
Type elementType = variableOp.getType();
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
return RankedTensorType::get(shape, elementType);
}

//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Parsers and printers
//===----------------------------------------------------------------------===//

ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &attr) {
namespace {

ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
DenseElementsAttr &varShapeAttr,
TypeAttr &typeAttr) {
if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
if (!shapedType.hasRank())
return parser.emitError(parser.getCurrentLocation())
<< "expected ranked type";

auto elementType = shapedType.getElementType();
typeAttr = TypeAttr::get(elementType);
ArrayRef<int64_t> shape = shapedType.getShape();
Builder builder(parser.getContext());
varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
return success();
}
return parser.emitError(parser.getCurrentLocation())
<< "expected shaped type";
}

} // namespace

// parses the optional initial value or type for a tosa variable
// with initial value:
// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
//
// without initial value:
// tosa.variable @name : tensor<1x8xf32>
ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
Attribute &initialValueAttr) {
if (succeeded(parser.parseOptionalEqual())) {
if (failed(parser.parseAttribute(attr))) {
if (failed(parser.parseAttribute(initialValueAttr))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
typeAttr = TypeAttr::get(typedAttr.getType());
if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
typeAttr);
}
return success();
return parser.emitError(parser.getCurrentLocation())
<< "expected Typed attr";
}

Type type;
if (failed(parser.parseColonType(type))) {
return parser.emitError(parser.getCurrentLocation()) << "expected type";
initialValueAttr = nullptr;
Type parsedType;
if (failed(parser.parseColonType(parsedType))) {
return parser.emitError(parser.getCurrentLocation())
<< "expected type after colon";
}
typeAttr = TypeAttr::get(type);

return success();
return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
}

void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
void mlir::tosa::printVariableOpTypeOrInitialValue(
OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
TypeAttr typeAttr, Attribute initialValueAttr) {
bool needsSpace = false;
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr || typedAttr.getType() != type.getValue()) {
if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
auto shape =
convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
Type elementType = typeAttr.getValue();
RankedTensorType tensorType =
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
auto tensorTypeAttr = TypeAttr::get(tensorType);
p << ": ";
p.printAttribute(type);
p.printAttribute(tensorTypeAttr);
needsSpace = true; // subsequent attr value needs a space separator
}
if (attr) {
if (initialValueAttr) {
if (needsSpace)
p << ' ';
p << "= ";
p.printAttribute(attr);
p.printAttribute(initialValueAttr);
}
}

Expand Down Expand Up @@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
<< symName << "' has not been declared by 'tosa.variable'";

// Verify type and shape
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
auto variableType = getVariableType(varOp.value());
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();

Expand Down Expand Up @@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}

static void buildVariableOp(OpBuilder &builder, OperationState &result,
StringRef name, Type variableType,
Attribute initialValue) {
const Location loc{result.location};
auto nameAttr = builder.getStringAttr(name);

auto shapedType = dyn_cast<ShapedType>(variableType);
if (!shapedType) {
(void)emitError(loc, "variable type must be a shaped type");
return;
}
if (!shapedType.hasRank()) {
(void)emitError(loc, "variable type must be a ranked type");
return;
}

auto elementType = shapedType.getElementType();
auto elementTypeAttr = TypeAttr::get(elementType);
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));

result.addAttribute("name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
}

//===----------------------------------------------------------------------===//
// TOSA Operator Return Type Inference.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
return success();
}

static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
}));
}

LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
Expand Down
11 changes: 2 additions & 9 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
::mlir::Attribute attr = op.getInitialValueAttr();
if (attr == nullptr)
return failure();

if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
addType(getElementTypeOrSelf(typedAttr));
return success();
}
return failure();
addType(op.getType());
return success();
}

template <>
Expand Down
Loading
Loading