Skip to content

Commit 78c4974

Browse files
authored
[MLIR][Vector] Allow non-default memory spaces in gather/scatter lowerings (#67500)
GPU targets can gather on non-default address spaces (e.g. global), so this removes the check for the default memory space.
1 parent 7ddf7d8 commit 78c4974

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,12 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
8787
return success();
8888
}
8989

90-
// Check if the last stride is non-unit or the memory space is not zero.
90+
// Check if the last stride is non-unit and has a valid memory space.
9191
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
9292
const LLVMTypeConverter &converter) {
9393
if (!isLastMemrefDimUnitStride(memRefType))
9494
return failure();
95-
FailureOr<unsigned> addressSpace =
96-
converter.getMemRefAddressSpace(memRefType);
97-
if (failed(addressSpace) || *addressSpace != 0)
95+
if (failed(converter.getMemRefAddressSpace(memRefType)))
9896
return failure();
9997
return success();
10098
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,6 +2108,20 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
21082108

21092109
// -----
21102110

2111+
func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
2112+
%0 = arith.constant 0: index
2113+
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
2114+
return %1 : vector<3xf32>
2115+
}
2116+
2117+
// CHECK-LABEL: func @gather_op
2118+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
2119+
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2120+
// CHECK: return %[[G]] : vector<3xf32>
2121+
2122+
// -----
2123+
2124+
21112125
func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
21122126
%0 = arith.constant 0: index
21132127
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex> into vector<3xindex>

0 commit comments

Comments
 (0)