Skip to content

Commit 7889090

Browse files
authored
[mlir][math] Propagate scalability in convert-math-to-llvm (#82635)
This also generally increases the coverage of scalable vector types in the math-to-llvm tests.
1 parent b39f566 commit 7889090

File tree

2 files changed

+90
-9
lines changed

2 files changed

+90
-9
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
148148
return LLVM::detail::handleMultidimensionalVectors(
149149
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
150150
[&](Type llvm1DVectorTy, ValueRange operands) {
151+
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
151152
auto splatAttr = SplatElementsAttr::get(
152-
mlir::VectorType::get(
153-
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
154-
floatType),
153+
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
154+
{numElements.isScalable()}),
155155
floatOne);
156156
auto one =
157157
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
207207
return LLVM::detail::handleMultidimensionalVectors(
208208
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
209209
[&](Type llvm1DVectorTy, ValueRange operands) {
210+
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
210211
auto splatAttr = SplatElementsAttr::get(
211-
mlir::VectorType::get(
212-
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
213-
floatType),
212+
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
213+
{numElements.isScalable()}),
214214
floatOne);
215215
auto one =
216216
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
266266
return LLVM::detail::handleMultidimensionalVectors(
267267
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
268268
[&](Type llvm1DVectorTy, ValueRange operands) {
269+
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
269270
auto splatAttr = SplatElementsAttr::get(
270-
mlir::VectorType::get(
271-
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
272-
floatType),
271+
mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
272+
{numElements.isScalable()}),
273273
floatOne);
274274
auto one =
275275
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
7777

7878
// -----
7979

80+
// CHECK-LABEL: func @log1p_scalable_vector(
81+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
82+
func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
83+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
84+
// CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]] : vector<[4]xf32>
85+
// CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<[4]xf32>) -> vector<[4]xf32>
86+
%0 = math.log1p %arg0 : vector<[4]xf32>
87+
func.return %0 : vector<[4]xf32>
88+
}
89+
90+
// -----
91+
8092
// CHECK-LABEL: func @expm1(
8193
// CHECK-SAME: f32
8294
func.func @expm1(%arg0 : f32) {
@@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) {
113125

114126
// -----
115127

128+
// CHECK-LABEL: func @expm1_scalable_vector(
129+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
130+
func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
131+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
132+
// CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
133+
// CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32>
134+
%0 = math.expm1 %arg0 : vector<[4]xf32>
135+
func.return %0 : vector<[4]xf32>
136+
}
137+
138+
// -----
139+
116140
// CHECK-LABEL: func @expm1_vector_fmf(
117141
// CHECK-SAME: vector<4xf32>
118142
func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
@@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) {
177201

178202
// -----
179203

204+
// CHECK-LABEL: func @cttz_scalable_vec(
205+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
206+
func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
207+
// CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32>
208+
%0 = math.cttz %arg0 : vector<[4]xi32>
209+
func.return %0 : vector<[4]xi32>
210+
}
211+
212+
// -----
213+
180214
// CHECK-LABEL: func @ctpop(
181215
// CHECK-SAME: i32
182216
func.func @ctpop(%arg0 : i32) {
@@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) {
197231

198232
// -----
199233

234+
// CHECK-LABEL: func @ctpop_scalable_vector(
235+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
236+
func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
237+
// CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32>
238+
%0 = math.ctpop %arg0 : vector<[4]xi32>
239+
func.return %0 : vector<[4]xi32>
240+
}
241+
242+
// -----
243+
200244
// CHECK-LABEL: func @rsqrt_double(
201245
// CHECK-SAME: f64
202246
func.func @rsqrt_double(%arg0 : f64) {
@@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
233277

234278
// -----
235279

280+
// CHECK-LABEL: func @rsqrt_scalable_vector(
281+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
282+
func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32>{
283+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
284+
// CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
285+
// CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
286+
%0 = math.rsqrt %arg0 : vector<[4]xf32>
287+
func.return %0 : vector<[4]xf32>
288+
}
289+
290+
// -----
291+
236292
// CHECK-LABEL: func @rsqrt_vector_fmf(
237293
// CHECK-SAME: vector<4xf32>
238294
func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
@@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
245301

246302
// -----
247303

304+
// CHECK-LABEL: func @rsqrt_scalable_vector_fmf(
305+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
306+
func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
307+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
308+
// CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32>
309+
// CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32>
310+
%0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32>
311+
func.return %0 : vector<[4]xf32>
312+
}
313+
314+
// -----
315+
248316
// CHECK-LABEL: func @rsqrt_multidim_vector(
249317
func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
250318
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
258326

259327
// -----
260328

329+
// CHECK-LABEL: func @rsqrt_multidim_scalable_vector(
330+
func.func @rsqrt_multidim_scalable_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> {
331+
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
332+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
333+
// CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32>
334+
// CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
335+
// CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
336+
%0 = math.rsqrt %arg0 : vector<4x[4]xf32>
337+
func.return %0 : vector<4x[4]xf32>
338+
}
339+
340+
// -----
341+
261342
// CHECK-LABEL: func @fpowi(
262343
// CHECK-SAME: f64
263344
func.func @fpowi(%arg0 : f64, %arg1 : i32) {

0 commit comments

Comments
 (0)