Skip to content

Commit ad96d6e

Browse files
committed
ops with CallOpInterface must have two new optional attrs arg_attrs and res_attrs
llvm/llvm-project#123176
1 parent 26ffeb0 commit ad96d6e

File tree

5 files changed

+27
-10
lines changed

5 files changed

+27
-10
lines changed

mlir/include/Catalyst/IR/CatalystOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def CallbackCallOp : Catalyst_Op<"callback_call",
170170

171171
let arguments = (ins
172172
FlatSymbolRefAttr:$callee,
173-
Variadic<AnyTypeOf<[AnyRankedTensor, MemRefOf<[AnyType]>]>>:$inputs
173+
Variadic<AnyTypeOf<[AnyRankedTensor, MemRefOf<[AnyType]>]>>:$inputs,
174+
OptionalAttr<DictArrayAttr>:$arg_attrs,
175+
OptionalAttr<DictArrayAttr>:$res_attrs
174176
);
175177

176178
let results = (outs Variadic<AnyType>);
@@ -188,7 +190,9 @@ def LaunchKernelOp : Catalyst_Op<"launch_kernel",
188190

189191
let arguments = (ins
190192
SymbolRefAttr:$callee,
191-
Variadic<AnyTypeOf<[AnyRankedTensor, MemRefOf<[AnyType]>]>>:$inputs
193+
Variadic<AnyTypeOf<[AnyRankedTensor, MemRefOf<[AnyType]>]>>:$inputs,
194+
OptionalAttr<DictArrayAttr>:$arg_attrs,
195+
OptionalAttr<DictArrayAttr>:$res_attrs
192196
);
193197

194198
let results = (outs Variadic<AnyType>);

mlir/include/Gradient/IR/GradientOps.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def GradOp : Gradient_Op<"grad", [
5858
SymbolRefAttr:$callee,
5959
Variadic<AnyType>:$operands,
6060
OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
61-
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
61+
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam,
62+
OptionalAttr<DictArrayAttr>:$arg_attrs,
63+
OptionalAttr<DictArrayAttr>:$res_attrs
6264
);
6365
let results = (outs Variadic<AnyTypeOf<[AnyFloat, RankedTensorOf<[AnyFloat]>]>>);
6466

@@ -82,7 +84,9 @@ def ValueAndGradOp : Gradient_Op<"value_and_grad", [
8284
SymbolRefAttr:$callee,
8385
Variadic<AnyType>:$operands,
8486
OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
85-
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
87+
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam,
88+
OptionalAttr<DictArrayAttr>:$arg_attrs,
89+
OptionalAttr<DictArrayAttr>:$res_attrs
8690
);
8791

8892
let results = (outs
@@ -184,7 +188,9 @@ def JVPOp : Gradient_Op<"jvp", [
184188
Variadic<AnyType>:$params,
185189
Variadic<AnyType>:$tangents,
186190
OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
187-
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
191+
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam,
192+
OptionalAttr<DictArrayAttr>:$arg_attrs,
193+
OptionalAttr<DictArrayAttr>:$res_attrs
188194
);
189195

190196
let results = (outs
@@ -217,7 +223,9 @@ def VJPOp : Gradient_Op<"vjp", [
217223
Variadic<AnyType>:$params,
218224
Variadic<AnyType>:$cotangents,
219225
OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
220-
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
226+
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam,
227+
OptionalAttr<DictArrayAttr>:$arg_attrs,
228+
OptionalAttr<DictArrayAttr>:$res_attrs
221229
);
222230

223231
let results = (outs

mlir/include/Mitigation/IR/MitigationOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def ZneOp : Mitigation_Op<"zne", [DeclareOpInterfaceMethods<CallOpInterface>,
4848
SymbolRefAttr:$callee,
4949
Variadic<AnyType>:$args,
5050
FoldingAttr:$folding,
51-
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$numFolds
51+
RankedTensorOf<[AnySignlessIntegerOrIndex]>:$numFolds,
52+
OptionalAttr<DictArrayAttr>:$arg_attrs,
53+
OptionalAttr<DictArrayAttr>:$res_attrs
5254
);
5355
let results = (outs Variadic<AnyTypeOf<[AnyFloat, RankedTensorOf<[AnyFloat]>]>>);
5456

mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,8 @@ struct CallbackCallOpInterface
321321
}
322322

323323
SmallVector<Type> emptyRets;
324-
rewriter.create<CallbackCallOp>(loc, emptyRets, callOp.getCallee(), newInputs);
324+
rewriter.create<CallbackCallOp>(loc, emptyRets, callOp.getCallee(), newInputs,
325+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
325326
bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs);
326327
return success();
327328
}

mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew
9797

9898
auto gradOp = rewriter.create<GradOp>(loc, grad_result_types, op.getMethod(), op.getCallee(),
9999
calleeOperands, op.getDiffArgIndicesAttr(),
100-
op.getFiniteDiffParamAttr());
100+
op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr,
101+
/*res_attrs=*/nullptr);
101102

102103
std::vector<Value> einsumResults;
103104
for (size_t nout = 0; nout < funcResultTypes.size(); nout++) {
@@ -219,7 +220,8 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew
219220

220221
auto gradOp = rewriter.create<GradOp>(loc, grad_result_types, op.getMethod(), op.getCallee(),
221222
calleeOperands, op.getDiffArgIndicesAttr(),
222-
op.getFiniteDiffParamAttr());
223+
op.getFiniteDiffParamAttr(), /*arg_attrs=*/nullptr,
224+
/*res_attrs=*/nullptr);
223225

224226
std::vector<Value> einsumResults;
225227
for (size_t nparam = 0; nparam < func_diff_operand_indices.size(); nparam++) {

0 commit comments

Comments
 (0)