Skip to content

Commit d00e6d0

Browse files
authored
[mlir][sparse] refine sparse assembler strategy (#80521)
Rewrite *all* public methods, making original internal, private methods, and exposing wrappers under the original name. This works a bit better in practice (when combined with c-interface mechanism of torch-mlir for example).
1 parent 032a70e commit d00e6d0

File tree

4 files changed

+113
-48
lines changed

4 files changed

+113
-48
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def SparseAssembler : Pass<"sparse-assembler", "ModuleOp"> {
1515
let summary = "Add [dis]assemble operations on external sparse tensors";
1616
let description = [{
1717
A pass that converts public entry methods that use sparse tensors as
18-
input parameters and/or output return values into wrapper functions
18+
input parameters and/or output return values into wrapper methods
1919
that [dis]assemble the individual tensors that constitute the actual
2020
storage used externally into MLIR sparse tensors. This pass can be used
2121
to prepare the public entry methods of a program that is compiled by the

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -132,29 +132,29 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
132132
namespace {
133133

134134
// A rewriting rules that converts public entry methods that use sparse tensors
135-
// as input parameters and/or output return values into wrapper functions
136-
// that [dis]assemble the individual tensors that constitute the actual
137-
// storage used externally into MLIR sparse tensors.
135+
// as input parameters and/or output return values into wrapper methods that
136+
// [dis]assemble the individual tensors that constitute the actual storage used
137+
// externally into MLIR sparse tensors before calling the original method.
138138
//
139139
// In particular, each sparse tensor input
140140
//
141141
// void foo(..., t, ...) { }
142142
//
143-
// adds the following strucuture in a wrapper
143+
// makes the original foo() internal and adds the following wrapper method
144144
//
145-
// void spiface_foo(..., t1..tn, ...) {
145+
// void foo(..., t1..tn, ...) {
146146
// t = assemble t1..tn
147-
// foo(..., t, ...)
147+
// _internal_foo(..., t, ...)
148148
// }
149149
//
150150
// and likewise, each output tensor
151151
//
152152
// ... T ... bar(...) { return ..., t, ...; }
153153
//
154-
// adds the following structure in a wrapper
154+
// makes the original bar() internal and adds the following wrapper method
155155
//
156-
// ... T1..TN ... spiface_bar(..., t1'..tn') {
157-
// ..., t, ... = bar(...)
156+
// ... T1..TN ... bar(..., t1'..tn') {
157+
// ..., t, ... = _internal_bar(...)
158158
// t1..tn = disassemble t, t1'..tn'
159159
// return ..., t1..tn, ...
160160
// }
@@ -168,9 +168,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
168168

169169
LogicalResult matchAndRewrite(func::FuncOp funcOp,
170170
PatternRewriter &rewriter) const override {
171-
// Only a rewrite an entry with the c-interface requested.
172-
if (!funcOp->getAttrOfType<UnitAttr>(
173-
LLVM::LLVMDialect::getEmitCWrapperAttrName()))
171+
// Only rewrite public entry methods.
172+
if (funcOp.isPrivate())
174173
return failure();
175174

176175
// Translate sparse tensor types to external types.
@@ -180,29 +179,29 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
180179
convTypes(funcOp.getArgumentTypes(), inputTypes);
181180
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes);
182181

183-
// Only sparse inputs or outputs need a wrapper function.
182+
// Only sparse inputs or outputs need a wrapper method.
184183
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
185184
outputTypes.size() == funcOp.getResultTypes().size())
186185
return failure();
187186

188-
// Start the new wrapper function. Together with the c-interface mangling,
189-
// a sparse external entry point eventually will have a name like:
190-
// _mlir_ciface_spiface_XXX(...)
187+
// Modify the original method into an internal, private method.
188+
auto orgName = funcOp.getName();
189+
std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
190+
funcOp.setName(wrapper);
191+
funcOp.setPrivate();
192+
193+
// Start the new public wrapper method with original name.
191194
Location loc = funcOp.getLoc();
192195
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
193196
MLIRContext *context = modOp.getContext();
194197
OpBuilder moduleBuilder(modOp.getBodyRegion());
195-
std::string wrapper = llvm::formatv("spiface_{0}", funcOp.getName()).str();
196198
unsigned extra = inputTypes.size();
197199
inputTypes.append(extraTypes);
198200
auto func = moduleBuilder.create<func::FuncOp>(
199-
loc, wrapper, FunctionType::get(context, inputTypes, outputTypes));
201+
loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
200202
func.setPublic();
201-
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
202-
UnitAttr::get(context));
203203

204-
// Construct new wrapper function body.
205-
auto org = SymbolRefAttr::get(context, funcOp.getName());
204+
// Construct new wrapper method body.
206205
OpBuilder::InsertionGuard insertionGuard(rewriter);
207206
Block *body = func.addEntryBlock();
208207
rewriter.setInsertionPointToStart(body);
@@ -212,7 +211,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
212211
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
213212
ValueRange(), inputs, 0, /*isIn=*/true);
214213

215-
// Call original function.
214+
// Call original, now internal method.
215+
auto org = SymbolRefAttr::get(context, wrapper);
216216
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
217217
inputs);
218218

@@ -222,8 +222,13 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
222222
body->getArguments(), outputs, extra, /*isIn=*/false);
223223
rewriter.create<func::ReturnOp>(loc, outputs);
224224

225-
// Strip the c-interface attribute from the original function.
226-
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
225+
// Finally, migrate a potential c-interface property.
226+
if (funcOp->getAttrOfType<UnitAttr>(
227+
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
228+
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
229+
UnitAttr::get(context));
230+
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
231+
}
227232
return success();
228233
}
229234
};

mlir/test/Dialect/SparseTensor/external.mlir

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,95 +3,100 @@
33
// -----
44

55
// CHECK-LABEL: func.func @nop(
6-
// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> attributes {llvm.emit_c_interface} {
6+
// CHECK-SAME: %[[A:.*]]: tensor<100xf32>) -> tensor<100xf32> {
77
// CHECK: return %[[A]] : tensor<100xf32>
88
// CHECK: }
9-
func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes { llvm.emit_c_interface } {
9+
func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
1010
return %arg0 : tensor<100xf32>
1111
}
1212

1313
// -----
1414

15-
// CHECK-LABEL: func.func @spiface_sparse_in(
15+
// CHECK-LABEL: func.func @sparse_in(
1616
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
1717
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
18-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
18+
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
1919
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
20-
// CHECK: %[[F:.*]] = call @sparse_in(%[[I]])
20+
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
2121
// CHECK: return %[[F]] : tensor<64x64xf32>
2222
// CHECK: }
23+
// CHECK: func.func private @_internal_sparse_in
2324
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
24-
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
25+
func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
2526
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
2627
return %0 : tensor<64x64xf32>
2728
}
2829

2930
// -----
3031

31-
// CHECK-LABEL: func.func @spiface_sparse_in2(
32+
// CHECK-LABEL: func.func @sparse_in2(
3233
// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
3334
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
3435
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
35-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
36+
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
3637
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
37-
// CHECK: %[[F:.*]] = call @sparse_in2(%[[X]], %[[I]])
38+
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
3839
// CHECK: return %[[F]] : tensor<64x64xf32>
3940
// CHECK: }
41+
// CHECK: func.func private @_internal_sparse_in2
4042
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
41-
func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> attributes { llvm.emit_c_interface } {
43+
func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
4244
%0 = sparse_tensor.convert %arg1 : tensor<64x64xf32, #sparse> to tensor<64x64xf32>
4345
return %0 : tensor<64x64xf32>
4446
}
4547

4648
// -----
4749

48-
// CHECK-LABEL: func.func @spiface_sparse_out(
50+
// CHECK-LABEL: func.func @sparse_out(
4951
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
5052
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
5153
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
52-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
53-
// CHECK: %[[F:.*]] = call @sparse_out(%[[X]])
54+
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
55+
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
5456
// CHECK: sparse_tensor.disassemble %[[F]]
5557
// CHECK: return
5658
// CHECK: }
59+
// CHECK: func.func private @_internal_sparse_out
5760
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
58-
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
61+
func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
5962
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
6063
return %0 : tensor<64x64xf32, #sparse>
6164
}
6265

6366
// -----
6467

65-
// CHECK-LABEL: func.func @spiface_sparse_out2(
68+
// CHECK-LABEL: func.func @sparse_out2(
6669
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
6770
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
6871
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
69-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
70-
// CHECK: %[[F:.*]]:2 = call @sparse_out2(%[[X]])
72+
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
73+
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
7174
// CHECK: sparse_tensor.disassemble %[[F]]#1
7275
// CHECK: return %[[F]]#0
7376
// CHECK: }
77+
// CHECK: func.func private @_internal_sparse_out2
7478
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
75-
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) attributes { llvm.emit_c_interface } {
79+
func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64x64xf32, #sparse>) {
7680
%0 = sparse_tensor.convert %arg0 : tensor<64x64xf32> to tensor<64x64xf32, #sparse>
7781
return %arg0, %0 : tensor<64x64xf32>, tensor<64x64xf32, #sparse>
7882
}
7983

8084
// -----
8185

82-
// CHECK-LABEL: func.func @spiface_sparse_inout(
86+
// CHECK-LABEL: func.func @sparse_inout(
8387
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
8488
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
8589
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
8690
// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
8791
// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
88-
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) attributes {llvm.emit_c_interface} {
92+
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
8993
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
90-
// CHECK: %[[F:.*]] = call @sparse_inout(%[[I]])
94+
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
9195
// CHECK: sparse_tensor.disassemble %[[F]]
9296
// CHECK: return
9397
// CHECK: }
98+
// CHECK: func.func private @_internal_sparse_inout
9499
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
95-
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> attributes { llvm.emit_c_interface } {
100+
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
96101
return %arg0 : tensor<64x64xf32, #sparse>
97102
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
2+
// RUN: mlir-opt %s --sparse-assembler \
3+
// RUN: --linalg-generalize-named-ops \
4+
// RUN: --linalg-fuse-elementwise-ops \
5+
// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
6+
// RUN: mlir-opt %s --sparse-assembler \
7+
// RUN: --sparsifier | FileCheck %s --check-prefix=CHECK-LOW
8+
9+
//
10+
// An example of a module generated by torch-mlir with a sparse tensor from
11+
// torch.sparse. The MLIR sparsifier should be able to provide the external
12+
// API through a wrapper method (spiface and ciface). Various passes should
13+
// compose without trouble.
14+
//
15+
16+
// CHECK-HI-LABEL: func.func @main
17+
// CHECK-HI: sparse_tensor.assemble
18+
// CHECK-HI: call @_internal_main
19+
// CHECK-HI: return
20+
// CHECK-HI: func.func private @_internal_main
21+
// CHECK-HI: linalg.matmul
22+
// CHECK-HI: return
23+
//
24+
// CHECK-MID-LABEL: func.func @main
25+
// CHECK-MID: memref.load
26+
// CHECK-MID: call @_internal_main
27+
// CHECK-MID: return
28+
// CHECK-MID: func.func private @_internal_main
29+
// CHECK-MID: scf.for
30+
// CHECK-MID: scf.for
31+
// CHECK-MID: return
32+
33+
// CHECK-LOW-LABEL: llvm.func @main
34+
// CHECK-LOW: llvm.call @_internal_main
35+
// CHECK-LOW: llvm.return
36+
// CHECK-LOW: llvm.func @_mlir_ciface_main
37+
// CHECK-LOW: llvm.call @main
38+
// CHECK-LOW: llvm.return
39+
// CHECK-LOW: llvm.func @_internal_main
40+
// CHECK-SAME: {sym_visibility = "private"}
41+
// CHECK-LOW: llvm.return
42+
43+
#csc = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
44+
module {
45+
func.func @main(%arg0: tensor<64x64xf32, #csc>,
46+
%arg1: tensor<64x64xf32>) -> tensor<64x64xf32> attributes {llvm.emit_c_interface} {
47+
%cst = arith.constant 0.000000e+00 : f32
48+
%0 = tensor.empty() : tensor<64x64xf32>
49+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x64xf32>) -> tensor<64x64xf32>
50+
%2 = linalg.matmul
51+
ins(%arg0, %arg1 : tensor<64x64xf32, #csc>, tensor<64x64xf32>)
52+
outs(%1 : tensor<64x64xf32>) -> tensor<64x64xf32>
53+
return %2 : tensor<64x64xf32>
54+
}
55+
}

0 commit comments

Comments
 (0)