@@ -2104,20 +2104,29 @@ LogicalResult spirv::GLDistanceOp::verify() {
2104
2104
auto p1Type = getP1 ().getType ();
2105
2105
auto resultType = getResult ().getType ();
2106
2106
2107
- auto p0VectorType = p0Type.dyn_cast <VectorType>();
2108
- auto p1VectorType = p1Type.dyn_cast <VectorType>();
2109
- if (!p0VectorType || !p1VectorType)
2110
- return emitOpError (" operands must be vectors" );
2107
+ auto getFloatType = [](Type type) -> FloatType {
2108
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
2109
+ return llvm::dyn_cast<FloatType>(vectorType.getElementType ());
2110
+ return llvm::dyn_cast<FloatType>(type);
2111
+ };
2111
2112
2112
- if (p0VectorType.getShape () != p1VectorType.getShape ())
2113
- return emitOpError (" operands must have same shape" );
2113
+ FloatType p0FloatType = getFloatType (p0Type);
2114
+ FloatType p1FloatType = getFloatType (p1Type);
2115
+ FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
2114
2116
2115
- if (!resultType. isa <FloatType>() )
2116
- return emitOpError (" result must be scalar float" );
2117
+ if (!p0FloatType || !p1FloatType || !resultFloatType )
2118
+ return emitOpError (" operands and result must be float scalar or vector of float" );
2117
2119
2118
- if (p0VectorType.getElementType () != resultType ||
2119
- p1VectorType.getElementType () != resultType)
2120
- return emitOpError (" operand vector elements must match result type" );
2120
+ if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
2121
+ return emitOpError (" operand and result element types must match" );
2122
+
2123
+ if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
2124
+ if (!llvm::dyn_cast<VectorType>(p1Type) ||
2125
+ p0Vec.getShape () != llvm::dyn_cast<VectorType>(p1Type).getShape ())
2126
+ return emitOpError (" vector operands must have same shape" );
2127
+ } else if (llvm::isa<VectorType>(p1Type)) {
2128
+ return emitOpError (" expected both operands to be scalars or both to be vectors" );
2129
+ }
2121
2130
2122
2131
return success ();
2123
2132
}
0 commit comments