Skip to content

Commit bc62629

Browse files
committed
distance op verify + distance op tests
1 parent 560862b commit bc62629

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,20 +2104,29 @@ LogicalResult spirv::GLDistanceOp::verify() {
21042104
auto p1Type = getP1().getType();
21052105
auto resultType = getResult().getType();
21062106

2107-
auto p0VectorType = p0Type.dyn_cast<VectorType>();
2108-
auto p1VectorType = p1Type.dyn_cast<VectorType>();
2109-
if (!p0VectorType || !p1VectorType)
2110-
return emitOpError("operands must be vectors");
2107+
auto getFloatType = [](Type type) -> FloatType {
2108+
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2109+
return llvm::dyn_cast<FloatType>(vectorType.getElementType());
2110+
return llvm::dyn_cast<FloatType>(type);
2111+
};
21112112

2112-
if (p0VectorType.getShape() != p1VectorType.getShape())
2113-
return emitOpError("operands must have same shape");
2113+
FloatType p0FloatType = getFloatType(p0Type);
2114+
FloatType p1FloatType = getFloatType(p1Type);
2115+
FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
21142116

2115-
if (!resultType.isa<FloatType>())
2116-
return emitOpError("result must be scalar float");
2117+
if (!p0FloatType || !p1FloatType || !resultFloatType)
2118+
return emitOpError("operands and result must be float scalar or vector of float");
21172119

2118-
if (p0VectorType.getElementType() != resultType ||
2119-
p1VectorType.getElementType() != resultType)
2120-
return emitOpError("operand vector elements must match result type");
2120+
if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
2121+
return emitOpError("operand and result element types must match");
2122+
2123+
if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
2124+
if (!llvm::dyn_cast<VectorType>(p1Type) ||
2125+
p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
2126+
return emitOpError("vector operands must have same shape");
2127+
} else if (llvm::isa<VectorType>(p1Type)) {
2128+
return emitOpError("expected both operands to be scalars or both to be vectors");
2129+
}
21212130

21222131
return success();
21232132
}

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,45 @@ func.func @findumsb(%arg0 : i64) -> () {
541541
%2 = spirv.GL.FindUMsb %arg0 : i64
542542
return
543543
}
544+
545+
// -----
546+
547+
//===----------------------------------------------------------------------===//
548+
// spirv.GL.Distance
549+
//===----------------------------------------------------------------------===//
550+
551+
func.func @distance_scalar(%arg0 : f32, %arg1 : f32) {
552+
// CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
553+
%0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> f32
554+
return
555+
}
556+
557+
func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
558+
// CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
559+
%0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<3xf32> -> f32
560+
return
561+
}
562+
563+
// -----
564+
565+
func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
566+
// expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
567+
%0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32
568+
return
569+
}
570+
571+
// -----
572+
573+
func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
574+
// expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
575+
%0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
576+
return
577+
}
578+
579+
// -----
580+
581+
func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) {
582+
// expected-error @+1 {{'spirv.GL.Distance' op result #0 must be 16/32/64-bit float}}
583+
%0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32
584+
return
585+
}

0 commit comments

Comments
 (0)