Skip to content

Commit eeeef09

Browse files
River707jpienaar
authored andcommitted
Set the namespace of the StandardOps dialect to "std", but add a special case to the parser to allow parsing standard operations without the "std" prefix. This will now allow for the standard dialect to be looked up dynamically by name.
PiperOrigin-RevId: 236493865
1 parent eee8536 commit eeeef09

File tree

18 files changed

+309
-258
lines changed

18 files changed

+309
-258
lines changed

mlir/bindings/python/test/test_py2and3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def testCustomOp(self):
3434
with E.ContextManager():
3535
a, b = (E.Expr(E.Bindable(self.i32Type)) for _ in range(2))
3636
c1 = self.module.op(
37-
"constant",
37+
"std.constant",
3838
self.i32Type, [],
3939
value=self.module.integerAttr(self.i32Type, 42))
40-
expr = self.module.op("addi", self.i32Type, [c1, b])
40+
expr = self.module.op("std.addi", self.i32Type, [c1, b])
4141
str = expr.__str__()
4242
self.assertIn("addi(42, $2)", str)
4343

@@ -99,7 +99,7 @@ def testIndexed(self):
9999
A, B, C = list(map(E.Indexed, [E.Bindable(memrefType) for _ in range(3)]))
100100
stmt = C.store([i, j], A.load([i, k]) * B.load([k, j]))
101101
str = stmt.__str__()
102-
self.assertIn(" = store(", str)
102+
self.assertIn(" = std.store(", str)
103103

104104
def testMatmul(self):
105105
with E.ContextManager():
@@ -118,7 +118,7 @@ def testMatmul(self):
118118
self.assertIn("for($1 = $4 to $7 step $10) {", str)
119119
self.assertIn("for($2 = $5 to $8 step $11) {", str)
120120
self.assertIn("for($3 = $6 to $9 step $12) {", str)
121-
self.assertIn(" = store", str)
121+
self.assertIn(" = std.store", str)
122122

123123
def testArithmetic(self):
124124
with E.ContextManager():
@@ -353,10 +353,10 @@ def testCustomOpCompilation(self):
353353
emitter = E.MLIRFunctionEmitter(f)
354354
funcArg, = emitter.bind_function_arguments()
355355
c1 = self.module.op(
356-
"constant",
356+
"std.constant",
357357
self.i32Type, [],
358358
value=self.module.integerAttr(self.i32Type, 42))
359-
expr = self.module.op("addi", self.i32Type, [c1, funcArg])
359+
expr = self.module.op("std.addi", self.i32Type, [c1, funcArg])
360360
block = E.Block([E.Stmt(expr), E.Return()])
361361
emitter.emit_inplace(block)
362362
self.module.compile()

mlir/g3doc/LangRef.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,8 +1391,8 @@ Examples:
13911391
%y = dim %A, 1 : tensor<4 x ? x f32>
13921392
13931393
// Equivalent generic form:
1394-
%x = "dim"(%A){index: 0} : (tensor<4 x ? x f32>) -> index
1395-
%y = "dim"(%A){index: 1} : (tensor<4 x ? x f32>) -> index
1394+
%x = "std.dim"(%A){index: 0} : (tensor<4 x ? x f32>) -> index
1395+
%y = "std.dim"(%A){index: 1} : (tensor<4 x ? x f32>) -> index
13961396
```
13971397

13981398
#### 'reshape' operation {#'reshape'-operation}
@@ -1461,7 +1461,7 @@ Example:
14611461
// sized base memref with the slice size being half the base memref's
14621462
// dynamic size and with an offset of %0
14631463
#map_c = (i, j)[s0, s1]->(i + s0, j) size(4, s1)
1464-
%s1 = "divi"(%n1, 2) : (i32, i32) -> i32
1464+
%s1 = "std.divi"(%n1, 2) : (i32, i32) -> i32
14651465
%C = view memref<16x?xf32, #map_a, hbm> -> memref<4x?xf32, #map_c, hbm>
14661466
(%s1) [%0, %n1] %A : memref<16x?xf32, #map_a, hbm>
14671467
```
@@ -1807,13 +1807,13 @@ Examples:
18071807
%x = cmpi "slt", %lhs, %rhs : i32
18081808
18091809
// Generic form of the same operation.
1810-
%x = "cmpi"(%lhs, %rhs){predicate: 2} : (i32, i32) -> i1
1810+
%x = "std.cmpi"(%lhs, %rhs){predicate: 2} : (i32, i32) -> i1
18111811
18121812
// Custom form of vector equality comparison.
18131813
%x = cmpi "eq", %lhs, %rhs : vector<4xi64>
18141814
18151815
// Generic form of the same operation.
1816-
%x = "cmpi"(%lhs, %rhs){predicate: 0}
1816+
%x = "std.cmpi"(%lhs, %rhs){predicate: 0}
18171817
: (vector<4xi64>, vector<4xi64> -> vector<4xi1>
18181818
```
18191819

@@ -1885,8 +1885,8 @@ Examples:
18851885
%3 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
18861886
18871887
// Equivalent generic forms
1888-
%1 = "constant"(){value: 42} : i32
1889-
%3 = "constant"(){value: @myfn}
1888+
%1 = "std.constant"(){value: 42} : i32
1889+
%3 = "std.constant"(){value: @myfn}
18901890
: () -> (tensor<16xf32>, f32) -> tensor<16xf32>
18911891
18921892
```
@@ -2089,10 +2089,10 @@ Examples:
20892089
%x = select %cond, %true, %false : i32
20902090
20912091
// Generic form of the same operation.
2092-
%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32
2092+
%x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
20932093
20942094
// Vector selection is element-wise
2095-
%vx = "select"(%vcond, %vtrue, %vfalse)
2095+
%vx = "std.select"(%vcond, %vtrue, %vfalse)
20962096
: (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
20972097
```
20982098

@@ -2120,15 +2120,15 @@ Examples:
21202120

21212121
```mlir {.mlir}
21222122
// Convert from unknown rank to rank 2 with unknown dimension sizes.
2123-
%2 = "tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
2123+
%2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor<?x?xf32>
21242124
%2 = tensor_cast %1 : tensor<*xf32> to tensor<?x?xf32>
21252125
21262126
// Convert to a type with more known dimensions.
2127-
%3 = "tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
2127+
%3 = "std.tensor_cast"(%2) : (tensor<?x?xf32>) -> tensor<4x?xf32>
21282128
21292129
// Discard static dimension and rank information.
2130-
%4 = "tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
2131-
%5 = "tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
2130+
%4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor<?x?xf32>
2131+
%5 = "std.tensor_cast"(%4) : (tensor<?x?xf32>) -> tensor<*xf32>
21322132
```
21332133

21342134
Convert a tensor from one type to an equivalent type without changing any data

mlir/include/mlir/StandardOps/Ops.h

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ namespace mlir {
3232
class AffineMap;
3333
class Builder;
3434

35+
namespace detail {
36+
/// A custom binary operation printer that omits the "std." prefix from the
37+
/// operation names.
38+
void printStandardBinaryOp(const Instruction *op, OpAsmPrinter *p);
39+
} // namespace detail
40+
3541
class StandardOpsDialect : public Dialect {
3642
public:
3743
StandardOpsDialect(MLIRContext *context);
@@ -67,7 +73,7 @@ class AllocOp
6773
return getResult()->getType().cast<MemRefType>();
6874
}
6975

70-
static StringRef getOperationName() { return "alloc"; }
76+
static StringRef getOperationName() { return "std.alloc"; }
7177

7278
// Hooks to customize behavior of this op.
7379
static void build(Builder *builder, OperationState *result,
@@ -96,7 +102,7 @@ class AllocOp
96102
class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
97103
OpTrait::ZeroResult, OpTrait::IsTerminator> {
98104
public:
99-
static StringRef getOperationName() { return "br"; }
105+
static StringRef getOperationName() { return "std.br"; }
100106

101107
static void build(Builder *builder, OperationState *result, Block *dest,
102108
ArrayRef<Value *> operands = {});
@@ -129,7 +135,7 @@ class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
129135
class CallOp
130136
: public Op<CallOp, OpTrait::VariadicOperands, OpTrait::VariadicResults> {
131137
public:
132-
static StringRef getOperationName() { return "call"; }
138+
static StringRef getOperationName() { return "std.call"; }
133139

134140
static void build(Builder *builder, OperationState *result, Function *callee,
135141
ArrayRef<Value *> operands);
@@ -173,7 +179,7 @@ class CallOp
173179
class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
174180
OpTrait::VariadicResults> {
175181
public:
176-
static StringRef getOperationName() { return "call_indirect"; }
182+
static StringRef getOperationName() { return "std.call_indirect"; }
177183

178184
static void build(Builder *builder, OperationState *result, Value *callee,
179185
ArrayRef<Value *> operands);
@@ -245,7 +251,7 @@ enum class CmpIPredicate {
245251
///
246252
/// %r1 = cmpi "eq" %0, %1 : i32
247253
/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
248-
/// %r3 = "cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
254+
/// %r3 = "std.cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
249255
class CmpIOp
250256
: public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
251257
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
@@ -257,7 +263,7 @@ class CmpIOp
257263
.getInt();
258264
}
259265

260-
static StringRef getOperationName() { return "cmpi"; }
266+
static StringRef getOperationName() { return "std.cmpi"; }
261267
static StringRef getPredicateAttrName() { return "predicate"; }
262268
static CmpIPredicate getPredicateByName(StringRef name);
263269

@@ -296,7 +302,7 @@ class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
296302
/// follows:
297303
/// { condition, [true_operands], [false_operands] }
298304
public:
299-
static StringRef getOperationName() { return "cond_br"; }
305+
static StringRef getOperationName() { return "std.cond_br"; }
300306

301307
static void build(Builder *builder, OperationState *result, Value *condition,
302308
Block *trueDest, ArrayRef<Value *> trueOperands,
@@ -416,8 +422,8 @@ class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
416422
/// The "constant" operation requires a single attribute named "value".
417423
/// It returns its value as an SSA value. For example:
418424
///
419-
/// %1 = "constant"(){value: 42} : i32
420-
/// %2 = "constant"(){value: @foo} : (f32)->f32
425+
/// %1 = "std.constant"(){value: 42} : i32
426+
/// %2 = "std.constant"(){value: @foo} : (f32)->f32
421427
///
422428
class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
423429
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
@@ -432,7 +438,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
432438

433439
Attribute getValue() const { return getAttr("value"); }
434440

435-
static StringRef getOperationName() { return "constant"; }
441+
static StringRef getOperationName() { return "std.constant"; }
436442

437443
// Hooks to customize behavior of this op.
438444
static bool parse(OpAsmParser *parser, OperationState *result);
@@ -449,7 +455,7 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
449455
/// This is a refinement of the "constant" op for the case where it is
450456
/// returning a float value of FloatType.
451457
///
452-
/// %1 = "constant"(){value: 42.0} : bf16
458+
/// %1 = "std.constant"(){value: 42.0} : bf16
453459
///
454460
class ConstantFloatOp : public ConstantOp {
455461
public:
@@ -471,7 +477,7 @@ class ConstantFloatOp : public ConstantOp {
471477
/// This is a refinement of the "constant" op for the case where it is
472478
/// returning an integer value of IntegerType.
473479
///
474-
/// %1 = "constant"(){value: 42} : i32
480+
/// %1 = "std.constant"(){value: 42} : i32
475481
///
476482
class ConstantIntOp : public ConstantOp {
477483
public:
@@ -498,7 +504,7 @@ class ConstantIntOp : public ConstantOp {
498504
/// This is a refinement of the "constant" op for the case where it is
499505
/// returning an integer value of Index type.
500506
///
501-
/// %1 = "constant"(){value: 99} : () -> index
507+
/// %1 = "std.constant"(){value: 99} : () -> index
502508
///
503509
class ConstantIndexOp : public ConstantOp {
504510
public:
@@ -533,7 +539,7 @@ class DeallocOp
533539
const Value *getMemRef() const { return getOperand(); }
534540
void setMemRef(Value *value) { setOperand(value); }
535541

536-
static StringRef getOperationName() { return "dealloc"; }
542+
static StringRef getOperationName() { return "std.dealloc"; }
537543

538544
// Hooks to customize behavior of this op.
539545
static void build(Builder *builder, OperationState *result, Value *memref);
@@ -568,7 +574,7 @@ class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
568574
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
569575
}
570576

571-
static StringRef getOperationName() { return "dim"; }
577+
static StringRef getOperationName() { return "std.dim"; }
572578

573579
// Hooks to customize behavior of this op.
574580
bool verify() const;
@@ -706,7 +712,7 @@ class DmaStartOp
706712
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
707713
}
708714

709-
static StringRef getOperationName() { return "dma_start"; }
715+
static StringRef getOperationName() { return "std.dma_start"; }
710716
static bool parse(OpAsmParser *parser, OperationState *result);
711717
void print(OpAsmPrinter *p) const;
712718
bool verify() const;
@@ -761,7 +767,7 @@ class DmaWaitOp
761767
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
762768
ArrayRef<Value *> tagIndices, Value *numElements);
763769

764-
static StringRef getOperationName() { return "dma_wait"; }
770+
static StringRef getOperationName() { return "std.dma_wait"; }
765771

766772
// Returns the Tag MemRef associated with the DMA operation being waited on.
767773
const Value *getTagMemRef() const { return getOperand(0); }
@@ -825,7 +831,7 @@ class ExtractElementOp
825831
getInstruction()->operand_end()};
826832
}
827833

828-
static StringRef getOperationName() { return "extract_element"; }
834+
static StringRef getOperationName() { return "std.extract_element"; }
829835

830836
// Hooks to customize behavior of this op.
831837
bool verify() const;
@@ -871,7 +877,7 @@ class LoadOp
871877
getInstruction()->operand_end()};
872878
}
873879

874-
static StringRef getOperationName() { return "load"; }
880+
static StringRef getOperationName() { return "std.load"; }
875881

876882
bool verify() const;
877883
static bool parse(OpAsmParser *parser, OperationState *result);
@@ -901,13 +907,15 @@ class LoadOp
901907
///
902908
class MemRefCastOp : public CastOp<MemRefCastOp> {
903909
public:
904-
static StringRef getOperationName() { return "memref_cast"; }
910+
static StringRef getOperationName() { return "std.memref_cast"; }
905911

906912
/// The result of a memref_cast is always a memref.
907913
MemRefType getType() const {
908914
return getResult()->getType().cast<MemRefType>();
909915
}
910916

917+
void print(OpAsmPrinter *p) const;
918+
911919
bool verify() const;
912920

913921
private:
@@ -920,14 +928,14 @@ class MemRefCastOp : public CastOp<MemRefCastOp> {
920928
/// The operand number and types must match the signature of the function
921929
/// that contains the operation. For example:
922930
///
923-
/// mlfunc @foo() : (i32, f8) {
931+
/// func @foo() : (i32, f8) {
924932
/// ...
925933
/// return %0, %1 : i32, f8
926934
///
927935
class ReturnOp : public Op<ReturnOp, OpTrait::VariadicOperands,
928936
OpTrait::ZeroResult, OpTrait::IsTerminator> {
929937
public:
930-
static StringRef getOperationName() { return "return"; }
938+
static StringRef getOperationName() { return "std.return"; }
931939

932940
static void build(Builder *builder, OperationState *result,
933941
ArrayRef<Value *> results = {});
@@ -956,7 +964,7 @@ class ReturnOp : public Op<ReturnOp, OpTrait::VariadicOperands,
956964
class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
957965
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
958966
public:
959-
static StringRef getOperationName() { return "select"; }
967+
static StringRef getOperationName() { return "std.select"; }
960968
static void build(Builder *builder, OperationState *result, Value *condition,
961969
Value *trueValue, Value *falseValue);
962970
static bool parse(OpAsmParser *parser, OperationState *result);
@@ -1015,7 +1023,7 @@ class StoreOp
10151023
getInstruction()->operand_end()};
10161024
}
10171025

1018-
static StringRef getOperationName() { return "store"; }
1026+
static StringRef getOperationName() { return "std.store"; }
10191027

10201028
bool verify() const;
10211029
static bool parse(OpAsmParser *parser, OperationState *result);
@@ -1041,13 +1049,15 @@ class StoreOp
10411049
///
10421050
class TensorCastOp : public CastOp<TensorCastOp> {
10431051
public:
1044-
static StringRef getOperationName() { return "tensor_cast"; }
1052+
static StringRef getOperationName() { return "std.tensor_cast"; }
10451053

10461054
/// The result of a tensor_cast is always a tensor.
10471055
TensorType getType() const {
10481056
return getResult()->getType().cast<TensorType>();
10491057
}
10501058

1059+
void print(OpAsmPrinter *p) const;
1060+
10511061
bool verify() const;
10521062

10531063
private:

0 commit comments

Comments
 (0)