Skip to content

Commit 9eb3039

Browse files
[SYCL-MLIR] Add memref of sub_group and minimum to list of SYCL memrefs (#8079)
This PR makes the following tests pass: ``` SYCL :: SubGroup/broadcast_fp64.cpp SYCL :: SubGroup/reduce_spirv13_fp16.cpp SYCL :: SubGroup/reduce_spirv13_fp64.cpp SYCL :: SubGroup/scan_spirv13_fp16.cpp ``` Signed-off-by: Tsang, Whitney <[email protected]>
1 parent d356268 commit 9eb3039

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,14 @@ def ItemMemRef : MemRefOf<[SYCL_ItemType]>;
367367
def LocalAccessorBaseDeviceMemRef : MemRefOf<[SYCL_LocalAccessorBaseDeviceType]>;
368368
def LocalAccessorBaseMemRef : MemRefOf<[SYCL_LocalAccessorBaseType]>;
369369
def LocalAccessorMemRef : MemRefOf<[SYCL_LocalAccessorType]>;
370+
def MinimumMemRef : MemRefOf<[SYCL_MinimumType]>;
370371
def MultiPtrMemRef : MemRefOf<[SYCL_MultiPtrType]>;
371372
def NDItemMemRef : MemRefOf<[SYCL_NdItemType]>;
372373
def NDRangeMemRef : MemRefOf<[SYCL_NdRangeType]>;
373374
def OwnerLessBaseMemRef : MemRefOf<[SYCL_OwnerLessBaseType]>;
374375
def RangeMemRef : MemRefOf<[SYCL_RangeType]>;
375376
def StreamMemRef : MemRefOf<[SYCL_StreamType]>;
377+
def SubGroupMemRef : MemRefOf<[SYCL_SubGroupType]>;
376378
def SwizzledVecMemRef : MemRefOf<[SYCL_SwizzledVecType]>;
377379
def VecMemRef : MemRefOf<[SYCL_VecType]>;
378380

@@ -393,13 +395,15 @@ def SYCLMemref : AnyTypeOf<[
393395
LocalAccessorBaseDeviceMemRef,
394396
LocalAccessorBaseMemRef,
395397
LocalAccessorMemRef,
398+
MinimumMemRef,
396399
MultiPtrMemRef,
397400
NDItemMemRef,
398401
NDRangeMemRef,
399402
OwnerLessBaseMemRef,
400403
RangeMemRef,
401-
SwizzledVecMemRef,
402404
StreamMemRef,
405+
SubGroupMemRef,
406+
SwizzledVecMemRef,
403407
VecMemRef
404408
]>;
405409

mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,15 @@ func.func @TestConstructorII32Ptr(%arg0: memref<?x!sycl_id_1_, 4>, %arg1: memref
1515
sycl.constructor(%arg0, %arg1) {MangledFunctionName = @_ZN4sycl3_V19multi_ptrIjLNS0_6access13address_spaceE1ELNS2_9decoratedE1EEC1EPU3AS1j, TypeName = @multi_ptr} : (memref<?x!sycl_id_1_, 4>, memref<?xi32, 1>) -> ()
1616
return
1717
}
18+
19+
// CHECK-LABEL: func.func @SubGroupConstructor
20+
func.func @SubGroupConstructor(%arg0: memref<?x!sycl.sub_group, 4>, %arg1: memref<?x!sycl.sub_group, 4>) {
21+
sycl.constructor(%arg0, %arg1) {MangledFunctionName = @_ZN4sycl3_V13ext6oneapi9sub_groupC1ERKS3_, TypeName = @sub_group} : (memref<?x!sycl.sub_group, 4>, memref<?x!sycl.sub_group, 4>) -> ()
22+
return
23+
}
24+
25+
// CHECK-LABEL: func.func @MinimumConstructor
26+
func.func @MinimumConstructor(%arg0: memref<?x!sycl.minimum<i32>, 4>, %arg1: memref<?x!sycl.minimum<i32>, 4>) {
27+
sycl.constructor(%arg0, %arg1) {MangledFunctionName = @_ZN4sycl3_V17minimumIiEC1ERKS2_, TypeName = @minimum} : (memref<?x!sycl.minimum<i32>, 4>, memref<?x!sycl.minimum<i32>, 4>) -> ()
28+
return
29+
}

0 commit comments

Comments
 (0)