Skip to content

[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion #119675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2025

Conversation

PietroGhg
Copy link
Contributor

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir

Author: Pietro Ghiglio (PietroGhg)

Changes

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).


Full diff: https://github.com/llvm/llvm-project/pull/119675.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+76-7)
  • (modified) mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir (+15-4)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         .Default([](auto) { return std::nullopt; });
   }
 
-  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
-    StringRef baseName = getBaseName(op.getMode());
-    std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+  static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+                                                Type type) {
+    StringRef baseName = getBaseName(mode);
+    std::optional<StringRef> typeMangling = getTypeMangling(type);
     if (!typeMangling)
       return std::nullopt;
     return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
                          typeMangling.value());
   }
 
+  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+    return getFuncName(op.getMode(), op.getType(0));
+  }
+
   /// Get the subgroup size from the target or return a default.
   static std::optional<int> getSubgroupSize(Operation *op) {
     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
+
     Value trueVal =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
     rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
   // CHECK-SAME:              (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
   // CHECK-SAME:               %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
   // CHECK-SAME:               %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
-  // CHECK-SAME:               %[[F64_VAL:.*]]: f64,  %[[OFFSET:.*]]: i32)
+  // CHECK-SAME:               %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+  // CHECK-SAME:               %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
   llvm.func @gpu_shuffles(%i8_val: i8,
                           %i16_val: i16,
                           %i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
                           %f16_val: f16,
                           %f32_val: f32,
                           %f64_val: f64,
+                          %bf16_val: bf16,
+                          %i1_val: i1,
                           %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
     %width = arith.constant 16 : i32
     // CHECK:         llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
     // CHECK:         llvm.mlir.constant(true) : i1
     // CHECK:         llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
     // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+    // CHECK:         %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+    // CHECK:         llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+    // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+    // CHECK:         %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+    // CHECK:         llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+    // CHECK:         llvm.mlir.constant(true) : i1
     %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
     %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
     %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
     %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
     %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
     %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+    %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+    %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
     llvm.return
   }
 }
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
 // Cannot convert due to value type not being supported by the conversion
 
 gpu.module @not_supported_lowering {
-  llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+  llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
     %width = arith.constant 32 : i32
-    // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
-    %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
     llvm.return
   }
 }

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Pietro Ghiglio (PietroGhg)

Changes

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).


Full diff: https://github.com/llvm/llvm-project/pull/119675.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+76-7)
  • (modified) mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir (+15-4)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         .Default([](auto) { return std::nullopt; });
   }
 
-  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
-    StringRef baseName = getBaseName(op.getMode());
-    std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+  static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+                                                Type type) {
+    StringRef baseName = getBaseName(mode);
+    std::optional<StringRef> typeMangling = getTypeMangling(type);
     if (!typeMangling)
       return std::nullopt;
     return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
                          typeMangling.value());
   }
 
+  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+    return getFuncName(op.getMode(), op.getType(0));
+  }
+
   /// Get the subgroup size from the target or return a default.
   static std::optional<int> getSubgroupSize(Operation *op) {
     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
+
     Value trueVal =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
     rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
   // CHECK-SAME:              (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
   // CHECK-SAME:               %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
   // CHECK-SAME:               %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
-  // CHECK-SAME:               %[[F64_VAL:.*]]: f64,  %[[OFFSET:.*]]: i32)
+  // CHECK-SAME:               %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+  // CHECK-SAME:               %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
   llvm.func @gpu_shuffles(%i8_val: i8,
                           %i16_val: i16,
                           %i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
                           %f16_val: f16,
                           %f32_val: f32,
                           %f64_val: f64,
+                          %bf16_val: bf16,
+                          %i1_val: i1,
                           %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
     %width = arith.constant 16 : i32
     // CHECK:         llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
     // CHECK:         llvm.mlir.constant(true) : i1
     // CHECK:         llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
     // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+    // CHECK:         %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+    // CHECK:         llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+    // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+    // CHECK:         %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+    // CHECK:         llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+    // CHECK:         llvm.mlir.constant(true) : i1
     %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
     %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
     %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
     %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
     %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
     %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+    %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+    %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
     llvm.return
   }
 }
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
 // Cannot convert due to value type not being supported by the conversion
 
 gpu.module @not_supported_lowering {
-  llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+  llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
     %width = arith.constant 32 : i32
-    // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
-    %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
     llvm.return
   }
 }

Comment on lines 346 to 359
Location loc = op->getLoc();
Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
std::optional<std::string> funcName;
Value inValue;
if (bitcastOrExtDestTy) {
Value newVal =
doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
assert(newVal && "Unhandled op type in bitcastorext");
funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
inValue = newVal;
} else {
funcName = getFuncName(op);
inValue = adaptor.getValue();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Location loc = op->getLoc();
Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
std::optional<std::string> funcName;
Value inValue;
if (bitcastOrExtDestTy) {
Value newVal =
doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
assert(newVal && "Unhandled op type in bitcastorext");
funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
inValue = newVal;
} else {
funcName = getFuncName(op);
inValue = adaptor.getValue();
}
Location loc = op->getLoc();
Value inValue = bitcastOrExtForShuffle(loc, adaptor.getValue(), rewriter);
std::optional<std::string> funcName = getFuncName(op.getMode(), inValue.getType());

Would this work too with a TypeSwitch in bitcastOrExtForShuffle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you

Comment on lines 375 to 380
if (bitcastOrExtDestTy) {
Value newVal =
doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
assert(newVal && "Unhandled op type in bitcastortrunc");
result = newVal;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (bitcastOrExtDestTy) {
Value newVal =
doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
assert(newVal && "Unhandled op type in bitcastortrunc");
result = newVal;
}
result = bitcastOrTruncAfterShuffle(loc, result, rewriter);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thank you

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some drive by nits

static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
Type type) {
StringRef baseName = getBaseName(mode);
std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, you don't need these indices, "_Z{}{}{}" should work too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've removed the explicit indices

Comment on lines 302 to 303
.Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
.Case<IntegerType>([&](auto intTy) -> Type {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't see why these need to be generic over the argument type

Suggested change
.Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
.Case<IntegerType>([&](auto intTy) -> Type {
.Case<BFloat16Type>([&](Type) { return rewriter.getIntegerType(16); })
.Case<IntegerType>([&](IntegerType intTy) -> Type {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that would compile.

        .Case([&](BFloat16Type) { return rewriter.getIntegerType(16); })
        .Case([&](IntegerType intTy) -> Type {

Is nicer IMO tho

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I went with @victor-eds 's suggestion

return rewriter.getIntegerType(8);
return Type{};
})
.Default([](auto) { return Type{}; });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.Default([](auto) { return Type{}; });
.Default([](Type) { return Type{}; });

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thank you

Comment on lines 357 to 365
gpu.module @not_supported_lowering {
llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
// expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have removed the reason for this module/function to exist. Maybe replace the i1 value with another unsupported type like a large float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added a test for f128

@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}

static bool needsBitCastOrExt(gpu::ShuffleOp op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it, thanks for spotting it

@PietroGhg PietroGhg force-pushed the pietro/fp16_i1_shuffle branch from 9e3af90 to 09e2222 Compare January 7, 2025 16:00
Copy link
Contributor

@victor-eds victor-eds left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple NITs. LGTM!

@PietroGhg PietroGhg force-pushed the pietro/fp16_i1_shuffle branch from 09e2222 to 7ae5428 Compare January 8, 2025 11:41
@PietroGhg PietroGhg force-pushed the pietro/fp16_i1_shuffle branch from 7ae5428 to 60ef4fd Compare January 8, 2025 13:37
@sommerlukas sommerlukas merged commit cdd652e into llvm:main Jan 9, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants