@@ -242,21 +242,25 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
242
242
// 1.a. Emit std_load from input views.
243
243
for (unsigned i = 0 ; i < nInputs; ++i) {
244
244
Value input = genericOp.getInput (i);
245
- if (!input.getType ().cast <ShapedType>().getRank ()) {
246
- indexedValues[i] = std_load (input);
247
- } else {
245
+ if (input.getType ().cast <ShapedType>().getRank ()) {
248
246
ValueHandleArray indexing (makeCanonicalAffineApplies (
249
247
b, loc, genericOp.getInputIndexingMap (i), allIvs));
250
248
indexedValues[i] = std_load (input, indexing);
249
+ } else {
250
+ indexedValues[i] = std_load (input);
251
251
}
252
252
}
253
253
254
254
// 1.b. Emit std_load from output views.
255
255
for (unsigned i = 0 ; i < nOutputs; ++i) {
256
- ValueHandleArray indexing (makeCanonicalAffineApplies (
257
- b, loc, genericOp.getOutputIndexingMap (i), allIvs));
258
- indexedValues[nInputs + i] =
259
- std_load (genericOp.getOutputBuffer (i), indexing);
256
+ Value output = genericOp.getOutputBuffer (i);
257
+ if (output.getType ().cast <ShapedType>().getRank ()) {
258
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
259
+ b, loc, genericOp.getOutputIndexingMap (i), allIvs));
260
+ indexedValues[nInputs + i] = std_load (output, indexing);
261
+ } else {
262
+ indexedValues[nInputs + i] = std_load (output);
263
+ }
260
264
}
261
265
262
266
auto funcOp = genericOp.getFunction ();
@@ -267,9 +271,14 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
267
271
268
272
// 3. Emit std_store.
269
273
for (unsigned i = 0 ; i < nOutputs; ++i) {
270
- ValueHandleArray indexing (makeCanonicalAffineApplies (
271
- b, loc, genericOp.getOutputIndexingMap (i), allIvs));
272
- std_store (callOp->getResult (i), genericOp.getOutputBuffer (i), indexing);
274
+ Value output = genericOp.getOutputBuffer (i);
275
+ if (output.getType ().cast <ShapedType>().getRank ()) {
276
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
277
+ b, loc, genericOp.getOutputIndexingMap (i), allIvs));
278
+ std_store (callOp->getResult (i), output, indexing);
279
+ } else {
280
+ std_store (callOp->getResult (i), output);
281
+ }
273
282
}
274
283
return ;
275
284
}
@@ -288,10 +297,15 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
288
297
auto *yieldOp = cast<YieldOp>(block.back ()).getOperation ();
289
298
assert (yieldOp->getNumOperands () == nOutputs);
290
299
for (unsigned i = 0 ; i < nOutputs; ++i) {
291
- ValueHandleArray indexing (makeCanonicalAffineApplies (
292
- b, loc, genericOp.getOutputIndexingMap (i), allIvs));
293
- std_store (map.lookup (yieldOp->getOperand (i)),
294
- genericOp.getOutputBuffer (i), indexing);
300
+ Value output = genericOp.getOutputBuffer (i);
301
+ if (output.getType ().cast <ShapedType>().getRank ()) {
302
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
303
+ b, loc, genericOp.getOutputIndexingMap (i), allIvs));
304
+ std_store (map.lookup (yieldOp->getOperand (i)),
305
+ genericOp.getOutputBuffer (i), indexing);
306
+ } else {
307
+ std_store (map.lookup (yieldOp->getOperand (i)), output);
308
+ }
295
309
}
296
310
}
297
311
};
@@ -348,21 +362,25 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
348
362
// 1.a. Emit std_load from input views.
349
363
for (unsigned i = 0 ; i < nInputs; ++i) {
350
364
Value input = indexedGenericOp.getInput (i);
351
- if (!input.getType ().cast <ShapedType>().getRank ()) {
352
- indexedValues[nLoops + i] = std_load (input);
353
- } else {
365
+ if (input.getType ().cast <ShapedType>().getRank ()) {
354
366
ValueHandleArray indexing (makeCanonicalAffineApplies (
355
367
b, loc, indexedGenericOp.getInputIndexingMap (i), allIvs));
356
368
indexedValues[nLoops + i] = std_load (input, indexing);
369
+ } else {
370
+ indexedValues[nLoops + i] = std_load (input);
357
371
}
358
372
}
359
373
360
374
// 1.b. Emit std_load from output views.
361
375
for (unsigned i = 0 ; i < nOutputs; ++i) {
362
- ValueHandleArray indexing (makeCanonicalAffineApplies (
363
- b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
364
- indexedValues[nLoops + nInputs + i] =
365
- std_load (indexedGenericOp.getOutputBuffer (i), indexing);
376
+ Value output = indexedGenericOp.getOutputBuffer (i);
377
+ if (output.getType ().cast <ShapedType>().getRank ()) {
378
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
379
+ b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
380
+ indexedValues[nLoops + nInputs + i] = std_load (output, indexing);
381
+ } else {
382
+ indexedValues[nLoops + nInputs + i] = std_load (output);
383
+ }
366
384
}
367
385
368
386
if (auto funcOp = indexedGenericOp.getFunction ()) {
@@ -372,10 +390,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
372
390
373
391
// 3. Emit std_store.
374
392
for (unsigned i = 0 ; i < nOutputs; ++i) {
375
- ValueHandleArray indexing (makeCanonicalAffineApplies (
376
- b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
377
- std_store (callOp->getResult (i), indexedGenericOp.getOutputBuffer (i),
378
- indexing);
393
+ Value output = indexedGenericOp.getOutputBuffer (i);
394
+ if (output.getType ().cast <ShapedType>().getRank ()) {
395
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
396
+ b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
397
+ std_store (callOp->getResult (i), output, indexing);
398
+ } else {
399
+ std_store (callOp->getResult (i), output);
400
+ }
379
401
}
380
402
return ;
381
403
}
@@ -394,10 +416,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
394
416
auto *yieldOp = cast<YieldOp>(block.back ()).getOperation ();
395
417
assert (yieldOp->getNumOperands () == nOutputs);
396
418
for (unsigned i = 0 ; i < nOutputs; ++i) {
397
- ValueHandleArray indexing (makeCanonicalAffineApplies (
398
- b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
399
- std_store (map.lookup (yieldOp->getOperand (i)),
400
- indexedGenericOp.getOutputBuffer (i), indexing);
419
+ Value output = indexedGenericOp.getOutputBuffer (i);
420
+ if (output.getType ().cast <ShapedType>().getRank ()) {
421
+ ValueHandleArray indexing (makeCanonicalAffineApplies (
422
+ b, loc, indexedGenericOp.getOutputIndexingMap (i), allIvs));
423
+ std_store (map.lookup (yieldOp->getOperand (i)), output, indexing);
424
+ } else {
425
+ std_store (map.lookup (yieldOp->getOperand (i)), output);
426
+ }
401
427
}
402
428
}
403
429
};
0 commit comments