Skip to content

Commit 6f9d25c

Browse files
authored
[CIR] Backport Allow different Int types together in Vec ShiftOp (#1643)
Backport improvements in ShiftOp for vectors from llvm/llvm-project#141111.
1 parent 25dda94 commit 6f9d25c

File tree

4 files changed

+64
-8
lines changed

4 files changed

+64
-8
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3939,15 +3939,33 @@ LogicalResult cir::BinOp::verify() {
39393939
//===----------------------------------------------------------------------===//
39403940
LogicalResult cir::ShiftOp::verify() {
39413941
mlir::Operation *op = getOperation();
3942-
mlir::Type resType = getResult().getType();
3943-
bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
3944-
bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
3945-
if (isOp0Vec != isOp1Vec)
3942+
auto op0VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(0).getType());
3943+
auto op1VecTy = mlir::dyn_cast<cir::VectorType>(op->getOperand(1).getType());
3944+
if (!op0VecTy ^ !op1VecTy)
3945+
39463946
return emitOpError() << "input types cannot be one vector and one scalar";
3947-
if (isOp1Vec && op->getOperand(1).getType() != resType) {
3948-
return emitOpError() << "shift amount must have the type of the result "
3949-
<< "if it is vector shift";
3947+
3948+
if (op0VecTy) {
3949+
if (op0VecTy.getSize() != op1VecTy.getSize())
3950+
return emitOpError() << "input vector types must have the same size";
3951+
3952+
auto opResultTy = mlir::dyn_cast<cir::VectorType>(getResult().getType());
3953+
if (!opResultTy)
3954+
return emitOpError() << "the type of the result must be a vector "
3955+
<< "if it is vector shift";
3956+
3957+
auto op0VecEleTy = mlir::cast<cir::IntType>(op0VecTy.getElementType());
3958+
auto op1VecEleTy = mlir::cast<cir::IntType>(op1VecTy.getElementType());
3959+
if (op0VecEleTy.getWidth() != op1VecEleTy.getWidth())
3960+
return emitOpError()
3961+
<< "vector operands do not have the same elements sizes";
3962+
3963+
auto resVecEleTy = mlir::cast<cir::IntType>(opResultTy.getElementType());
3964+
if (op0VecEleTy.getWidth() != resVecEleTy.getWidth())
3965+
return emitOpError() << "vector operands and result type do not have the "
3966+
"same elements sizes";
39503967
}
3968+
39513969
return mlir::success();
39523970
}
39533971

clang/test/CIR/CodeGen/vectype-ext.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
55

66
typedef int vi4 __attribute__((ext_vector_type(4)));
7+
typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
78
typedef int vi3 __attribute__((ext_vector_type(3)));
89
typedef int vi2 __attribute__((ext_vector_type(2)));
910
typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -535,3 +536,13 @@ void test_vec3() {
535536
// LLVM-NEXT: %[[#RES:]] = add <3 x i32> %[[#V3]], splat (i32 1)
536537

537538
}
539+
540+
void vector_integers_shifts_test() {
541+
vi4 a = {1, 2, 3, 4};
542+
uvi4 b = {5u, 6u, 7u, 8u};
543+
544+
vi4 shl = a << b;
545+
// CHECK: %{{[0-9]+}} = cir.shift(left, %{{[0-9]+}} : !cir.vector<!s32i x 4>, %{{[0-9]+}} : !cir.vector<!u32i x 4>) -> !cir.vector<!s32i x 4>
546+
uvi4 shr = b >> a;
547+
// CHECK: %{{[0-9]+}} = cir.shift(right, %{{[0-9]+}} : !cir.vector<!u32i x 4>, %{{[0-9]+}} : !cir.vector<!s32i x 4>) -> !cir.vector<!u32i x 4>
548+
}

clang/test/CIR/CodeGen/vectype.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s
22

33
typedef int vi4 __attribute__((vector_size(16)));
4+
typedef unsigned int uvi4 __attribute__((vector_size(16)));
45
typedef double vd2 __attribute__((vector_size(16)));
56
typedef long long vll2 __attribute__((vector_size(16)));
67
typedef unsigned short vus2 __attribute__((vector_size(4)));
@@ -198,3 +199,13 @@ void vector_double_test(int x, double y) {
198199
vus2 w = __builtin_convertvector(a, vus2);
199200
// CHECK: %{{[0-9]+}} = cir.cast(float_to_int, %{{[0-9]+}} : !cir.vector<!cir.double x 2>), !cir.vector<!u16i x 2>
200201
}
202+
203+
void vector_integers_shifts_test() {
204+
vi4 a = {1, 2, 3, 4};
205+
uvi4 b = {5u, 6u, 7u, 8u};
206+
207+
vi4 shl = a << b;
208+
// CHECK: %{{[0-9]+}} = cir.shift(left, %{{[0-9]+}} : !cir.vector<!s32i x 4>, %{{[0-9]+}} : !cir.vector<!u32i x 4>) -> !cir.vector<!s32i x 4>
209+
uvi4 shr = b >> a;
210+
// CHECK: %{{[0-9]+}} = cir.shift(right, %{{[0-9]+}} : !cir.vector<!u32i x 4>, %{{[0-9]+}} : !cir.vector<!s32i x 4>) -> !cir.vector<!u32i x 4>
211+
}

clang/test/CIR/IR/invalid.cir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1434,11 +1434,27 @@ module {
14341434
%0 = cir.alloca !cir.vector<!s32i x 2>, !cir.ptr<!cir.vector<!s32i x 2>>, ["a", init] {alignment = 8 : i64}
14351435
%1 = cir.load %0 : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
14361436
%4 = cir.const #cir.const_vector<[#cir.int<12> : !s16i, #cir.int<12> : !s16i]> : !cir.vector<!s16i x 2>
1437-
// expected-error@+1 {{'cir.shift' op shift amount must have the type of the result if it is vector shift}}
1437+
// expected-error@+1 {{'cir.shift' op vector operands do not have the same elements sizes}}
14381438
%5 = cir.shift(left, %1 : !cir.vector<!s32i x 2>, %4 : !cir.vector<!s16i x 2>) -> !cir.vector<!s32i x 2>
14391439
cir.return
14401440
}
14411441
}
1442+
1443+
// -----
1444+
1445+
!s32i = !cir.int<s, 32>
1446+
!s16i = !cir.int<s, 16>
1447+
module {
1448+
cir.func @test_shift_vec2() {
1449+
%0 = cir.alloca !cir.vector<!s32i x 2>, !cir.ptr<!cir.vector<!s32i x 2>>, ["a", init] {alignment = 8 : i64}
1450+
%1 = cir.load %0 : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
1451+
%4 = cir.const #cir.const_vector<[#cir.int<12> : !s16i, #cir.int<12> : !s16i]> : !cir.vector<!s16i x 2>
1452+
// expected-error@+1 {{'cir.shift' op vector operands do not have the same elements sizes}}
1453+
%5 = cir.shift(left, %4 : !cir.vector<!s16i x 2>, %1 : !cir.vector<!s32i x 2>) -> !cir.vector<!s32i x 2>
1454+
cir.return
1455+
}
1456+
}
1457+
14421458
// -----
14431459

14441460
// Type of the attribute must be a CIR floating point type

0 commit comments

Comments
 (0)