Skip to content

Commit d558540

Browse files
committed
[mlir][Vector] Add return type inference for multi_reduction
This subsumes the builder and verifier.
1 parent 3ba42a5 commit d558540

File tree

3 files changed

+30
-44
lines changed

3 files changed

+30
-44
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def Vector_ReductionOp :
313313
def Vector_MultiDimReductionOp :
314314
Vector_Op<"multi_reduction", [NoSideEffect,
315315
PredOpTrait<"source operand and result have same element type",
316-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
316+
TCresVTEtIsSameAsOpBase<0, 0>>,
317+
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
317318
Arguments<(ins Vector_CombiningKindAttr:$kind,
318319
AnyVector:$source,
319320
I64ArrayAttr:$reduction_dims)>,
@@ -367,31 +368,10 @@ def Vector_MultiDimReductionOp :
367368
res[idx] = true;
368369
return res;
369370
}
370-
371-
static SmallVector<int64_t> inferDestShape(
372-
ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask) {
373-
assert(sourceShape.size() == reducedDimsMask.size() &&
374-
"sourceShape and maks of different sizes");
375-
SmallVector<int64_t> res;
376-
for (auto it : llvm::zip(reducedDimsMask, sourceShape))
377-
if (!std::get<0>(it))
378-
res.push_back(std::get<1>(it));
379-
return res;
380-
}
381-
382-
static Type inferDestType(
383-
ArrayRef<int64_t> sourceShape, ArrayRef<bool> reducedDimsMask, Type elementType) {
384-
auto targetShape = inferDestShape(sourceShape, reducedDimsMask);
385-
// TODO: update to also allow 0-d vectors when available.
386-
if (targetShape.empty())
387-
return elementType;
388-
return VectorType::get(targetShape, elementType);
389-
}
390371
}];
391372
let assemblyFormat =
392373
"$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
393374
let hasFolder = 1;
394-
let hasVerifier = 1;
395375
}
396376

397377
def Vector_BroadcastOp :

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -336,32 +336,31 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
336336
OperationState &result, Value source,
337337
ArrayRef<bool> reductionMask,
338338
CombiningKind kind) {
339-
result.addOperands(source);
340-
auto sourceVectorType = source.getType().cast<VectorType>();
341-
auto targetType = MultiDimReductionOp::inferDestType(
342-
sourceVectorType.getShape(), reductionMask,
343-
sourceVectorType.getElementType());
344-
result.addTypes(targetType);
345-
346339
SmallVector<int64_t> reductionDims;
347340
for (const auto &en : llvm::enumerate(reductionMask))
348341
if (en.value())
349342
reductionDims.push_back(en.index());
350-
result.addAttribute(getReductionDimsAttrStrName(),
351-
builder.getI64ArrayAttr(reductionDims));
352-
result.addAttribute(getKindAttrStrName(),
353-
CombiningKindAttr::get(kind, builder.getContext()));
354-
}
355-
356-
LogicalResult MultiDimReductionOp::verify() {
357-
auto reductionMask = getReductionMask();
358-
auto targetType = MultiDimReductionOp::inferDestType(
359-
getSourceVectorType().getShape(), reductionMask,
360-
getSourceVectorType().getElementType());
361-
// TODO: update to support 0-d vectors when available.
362-
if (targetType != getDestType())
363-
return emitError("invalid output vector type: ")
364-
<< getDestType() << " (expected: " << targetType << ")";
343+
build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
344+
}
345+
346+
LogicalResult MultiDimReductionOp::inferReturnTypes(
347+
MLIRContext *, Optional<Location>, ValueRange operands,
348+
DictionaryAttr attributes, RegionRange,
349+
SmallVectorImpl<Type> &inferredReturnTypes) {
350+
MultiDimReductionOp::Adaptor op(operands, attributes);
351+
auto vectorType = op.source().getType().cast<VectorType>();
352+
SmallVector<int64_t> targetShape;
353+
for (auto it : llvm::enumerate(vectorType.getShape()))
354+
if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) {
355+
return attr.cast<IntegerAttr>().getValue() == it.index();
356+
}))
357+
targetShape.push_back(it.value());
358+
// TODO: update to also allow 0-d vectors when available.
359+
if (targetShape.empty())
360+
inferredReturnTypes.push_back(vectorType.getElementType());
361+
else
362+
inferredReturnTypes.push_back(
363+
VectorType::get(targetShape, vectorType.getElementType()));
365364
return success();
366365
}
367366

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,13 @@ func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
11371137

11381138
// -----
11391139

1140+
func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
1141+
// expected-error@+1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
1142+
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
1143+
}
1144+
1145+
// -----
1146+
11401147
func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
11411148
// expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}}
11421149
%0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>

0 commit comments

Comments
 (0)