Skip to content

Commit 4e032a0

Browse files
[SYCL-MLIR] Mimic codegen argument passing in operations definitions (#8333)
Do enforce SSA usage when defining SYCL dialect operations, but pass by pointer when appropriate. Signed-off-by: Victor Perez <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 0391711 commit 4e032a0

File tree

15 files changed

+253
-473
lines changed

15 files changed

+253
-473
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def SYCLCallOp : SYCL_Op<"call", [CallOpInterface]> {
687687
// accessor.subscript OPERATION
688688
////////////////////////////////////////////////////////////////////////////////
689689

690-
def SYCLAccessorSubscriptIndex : AnyTypeOf<[IndexType, SYCL_IDType]>;
690+
def SYCLAccessorSubscriptIndex : AnyTypeOf<[IndexType, IDMemRef]>;
691691

692692
def SYCLAccessorSubscriptOp
693693
: SYCLMethodOpInterfaceImpl<"accessor.subscript", "AccessorType",
@@ -697,8 +697,10 @@ def SYCLAccessorSubscriptOp
697697
This operation represents a call to the accessor::operator[] function.
698698
}];
699699

700-
let arguments = (ins SYCL_AccessorType:$Acc,
701-
SYCLAccessorSubscriptIndex:$Index,
700+
let arguments = (ins Arg<AccessorMemRef, "The offsetted accessor",
701+
[MemRead]>:$Acc,
702+
Arg<SYCLAccessorSubscriptIndex, "The offset",
703+
[MemRead]>:$Index,
702704
TypeArrayAttr:$ArgumentTypes,
703705
FlatSymbolRefAttr:$FunctionName,
704706
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -726,7 +728,7 @@ def SYCLRangeGetOp
726728
This operation represents a call to the range::get/operator[] functions.
727729
}];
728730

729-
let arguments = (ins SYCL_RangeType:$Range,
731+
let arguments = (ins Arg<RangeMemRef, "The input range", [MemRead]>:$Range,
730732
I32:$Index,
731733
TypeArrayAttr:$ArgumentTypes,
732734
FlatSymbolRefAttr:$FunctionName,
@@ -752,7 +754,7 @@ def SYCLRangeSizeOp
752754
This operation represents a call to the range::size[] function.
753755
}];
754756

755-
let arguments = (ins SYCL_RangeType:$Range,
757+
let arguments = (ins Arg<RangeMemRef, "The input range", [MemRead]>:$Range,
756758
TypeArrayAttr:$ArgumentTypes,
757759
FlatSymbolRefAttr:$FunctionName,
758760
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -773,7 +775,7 @@ def SYCLNdRangeGetGlobalRange
773775
This operation represents a call to the nd_range::get_global_range function.
774776
}];
775777

776-
let arguments = (ins SYCL_NdRangeType:$ND,
778+
let arguments = (ins Arg<NDRangeMemRef, "The input ND-range", [MemRead]>:$ND,
777779
TypeArrayAttr:$ArgumentTypes,
778780
FlatSymbolRefAttr:$FunctionName,
779781
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -794,7 +796,7 @@ def SYCLNdRangeGetLocalRange
794796
This operation represents a call to the nd_range::get_local_range function.
795797
}];
796798

797-
let arguments = (ins SYCL_NdRangeType:$ND,
799+
let arguments = (ins Arg<NDRangeMemRef, "The input ND-range", [MemRead]>:$ND,
798800
TypeArrayAttr:$ArgumentTypes,
799801
FlatSymbolRefAttr:$FunctionName,
800802
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -815,7 +817,7 @@ def SYCLNdRangeGetGroupRange
815817
This operation represents a call to the nd_range::get_group_range function.
816818
}];
817819

818-
let arguments = (ins SYCL_NdRangeType:$ND,
820+
let arguments = (ins Arg<NDRangeMemRef, "The input ND-range", [MemRead]>:$ND,
819821
TypeArrayAttr:$ArgumentTypes,
820822
FlatSymbolRefAttr:$FunctionName,
821823
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -837,7 +839,7 @@ def SYCLIDGetOp
837839
This operation represents a call to the id::get/operator[]/operator size_t functions.
838840
}];
839841

840-
let arguments = (ins SYCL_IDType:$ID,
842+
let arguments = (ins Arg<IDMemRef, "The input ID", [MemRead]>:$ID,
841843
Optional<I32>:$Index,
842844
TypeArrayAttr:$ArgumentTypes,
843845
FlatSymbolRefAttr:$FunctionName,
@@ -864,7 +866,7 @@ def SYCLItemGetIDOp
864866
This operation represents a call to the item::get_id/operator[]/operator size_t functions.
865867
}];
866868

867-
let arguments = (ins SYCL_ItemType:$Item,
869+
let arguments = (ins Arg<ItemMemRef, "The input item", [MemRead]>:$Item,
868870
Optional<I32>:$Index,
869871
TypeArrayAttr:$ArgumentTypes,
870872
FlatSymbolRefAttr:$FunctionName,
@@ -886,7 +888,7 @@ def SYCLItemGetRangeOp
886888
This operation represents a call to the item::get_range function.
887889
}];
888890

889-
let arguments = (ins SYCL_ItemType:$Item,
891+
let arguments = (ins Arg<ItemMemRef, "The input item", [MemRead]>:$Item,
890892
Optional<I32>:$Index,
891893
TypeArrayAttr:$ArgumentTypes,
892894
FlatSymbolRefAttr:$FunctionName,
@@ -908,7 +910,7 @@ def SYCLItemGetLinearIDOp
908910
This operation represents a call to the item::get_linear_id function.
909911
}];
910912

911-
let arguments = (ins SYCL_ItemType:$Item,
913+
let arguments = (ins Arg<ItemMemRef, "The input item", [MemRead]>:$Item,
912914
TypeArrayAttr:$ArgumentTypes,
913915
FlatSymbolRefAttr:$FunctionName,
914916
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -929,7 +931,7 @@ def SYCLNDItemGetGlobalIDOp
929931
This operation represents a call to the nd_item::get_global_id function.
930932
}];
931933

932-
let arguments = (ins SYCL_NdItemType:$NDItem,
934+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
933935
Optional<I32>:$Index,
934936
TypeArrayAttr:$ArgumentTypes,
935937
FlatSymbolRefAttr:$FunctionName,
@@ -951,7 +953,7 @@ def SYCLNDItemGetGlobalLinearIDOp
951953
This operation represents a call to the nd_item::get_global_linear_id function.
952954
}];
953955

954-
let arguments = (ins SYCL_NdItemType:$NDItem,
956+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
955957
TypeArrayAttr:$ArgumentTypes,
956958
FlatSymbolRefAttr:$FunctionName,
957959
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -972,7 +974,7 @@ def SYCLNDItemGetLocalIDOp
972974
This operation represents a call to the nd_item::get_local_id function.
973975
}];
974976

975-
let arguments = (ins SYCL_NdItemType:$NDItem,
977+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
976978
Optional<I32>:$Index,
977979
TypeArrayAttr:$ArgumentTypes,
978980
FlatSymbolRefAttr:$FunctionName,
@@ -994,7 +996,7 @@ def SYCLNDItemGetLocalLinearIDOp
994996
This operation represents a call to the nd_item::get_local_linear_id function.
995997
}];
996998

997-
let arguments = (ins SYCL_NdItemType:$NDItem,
999+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
9981000
TypeArrayAttr:$ArgumentTypes,
9991001
FlatSymbolRefAttr:$FunctionName,
10001002
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1015,7 +1017,7 @@ def SYCLNDItemGetGroupOp
10151017
This operation represents a call to the nd_item::get_group function.
10161018
}];
10171019

1018-
let arguments = (ins SYCL_NdItemType:$NDItem,
1020+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
10191021
Optional<I32>:$Index,
10201022
TypeArrayAttr:$ArgumentTypes,
10211023
FlatSymbolRefAttr:$FunctionName,
@@ -1037,7 +1039,7 @@ def SYCLNDItemGetGroupLinearIDOp
10371039
This operation represents a call to the nd_item::get_group_linear_id function.
10381040
}];
10391041

1040-
let arguments = (ins SYCL_NdItemType:$NDItem,
1042+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
10411043
TypeArrayAttr:$ArgumentTypes,
10421044
FlatSymbolRefAttr:$FunctionName,
10431045
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1059,7 +1061,7 @@ def SYCLNDItemGetGroupRangeOp
10591061
This operation represents a call to the nd_item::get_group_range function.
10601062
}];
10611063

1062-
let arguments = (ins SYCL_NdItemType:$NDItem,
1064+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
10631065
Optional<I32>:$Index,
10641066
TypeArrayAttr:$ArgumentTypes,
10651067
FlatSymbolRefAttr:$FunctionName,
@@ -1082,7 +1084,7 @@ def SYCLNDItemGetGlobalRangeOp
10821084
This operation represents a call to the nd_item::get_global_range function.
10831085
}];
10841086

1085-
let arguments = (ins SYCL_NdItemType:$NDItem,
1087+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
10861088
Optional<I32>:$Index,
10871089
TypeArrayAttr:$ArgumentTypes,
10881090
FlatSymbolRefAttr:$FunctionName,
@@ -1105,7 +1107,7 @@ def SYCLNDItemGetLocalRangeOp
11051107
This operation represents a call to the nd_item::get_local_range function.
11061108
}];
11071109

1108-
let arguments = (ins SYCL_NdItemType:$NDItem,
1110+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
11091111
Optional<I32>:$Index,
11101112
TypeArrayAttr:$ArgumentTypes,
11111113
FlatSymbolRefAttr:$FunctionName,
@@ -1127,7 +1129,7 @@ def SYCLNDItemGetNdRangeOp
11271129
This operation represents a call to the nd_item::get_nd_range function.
11281130
}];
11291131

1130-
let arguments = (ins SYCL_NdItemType:$NDItem,
1132+
let arguments = (ins Arg<NDItemMemRef, "The input ND-item", [MemRead]>:$NDItem,
11311133
TypeArrayAttr:$ArgumentTypes,
11321134
FlatSymbolRefAttr:$FunctionName,
11331135
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1149,7 +1151,7 @@ def SYCLGroupGetGroupIDOp
11491151
This operation represents a call to the group::get_group_id/operator[] functions.
11501152
}];
11511153

1152-
let arguments = (ins SYCL_GroupType:$Group,
1154+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
11531155
Optional<I32>:$Index,
11541156
TypeArrayAttr:$ArgumentTypes,
11551157
FlatSymbolRefAttr:$FunctionName,
@@ -1172,7 +1174,7 @@ def SYCLGroupGetLocalIDOp
11721174
This operation represents a call to the group::get_local_id function.
11731175
}];
11741176

1175-
let arguments = (ins SYCL_GroupType:$Group,
1177+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
11761178
Optional<I32>:$Index,
11771179
TypeArrayAttr:$ArgumentTypes,
11781180
FlatSymbolRefAttr:$FunctionName,
@@ -1195,7 +1197,7 @@ def SYCLGroupGetLocalRangeOp
11951197
This operation represents a call to the group::get_local_range function.
11961198
}];
11971199

1198-
let arguments = (ins SYCL_GroupType:$Group,
1200+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
11991201
Optional<I32>:$Index,
12001202
TypeArrayAttr:$ArgumentTypes,
12011203
FlatSymbolRefAttr:$FunctionName,
@@ -1218,7 +1220,7 @@ def SYCLGroupGetGroupRangeOp
12181220
This operation represents a call to the group::get_group_range function.
12191221
}];
12201222

1221-
let arguments = (ins SYCL_GroupType:$Group,
1223+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
12221224
Optional<I32>:$Index,
12231225
TypeArrayAttr:$ArgumentTypes,
12241226
FlatSymbolRefAttr:$FunctionName,
@@ -1240,7 +1242,7 @@ def SYCLGroupGetMaxLocalRangeOp
12401242
This operation represents a call to the group::get_max_local_range function.
12411243
}];
12421244

1243-
let arguments = (ins SYCL_GroupType:$Group,
1245+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
12441246
TypeArrayAttr:$ArgumentTypes,
12451247
FlatSymbolRefAttr:$FunctionName,
12461248
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1261,7 +1263,7 @@ def SYCLGroupGetGroupLinearIDOp
12611263
This operation represents a call to the group::get_group_linear_id function.
12621264
}];
12631265

1264-
let arguments = (ins SYCL_GroupType:$Group,
1266+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
12651267
TypeArrayAttr:$ArgumentTypes,
12661268
FlatSymbolRefAttr:$FunctionName,
12671269
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1282,7 +1284,7 @@ def SYCLGroupGetLocalLinearIDOp
12821284
This operation represents a call to the group::get_local_linear_id function.
12831285
}];
12841286

1285-
let arguments = (ins SYCL_GroupType:$Group,
1287+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
12861288
TypeArrayAttr:$ArgumentTypes,
12871289
FlatSymbolRefAttr:$FunctionName,
12881290
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1303,7 +1305,7 @@ def SYCLGroupGetLocalLinearIDOp
13031305
This operation represents a call to the group::get_group_linear_range function.
13041306
}];
13051307

1306-
let arguments = (ins SYCL_GroupType:$Group,
1308+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
13071309
TypeArrayAttr:$ArgumentTypes,
13081310
FlatSymbolRefAttr:$FunctionName,
13091311
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
@@ -1324,7 +1326,7 @@ def SYCLGroupGetLocalLinearRangeOp
13241326
This operation represents a call to the group::get_local_linear_range function.
13251327
}];
13261328

1327-
let arguments = (ins SYCL_GroupType:$Group,
1329+
let arguments = (ins Arg<GroupMemRef, "The input group", [MemRead]>:$Group,
13281330
TypeArrayAttr:$ArgumentTypes,
13291331
FlatSymbolRefAttr:$FunctionName,
13301332
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,

mlir-sycl/include/mlir/Dialect/SYCL/MethodUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ class OpBuilder;
2323
class ValueRange;
2424
namespace sycl {
2525
class SYCLMethodOpInterface;
26-
/// Generates a list of values to be used as arguments to a
27-
/// SYCLMethodOpInterface instance from \p Original.
28-
SmallVector<Value> adaptSYCLMethodOpArguments(OpBuilder &Builder, Location Loc,
29-
ValueRange Original);
3026
/// Abstracts different cast operations from which \p Original may have
3127
/// originated.
3228
Value abstractCasts(Value Original);

mlir-sycl/lib/Conversion/SYCLToGPU/SYCLToGPU.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ using gpu_counterpart_operation_t =
6666
Value createGetOp(OpBuilder &builder, Location loc, Type underlyingArrTy,
6767
Value res, Value index, ArrayAttr argumentTypes,
6868
FlatSymbolRefAttr functionName) {
69-
return TypeSwitch<Type, Value>(res.getType())
69+
return TypeSwitch<Type, Value>(
70+
res.getType().cast<MemRefType>().getElementType())
7071
.Case<IDType, RangeType>([&](auto arg) {
7172
// `this` type
7273
using ArgTy = decltype(arg);
@@ -143,13 +144,11 @@ void convertToFullObject(ConversionPatternRewriter &rewriter, StringRef opName,
143144
MemRefType::get(dimensions, targetIndexTy, {}, genericAddressSpace);
144145
// Allocate
145146
const auto resTy = op->getResultTypes()[0];
146-
const auto alloca = static_cast<Value>(
147-
rewriter.create<memref::AllocaOp>(loc, MemRefType::get(1, resTy)));
147+
const Value res =
148+
rewriter.create<memref::AllocaOp>(loc, MemRefType::get(1, resTy));
148149
// Load
149150
const auto zero =
150151
static_cast<Value>(rewriter.create<arith::ConstantIndexOp>(loc, 0));
151-
const auto res = static_cast<Value>(
152-
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, alloca, zero));
153152
const auto argumentTypes = rewriter.getTypeArrayAttr(
154153
{MemRefType::get(1, resTy, {}, genericAddressSpace), getIndexTy});
155154
const auto functionName = rewriter.getAttr<FlatSymbolRefAttr>("operator[]");
@@ -164,6 +163,7 @@ void convertToFullObject(ConversionPatternRewriter &rewriter, StringRef opName,
164163
argumentTypes, functionName);
165164
rewriter.create<AffineStoreOp>(loc, val, ptr, zero);
166165
}
166+
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, res, zero);
167167
}
168168

169169
template <typename OpTy, typename GPUOpTy = gpu_counterpart_operation_t<OpTy>>

0 commit comments

Comments
 (0)