@@ -541,12 +541,10 @@ static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
541
541
mlir::OperandRange operands,
542
542
mlir::TypeRange types,
543
543
std::optional<mlir::ArrayAttr> attributes) {
544
- for (unsigned i = 0 , e = attributes->size (); i < e; ++i) {
545
- if (i != 0 )
546
- p << " , " ;
547
- p << (*attributes)[i] << " -> " << operands[i] << " : "
548
- << operands[i].getType ();
549
- }
544
+ llvm::interleaveComma (llvm::zip (*attributes, operands), p, [&](auto it) {
545
+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
546
+ << std::get<1 >(it).getType ();
547
+ });
550
548
}
551
549
552
550
// ===----------------------------------------------------------------------===//
@@ -852,27 +850,27 @@ static ParseResult parseNumGangs(
852
850
return success ();
853
851
}
854
852
853
+ static void printSingleDeviceType (mlir::OpAsmPrinter &p, mlir::Attribute attr) {
854
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
855
+ if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
856
+ p << " [" << attr << " ]" ;
857
+ }
858
+
855
859
static void printNumGangs (mlir::OpAsmPrinter &p, mlir::Operation *op,
856
860
mlir::OperandRange operands, mlir::TypeRange types,
857
861
std::optional<mlir::ArrayAttr> deviceTypes,
858
862
std::optional<mlir::DenseI32ArrayAttr> segments) {
859
863
unsigned opIdx = 0 ;
860
- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
861
- if (i != 0 )
862
- p << " , " ;
864
+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
863
865
p << " {" ;
864
- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
865
- if (j != 0 )
866
- p << " , " ;
867
- p << operands[opIdx] << " : " << operands[opIdx].getType ();
868
- ++opIdx;
869
- }
866
+ llvm::interleaveComma (
867
+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
868
+ p << operands[opIdx] << " : " << operands[opIdx].getType ();
869
+ ++opIdx;
870
+ });
870
871
p << " }" ;
871
- auto deviceTypeAttr =
872
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
873
- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
874
- p << " [" << (*deviceTypes)[i] << " ]" ;
875
- }
872
+ printSingleDeviceType (p, it.value ());
873
+ });
876
874
}
877
875
878
876
static ParseResult parseDeviceTypeOperandsWithSegment (
@@ -921,30 +919,21 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
921
919
return success ();
922
920
}
923
921
924
- static void printSingleDeviceType (mlir::OpAsmPrinter &p, mlir::Attribute attr) {
925
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
926
- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
927
- p << " [" << attr << " ]" ;
928
- }
929
-
930
922
static void printDeviceTypeOperandsWithSegment (
931
923
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
932
924
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
933
925
std::optional<mlir::DenseI32ArrayAttr> segments) {
934
926
unsigned opIdx = 0 ;
935
- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
936
- if (i != 0 )
937
- p << " , " ;
927
+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
938
928
p << " {" ;
939
- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
940
- if (j != 0 )
941
- p << " , " ;
942
- p << operands[opIdx] << " : " << operands[opIdx].getType ();
943
- ++opIdx;
944
- }
929
+ llvm::interleaveComma (
930
+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
931
+ p << operands[opIdx] << " : " << operands[opIdx].getType ();
932
+ ++opIdx;
933
+ });
945
934
p << " }" ;
946
- printSingleDeviceType (p, (*deviceTypes)[i] );
947
- }
935
+ printSingleDeviceType (p, it. value () );
936
+ });
948
937
}
949
938
950
939
static ParseResult parseDeviceTypeOperands (
@@ -977,12 +966,10 @@ static void
977
966
printDeviceTypeOperands (mlir::OpAsmPrinter &p, mlir::Operation *op,
978
967
mlir::OperandRange operands, mlir::TypeRange types,
979
968
std::optional<mlir::ArrayAttr> deviceTypes) {
980
- for (unsigned i = 0 , e = deviceTypes->size (); i < e; ++i) {
981
- if (i != 0 )
982
- p << " , " ;
983
- p << operands[i] << " : " << operands[i].getType ();
984
- printSingleDeviceType (p, (*deviceTypes)[i]);
985
- }
969
+ llvm::interleaveComma (llvm::zip (*deviceTypes, operands), p, [&](auto it) {
970
+ p << std::get<1 >(it) << " : " << std::get<1 >(it).getType ();
971
+ printSingleDeviceType (p, std::get<0 >(it));
972
+ });
986
973
}
987
974
988
975
static ParseResult parseDeviceTypeOperandsWithKeywordOnly (
@@ -1056,14 +1043,10 @@ static void printDeviceTypes(mlir::OpAsmPrinter &p,
1056
1043
std::optional<mlir::ArrayAttr> deviceTypes) {
1057
1044
if (!hasDeviceTypeValues (deviceTypes))
1058
1045
return ;
1046
+
1059
1047
p << " [" ;
1060
- for (unsigned i = 0 ; i < deviceTypes.value ().size (); ++i) {
1061
- if (i != 0 )
1062
- p << " , " ;
1063
- auto deviceTypeAttr =
1064
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
1065
- p << deviceTypeAttr;
1066
- }
1048
+ llvm::interleaveComma (*deviceTypes, p,
1049
+ [&](mlir::Attribute attr) { p << attr; });
1067
1050
p << " ]" ;
1068
1051
}
1069
1052
@@ -1081,19 +1064,11 @@ static void printDeviceTypeOperandsWithKeywordOnly(
1081
1064
}
1082
1065
1083
1066
p << " (" ;
1084
-
1085
1067
printDeviceTypes (p, keywordOnlyDeviceTypes);
1086
-
1087
1068
if (hasDeviceTypeValues (keywordOnlyDeviceTypes) &&
1088
1069
hasDeviceTypeValues (deviceTypes))
1089
1070
p << " , " ;
1090
-
1091
- for (unsigned i = 0 , e = deviceTypes->size (); i < e; ++i) {
1092
- if (i != 0 )
1093
- p << " , " ;
1094
- p << operands[i] << " : " << operands[i].getType ();
1095
- printSingleDeviceType (p, (*deviceTypes)[i]);
1096
- }
1071
+ printDeviceTypeOperands (p, op, operands, types, deviceTypes);
1097
1072
p << " )" ;
1098
1073
}
1099
1074
@@ -1483,49 +1458,33 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
1483
1458
}
1484
1459
1485
1460
p << " (" ;
1486
- if (hasDeviceTypeValues (gangOnlyDeviceTypes)) {
1487
- p << " [" ;
1488
- for (unsigned i = 0 ; i < gangOnlyDeviceTypes.value ().size (); ++i) {
1489
- if (i != 0 )
1490
- p << " , " ;
1491
- auto deviceTypeAttr =
1492
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[i]);
1493
- p << deviceTypeAttr;
1494
- }
1495
- p << " ]" ;
1496
- }
1461
+ printDeviceTypes (p, gangOnlyDeviceTypes);
1497
1462
1498
1463
if (hasDeviceTypeValues (gangOnlyDeviceTypes) &&
1499
1464
hasDeviceTypeValues (deviceTypes))
1500
1465
p << " , " ;
1501
1466
1502
1467
if (deviceTypes) {
1503
1468
unsigned opIdx = 0 ;
1504
- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
1505
- if (i != 0 )
1506
- p << " , " ;
1469
+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
1507
1470
p << " {" ;
1508
- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
1509
- if (j != 0 )
1510
- p << " , " ;
1511
- auto gangArgTypeAttr =
1512
- mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
1513
- if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Num)
1514
- p << LoopOp::getGangNumKeyword ();
1515
- else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Dim)
1516
- p << LoopOp::getGangDimKeyword ();
1517
- else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Static)
1518
- p << LoopOp::getGangStaticKeyword ();
1519
- p << " =" << operands[opIdx] << " : " << operands[opIdx].getType ();
1520
- ++opIdx;
1521
- }
1522
-
1471
+ llvm::interleaveComma (
1472
+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
1473
+ auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1474
+ (*gangArgTypes)[opIdx]);
1475
+ if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Num)
1476
+ p << LoopOp::getGangNumKeyword ();
1477
+ else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Dim)
1478
+ p << LoopOp::getGangDimKeyword ();
1479
+ else if (gangArgTypeAttr.getValue () ==
1480
+ mlir::acc::GangArgType::Static)
1481
+ p << LoopOp::getGangStaticKeyword ();
1482
+ p << " =" << operands[opIdx] << " : " << operands[opIdx].getType ();
1483
+ ++opIdx;
1484
+ });
1523
1485
p << " }" ;
1524
- auto deviceTypeAttr =
1525
- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
1526
- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
1527
- p << " [" << (*deviceTypes)[i] << " ]" ;
1528
- }
1486
+ printSingleDeviceType (p, it.value ());
1487
+ });
1529
1488
}
1530
1489
p << " )" ;
1531
1490
}
0 commit comments