Skip to content

Commit 6b9c186

Browse files
authored
[mlir][spirv] Handle non-innerprod float vector add reductions (#73476)
Instead of extracting all individual vector components and performing a scalar summation, use `spirv.Dot` with the original reduction operand and a vector constant of all ones.
1 parent c599b8e commit 6b9c186

File tree

2 files changed

+88
-12
lines changed

2 files changed

+88
-12
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1919
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2020
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/Attributes.h"
2122
#include "mlir/IR/BuiltinAttributes.h"
2223
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/IR/Location.h"
@@ -755,16 +756,43 @@ struct VectorReductionToFPDotProd final
755756
if (!resultType)
756757
return rewriter.notifyMatchFailure(op, "result is not a float");
757758

758-
auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
759-
if (!mul)
760-
return rewriter.notifyMatchFailure(
761-
op, "reduction operand is not 'arith.mulf'");
759+
Value vec = adaptor.getVector();
760+
Value acc = adaptor.getAcc();
761+
762+
auto vectorType = dyn_cast<VectorType>(vec.getType());
763+
if (!vectorType) {
764+
assert(isa<FloatType>(vec.getType()) &&
765+
"Expected the vector to be scalarized");
766+
if (acc) {
767+
rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
768+
return success();
769+
}
770+
771+
rewriter.replaceOp(op, vec);
772+
return success();
773+
}
762774

763775
Location loc = op.getLoc();
764-
Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
765-
mul.getRhs());
766-
if (op.getAcc())
767-
res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
776+
Value lhs;
777+
Value rhs;
778+
if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
779+
lhs = mul.getLhs();
780+
rhs = mul.getRhs();
781+
} else {
782+
// If the operand is not a mul, use a vector of ones for the dot operand
783+
// to just sum up all values.
784+
lhs = vec;
785+
Attribute oneAttr =
786+
rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
787+
oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
788+
rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr);
789+
}
790+
assert(lhs);
791+
assert(rhs);
792+
793+
Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs);
794+
if (acc)
795+
res = rewriter.create<spirv::FAddOp>(loc, acc, res);
768796

769797
rewriter.replaceOp(op, res);
770798
return success();

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,31 +500,79 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
500500

501501
// -----
502502

503-
// CHECK-LABEL: func @reduction_addf
503+
// CHECK-LABEL: func @reduction_addf_mulf
504504
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
505505
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
506506
// CHECK: return %[[DOT]] : f32
507-
func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
507+
func.func @reduction_addf_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
508508
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
509509
%red = vector.reduction <add>, %mul : vector<4xf32> into f32
510510
return %red : f32
511511
}
512512

513513
// -----
514514

515-
// CHECK-LABEL: func @reduction_addf_acc
515+
// CHECK-LABEL: func @reduction_addf_acc_mulf
516516
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
517517
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
518518
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
519519
// CHECK: return %[[RES]] : f32
520-
func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
520+
func.func @reduction_addf_acc_mulf(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
521521
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
522522
%red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
523523
return %red : f32
524524
}
525525

526526
// -----
527527

528+
// CHECK-LABEL: func @reduction_addf
529+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
530+
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.+}}> : vector<4xf32>
531+
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
532+
// CHECK: return %[[DOT]] : f32
533+
func.func @reduction_addf_mulf(%arg0: vector<4xf32>) -> f32 {
534+
%red = vector.reduction <add>, %arg0 : vector<4xf32> into f32
535+
return %red : f32
536+
}
537+
538+
// -----
539+
540+
// CHECK-LABEL: func @reduction_addf_acc
541+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
542+
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.0{{.*}}> : vector<4xf32>
543+
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ONE]] : vector<4xf32> -> f32
544+
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
545+
// CHECK: return %[[RES]] : f32
546+
func.func @reduction_addf_acc(%arg0: vector<4xf32>, %acc: f32) -> f32 {
547+
%red = vector.reduction <add>, %arg0, %acc : vector<4xf32> into f32
548+
return %red : f32
549+
}
550+
551+
// -----
552+
553+
// CHECK-LABEL: func @reduction_addf_one_elem
554+
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>)
555+
// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
556+
// CHECK: return %[[RES]] : f32
557+
func.func @reduction_addf_one_elem(%arg0: vector<1xf32>) -> f32 {
558+
%red = vector.reduction <add>, %arg0 : vector<1xf32> into f32
559+
return %red : f32
560+
}
561+
562+
// -----
563+
564+
// CHECK-LABEL: func @reduction_addf_one_elem_acc
565+
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ACC:.+]]: f32)
566+
// CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
567+
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[RHS]] : f32
568+
// CHECK: return %[[RES]] : f32
569+
func.func @reduction_addf_one_elem_acc(%arg0: vector<1xf32>, %acc: f32) -> f32 {
570+
%red = vector.reduction <add>, %arg0, %acc : vector<1xf32> into f32
571+
return %red : f32
572+
}
573+
574+
// -----
575+
528576
// CHECK-LABEL: func @reduction_mul
529577
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
530578
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>

0 commit comments

Comments
 (0)