Skip to content

Commit 14858cf

Browse files
[mlir][Conversion/GPUCommon] Fix bug in conversion of math ops
The common GPU operation transformation that lowers `math` operations to function calls in the `gpu-to-nvvm` and `gpu-to-rocdl` passes handles `vector` types by applying the function to each scalar and returning a new vector. However, there was a typo that results in incorrectly accumulating the result vector, and the rewrite returns an `llvm.mlir.undef` result instead of the correct vector. A patch is added and tests are strengthened. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D154269
1 parent 6a66673 commit 14858cf

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
485485
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
486486
Operation *scalarOp =
487487
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
488-
rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0),
489-
index);
488+
result = rewriter.create<LLVM::InsertElementOp>(
489+
loc, result, scalarOp->getResult(0), index);
490490
}
491491

492492
rewriter.replaceOp(op, result);

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,16 @@ gpu.module @test_module {
516516
// CHECK-LABEL: func @gpu_unroll
517517
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
518518
%result = math.exp %arg0 : vector<4xf32>
519-
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
520-
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
521-
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
522-
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
519+
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
520+
// CHECK: %[[CL:.+]] = llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
521+
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
522+
// CHECK: %[[CL:.+]] = llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
523+
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
524+
// CHECK: %[[CL:.+]] = llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
525+
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
526+
// CHECK: %[[CL:.+]] = llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
527+
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
528+
// CHECK: return %[[V4]]
523529
func.return %result : vector<4xf32>
524530
}
525531
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,16 @@ gpu.module @test_module {
456456
// CHECK-LABEL: func @gpu_unroll
457457
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
458458
%result = math.exp %arg0 : vector<4xf32>
459-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
460-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
461-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
462-
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
459+
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
460+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
461+
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
462+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
463+
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
464+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
465+
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
466+
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
467+
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
468+
// CHECK: return %[[V4]]
463469
func.return %result : vector<4xf32>
464470
}
465471
}

0 commit comments

Comments
 (0)