Skip to content

Commit e493abc

Browse files
author
KareemErgawy
committed
[MLIR][SPIRV] Use getAsmResultName(...) hook for ConstantOp.
Implements better naming for results of `spv.Constant` ops by making it inherit from OpAsmOpInterface and implementing the associated getAsmResultName(...) hook. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D103152
1 parent ffc4d3e commit e493abc

File tree

5 files changed

+74
-2
lines changed

5 files changed

+74
-2
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1818
#include "mlir/IR/BuiltinOps.h"
19+
#include "mlir/IR/OpImplementation.h"
1920
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122
#include "llvm/Support/PointerLikeTypeTraits.h"

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS
1717

1818
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
19+
include "mlir/IR/OpAsmInterface.td"
1920
include "mlir/IR/SymbolInterfaces.td"
2021
include "mlir/Interfaces/CallInterfaces.td"
2122
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -67,7 +68,8 @@ def SPV_AddressOfOp : SPV_Op<"mlir.addressof", [InFunctionScope, NoSideEffect]>
6768

6869
// -----
6970

70-
def SPV_ConstantOp : SPV_Op<"Constant", [ConstantLike, NoSideEffect]> {
71+
def SPV_ConstantOp : SPV_Op<"Constant",
72+
[ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface>, NoSideEffect]> {
7173
let summary = "The op that declares a SPIR-V normal constant";
7274

7375
let description = [{

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,46 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
16501650
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
16511651
}
16521652

1653+
void mlir::spirv::ConstantOp::getAsmResultNames(
1654+
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1655+
Type type = getType();
1656+
1657+
SmallString<32> specialNameBuffer;
1658+
llvm::raw_svector_ostream specialName(specialNameBuffer);
1659+
specialName << "cst";
1660+
1661+
IntegerType intTy = type.dyn_cast<IntegerType>();
1662+
1663+
if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1664+
if (intTy && intTy.getWidth() == 1) {
1665+
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1666+
}
1667+
1668+
if (intTy.isSignless()) {
1669+
specialName << intCst.getInt();
1670+
} else {
1671+
specialName << intCst.getSInt();
1672+
}
1673+
}
1674+
1675+
if (intTy || type.isa<FloatType>()) {
1676+
specialName << '_' << type;
1677+
}
1678+
1679+
if (auto vecType = type.dyn_cast<VectorType>()) {
1680+
specialName << "_vec_";
1681+
specialName << vecType.getDimSize(0);
1682+
1683+
Type elementType = vecType.getElementType();
1684+
1685+
if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1686+
specialName << "x" << elementType;
1687+
}
1688+
}
1689+
1690+
setNameFn(getResult(), specialName.str());
1691+
}
1692+
16531693
//===----------------------------------------------------------------------===//
16541694
// spv.EntryPoint
16551695
//===----------------------------------------------------------------------===//
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt %s | FileCheck %s
2+
3+
func @const() -> () {
4+
// CHECK: %true
5+
%0 = spv.Constant true
6+
// CHECK: %false
7+
%1 = spv.Constant false
8+
9+
// CHECK: %cst42_i32
10+
%2 = spv.Constant 42 : i32
11+
// CHECK: %cst-42_i32
12+
%-2 = spv.Constant -42 : i32
13+
// CHECK: %cst43_i64
14+
%3 = spv.Constant 43 : i64
15+
16+
// CHECK: %cst_f32
17+
%4 = spv.Constant 0.5 : f32
18+
// CHECK: %cst_f64
19+
%5 = spv.Constant 0.5 : f64
20+
21+
// CHECK: %cst_vec_3xi32
22+
%6 = spv.Constant dense<[1, 2, 3]> : vector<3xi32>
23+
24+
// CHECK: %cst
25+
%8 = spv.Constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
26+
27+
return
28+
}

mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,9 @@ func @variable(%arg0: f32) -> () {
487487
// -----
488488

489489
func @variable_init_normal_constant() -> () {
490+
// CHECK: %[[cst:.*]] = spv.Constant
490491
%0 = spv.Constant 4.0 : f32
491-
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Function>
492+
// CHECK: spv.Variable init(%[[cst]]) : !spv.ptr<f32, Function>
492493
%1 = spv.Variable init(%0) : !spv.ptr<f32, Function>
493494
return
494495
}

0 commit comments

Comments
 (0)