Skip to content

Commit af65eca

Browse files
committed
Address code review comments
1 parent a52f94c commit af65eca

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,14 +1401,15 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
14011401
The `cir.shift` operation performs a bitwise shift, either to the left or to
14021402
the right, based on the first operand. The second operand specifies the
14031403
value to be shifted, and the third operand determines the number of
1404-
positions by which the shift is applied, they must be either all vector of
1404+
positions by which the shift is applied, They must be either all vector of
14051405
integer type, or all integer type. If they are vectors, each vector element of
14061406
the shift target is shifted by the corresponding shift amount in
14071407
the shift amount vector.
14081408

14091409
```mlir
14101410
%res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
1411-
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
1411+
%new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs :
1412+
!cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
14121413
```
14131414
}];
14141415

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,44 +1372,42 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
13721372
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
13731373
cir::ShiftOp op, OpAdaptor adaptor,
13741374
mlir::ConversionPatternRewriter &rewriter) const {
1375-
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1376-
auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
1375+
assert((op.getValue().getType() == op.getType()) &&
1376+
"inconsistent operands' types NYI");
13771377

1378-
// Operands could also be vector type
1379-
auto cirAmtVTy = mlir::dyn_cast<cir::VectorType>(op.getAmount().getType());
1380-
auto cirValVTy = mlir::dyn_cast<cir::VectorType>(op.getValue().getType());
1381-
mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1378+
const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
13821379
mlir::Value amt = adaptor.getAmount();
13831380
mlir::Value val = adaptor.getValue();
13841381

1385-
assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
1386-
"shift input type must be integer or vector type, otherwise NYI");
1387-
1388-
assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
1389-
"inconsistent operands' types NYI");
1390-
1391-
// Ensure shift amount is the same type as the value. Some undefined
1392-
// behavior might occur in the casts below as per [C99 6.5.7.3].
1393-
// Vector type shift amount needs no cast as type consistency is expected to
1394-
// be already be enforced at CIRGen.
1395-
if (cirAmtTy)
1396-
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
1397-
true, cirAmtTy.getWidth(), cirValTy.getWidth());
1382+
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1383+
bool isUnsigned;
1384+
if (cirAmtTy) {
1385+
auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType());
1386+
isUnsigned = cirValTy.isUnsigned();
1387+
1388+
// Ensure shift amount is the same type as the value. Some undefined
1389+
// behavior might occur in the casts below as per [C99 6.5.7.3].
1390+
// Vector type shift amount needs no cast as type consistency is expected to
1391+
// be already be enforced at CIRGen.
1392+
if (cirAmtTy)
1393+
amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(),
1394+
cirValTy.getWidth());
1395+
} else {
1396+
auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType());
1397+
isUnsigned =
1398+
mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned();
1399+
}
13981400

13991401
// Lower to the proper LLVM shift operation.
14001402
if (op.getIsShiftleft()) {
14011403
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1402-
} else {
1403-
const bool isUnsigned =
1404-
cirValTy
1405-
? !cirValTy.isSigned()
1406-
: !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
1407-
if (isUnsigned)
1408-
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1409-
else
1410-
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1404+
return mlir::success();
14111405
}
14121406

1407+
if (isUnsigned)
1408+
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1409+
else
1410+
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
14131411
return mlir::success();
14141412
}
14151413

0 commit comments

Comments
 (0)