Skip to content

[mlir][spirv] Add support for SPV_ARM_tensors #144667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m

def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;

def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;

def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
Expand All @@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
Expand Down Expand Up @@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
list<Availability> availability = [
Extension<[SPV_ARM_tensors]>
];
}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
Expand Down Expand Up @@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
Expand Down Expand Up @@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;

def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;

// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
Expand Down Expand Up @@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
"any SPIR-V tensorArm type">;

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
Expand All @@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
SPIRV_AnyImage
SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;

def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
Expand Down Expand Up @@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
Expand Down Expand Up @@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformLogicalXor,
SPIRV_OC_OpTypeTensorARM,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
Expand Down
43 changes: 42 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
Expand Down Expand Up @@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
// StructType, or SPIR-V TensorArmType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
Expand Down Expand Up @@ -477,6 +479,45 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};

/// SPIR-V TensorARM Type
class TensorArmType
: public Type::TypeBase<TensorArmType, CompositeType,
detail::TensorArmTypeStorage, ShapedType::Trait> {
public:
using Base::Base;

using ShapedType::Trait<TensorArmType>::getElementTypeBitWidth;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If you want, you can use a typedef for this trait and then use that in all of these using declarations. Something like:

using ShapedTypeTraits = ShapedType::Trait<TensorArmType>;
using ShapedTypeTraits::getElementTypeBitWidth;

using ShapedType::Trait<TensorArmType>::getRank;
using ShapedType::Trait<TensorArmType>::getNumElements;
using ShapedType::Trait<TensorArmType>::isDynamicDim;
using ShapedType::Trait<TensorArmType>::hasStaticShape;
using ShapedType::Trait<TensorArmType>::getNumDynamicDims;
using ShapedType::Trait<TensorArmType>::getDimSize;
using ShapedType::Trait<TensorArmType>::getDynamicDimIndex;

static constexpr StringLiteral name = "spirv.arm.tensor";

// TensorArm supports minimum rank of 1, hence an empty shape here means
// unranked.
static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);

Type getElementType() const;
ArrayRef<int64_t> getShape() const;
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};

} // namespace spirv
} // namespace mlir

Expand Down
75 changes: 74 additions & 1 deletion mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
} else if (auto t = dyn_cast<TensorArmType>(type)) {
if (!isa<ScalarType>(t.getElementType())) {
parser.emitError(
typeLoc, "only scalar element type allowed in tensor type but found ")
<< t.getElementType();
return Type();
}
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
Expand Down Expand Up @@ -363,6 +370,52 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}

// tensor-arm-type ::=
// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
static Type parseTensorArmType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return {};

bool unranked = false;
SmallVector<int64_t, 4> dims;
SMLoc countLoc = parser.getCurrentLocation();

if (parser.parseOptionalStar().succeeded()) {
unranked = true;
if (parser.parseXInDimensionList())
return {};
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
return {};
}

if (!unranked && dims.empty()) {
parser.emitError(countLoc, "arm.tensors do not support rank zero");
return {};
}

if (llvm::is_contained(dims, 0)) {
parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
return {};
}

if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
"fully dynamic or completed shaped");
return {};
}

auto elementTy = parseAndVerifyType(dialect, parser);
if (!elementTy)
return {};

if (parser.parseGreater())
return {};

return TensorArmType::get(dims, elementTy);
}

// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
Expand Down Expand Up @@ -759,6 +812,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
if (keyword == "arm.tensor")
return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
Expand Down Expand Up @@ -855,10 +910,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
os << ">";
}

static void print(TensorArmType type, DialectAsmPrinter &os) {
os << "arm.tensor<";

llvm::interleave(
type.getShape(), os,
[&](int64_t dim) {
if (ShapedType::isDynamic(dim))
os << '?';
else
os << dim;
},
"x");
if (!type.hasRank()) {
os << "*";
}
os << "x" << type.getElementType() << ">";
}

void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
ImageType, SampledImageType, StructType, MatrixType>(
ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
}

if (llvm::isa<TensorArmType>(type)) {
if (parser.parseOptionalColon().succeeded())
if (parser.parseType(type))
return failure();
}

return parser.addTypeToList(type, result.types);
}

Expand Down
Loading
Loading