@@ -824,7 +824,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
824
824
// We have some custom DAG combine patterns for these nodes
825
825
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
826
826
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
827
- ISD::BUILD_VECTOR});
827
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST });
828
828
829
829
// setcc for f16x2 and bf16x2 needs special handling to prevent
830
830
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5536,6 +5536,21 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5536
5536
return DAG.getNode (ISD::BITCAST, DL, VT, PRMT);
5537
5537
}
5538
5538
5539
+ static SDValue combineADDRSPACECAST (SDNode *N,
5540
+ TargetLowering::DAGCombinerInfo &DCI) {
5541
+ auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
5542
+
5543
+ if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(ASCN1->getOperand (0 ))) {
5544
+ assert (ASCN2->getDestAddressSpace () == ASCN1->getSrcAddressSpace ());
5545
+
5546
+ // Fold asc[B -> A](asc[A -> B](x)) -> x
5547
+ if (ASCN1->getDestAddressSpace () == ASCN2->getSrcAddressSpace ())
5548
+ return ASCN2->getOperand (0 );
5549
+ }
5550
+
5551
+ return SDValue ();
5552
+ }
5553
+
5539
5554
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5540
5555
DAGCombinerInfo &DCI) const {
5541
5556
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5570,6 +5585,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5570
5585
return PerformVSELECTCombine (N, DCI);
5571
5586
case ISD::BUILD_VECTOR:
5572
5587
return PerformBUILD_VECTORCombine (N, DCI);
5588
+ case ISD::ADDRSPACECAST:
5589
+ return combineADDRSPACECAST (N, DCI);
5573
5590
}
5574
5591
return SDValue ();
5575
5592
}
0 commit comments