File tree Expand file tree Collapse file tree 3 files changed +26
-9
lines changed Expand file tree Collapse file tree 3 files changed +26
-9
lines changed Original file line number Diff line number Diff line change @@ -155,9 +155,15 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
155
155
// operations.
156
156
auto valueShape = memRefType.getShape ();
157
157
SmallVector<Value, 8 > constantIndices;
158
- for (auto i : llvm::seq<int64_t >(
159
- 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
160
- constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
158
+
159
+ if (!valueShape.empty ()) {
160
+ for (auto i : llvm::seq<int64_t >(
161
+ 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
162
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
163
+ } else {
164
+ // This is the case of a tensor of rank 0.
165
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, 0 ));
166
+ }
161
167
162
168
// The constant operation represents a multi-dimensional constant, so we
163
169
// will need to generate a store for each of the elements. The following
Original file line number Diff line number Diff line change @@ -155,10 +155,15 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
155
155
// operations.
156
156
auto valueShape = memRefType.getShape ();
157
157
SmallVector<Value, 8 > constantIndices;
158
- for (auto i : llvm::seq<int64_t >(
159
- 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
160
- constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
161
158
159
+ if (!valueShape.empty ()) {
160
+ for (auto i : llvm::seq<int64_t >(
161
+ 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
162
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
163
+ } else {
164
+ // This is the case of a tensor of rank 0.
165
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, 0 ));
166
+ }
162
167
// The constant operation represents a multi-dimensional constant, so we
163
168
// will need to generate a store for each of the elements. The following
164
169
// functor recursively walks the dimensions of the constant shape,
Original file line number Diff line number Diff line change @@ -155,9 +155,15 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
155
155
// operations.
156
156
auto valueShape = memRefType.getShape ();
157
157
SmallVector<Value, 8 > constantIndices;
158
- for (auto i : llvm::seq<int64_t >(
159
- 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
160
- constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
158
+
159
+ if (!valueShape.empty ()) {
160
+ for (auto i : llvm::seq<int64_t >(
161
+ 0 , *std::max_element (valueShape.begin (), valueShape.end ())))
162
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, i));
163
+ } else {
164
+ // This is the case of a tensor of rank 0.
165
+ constantIndices.push_back (rewriter.create <ConstantIndexOp>(loc, 0 ));
166
+ }
161
167
162
168
// The constant operation represents a multi-dimensional constant, so we
163
169
// will need to generate a store for each of the elements. The following
You can’t perform that action at this time.
0 commit comments