@@ -90,39 +90,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
90
90
Value one = rewriter.create <ConstantIndexOp>(loc, 1 );
91
91
92
92
// Find smaller and greater rank and extent tensor.
93
- Value lhsRank = rewriter.create <DimOp>(loc, transformed .lhs (), zero);
94
- Value rhsRank = rewriter.create <DimOp>(loc, transformed .rhs (), zero);
95
- Value lhsSmaller =
93
+ Value lhsRank = rewriter.create <DimOp>(loc, op .lhs (), zero);
94
+ Value rhsRank = rewriter.create <DimOp>(loc, op .rhs (), zero);
95
+ Value lhsRankULE =
96
96
rewriter.create <CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
97
97
Type indexTy = rewriter.getIndexType ();
98
- Type extentTensorTy = op.getType ();
99
- auto ifOp = rewriter.create <IfOp>(
100
- loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
101
- lhsSmaller,
102
- [&](OpBuilder &b, Location loc) {
103
- b.create <scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs (),
104
- rhsRank, transformed.rhs ()});
105
- },
106
- [&](OpBuilder &b, Location loc) {
107
- b.create <scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs (),
108
- lhsRank, transformed.lhs ()});
109
- });
110
- Value smallerRank = ifOp.getResult (0 );
111
- Value smallerOperand = ifOp.getResult (1 );
112
- Value greaterRank = ifOp.getResult (2 );
113
- Value greaterOperand = ifOp.getResult (3 );
98
+ Value lesserRank =
99
+ rewriter.create <SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
100
+ Value greaterRank =
101
+ rewriter.create <SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
102
+ Value lesserRankOperand =
103
+ rewriter.create <SelectOp>(loc, lhsRankULE, op.lhs (), op.rhs ());
104
+ Value greaterRankOperand =
105
+ rewriter.create <SelectOp>(loc, lhsRankULE, op.rhs (), op.lhs ());
114
106
115
107
// Allocate stack memory for the broadcasted extent tensor.
116
108
Type memTy = MemRefType::get ({ShapedType::kDynamicSize }, indexTy);
117
109
Value mem = rewriter.create <AllocaOp>(loc, memTy, ValueRange{greaterRank});
118
110
119
111
// Copy extents from greater operand that are not challenged.
120
112
Value rankDiff =
121
- rewriter.create <SubIOp>(loc, indexTy, greaterRank, smallerRank );
113
+ rewriter.create <SubIOp>(loc, indexTy, greaterRank, lesserRank );
122
114
rewriter.create <ForOp>(loc, zero, rankDiff, one, llvm::None,
123
115
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
124
116
Value extent = b.create <ExtractElementOp>(
125
- loc, greaterOperand , ValueRange{iv});
117
+ loc, greaterRankOperand , ValueRange{iv});
126
118
b.create <StoreOp>(loc, extent, mem, ValueRange{iv});
127
119
b.create <scf::YieldOp>(loc);
128
120
});
@@ -132,16 +124,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
132
124
loc, rankDiff, greaterRank, one, llvm::None,
133
125
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
134
126
Value greaterOperandExtent =
135
- b.create <ExtractElementOp>(loc, greaterOperand , ValueRange{iv});
127
+ b.create <ExtractElementOp>(loc, greaterRankOperand , ValueRange{iv});
136
128
Value greaterOperandExtentIsOne =
137
129
b.create <CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
138
130
auto ifOp = b.create <IfOp>(
139
131
loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
140
132
[&](OpBuilder &b, Location loc) {
141
133
Value ivShifted = b.create <SubIOp>(loc, indexTy, iv, rankDiff);
142
- Value smallerOperandExtent = b.create <ExtractElementOp>(
143
- loc, smallerOperand , ValueRange{ivShifted});
144
- b.create <scf::YieldOp>(loc, smallerOperandExtent );
134
+ Value lesserRankOperandExtent = b.create <ExtractElementOp>(
135
+ loc, lesserRankOperand , ValueRange{ivShifted});
136
+ b.create <scf::YieldOp>(loc, lesserRankOperandExtent );
145
137
},
146
138
[&](OpBuilder &b, Location loc) {
147
139
b.create <scf::YieldOp>(loc, greaterOperandExtent);
0 commit comments