@@ -313,6 +313,116 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
313
313
314
314
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
315
315
let summary = "Normalize memrefs";
316
+ let description = [{
317
+ This pass transforms memref types with a non-trivial
318
+ [layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
319
+ memref types with an identity layout map, e.g. (i, j) -> (i, j). This
320
+ pass is inter-procedural, in the sense that it can modify function
321
+ interfaces and call sites that pass memref types. In order to modify
322
+ memref types while preserving the original behavior, users of those
323
+ memref types are also modified to incorporate the resulting layout map.
324
+ For instance, an [AffineLoadOp]
325
+ (https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
326
+ will be updated to compose the layout map with with the affine expression
327
+ contained in the op. Operations marked with the [MemRefsNormalizable]
328
+ (https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
329
+ expected to be normalizable. Supported operations include affine
330
+ operations, std.alloc, std.dealloc, and std.return.
331
+
332
+ Given an appropriate layout map specified in the code, this transformation
333
+ can express tiled or linearized access to multi-dimensional data
334
+ structures, but will not modify memref types without an explicit layout
335
+ map.
336
+
337
+ Currently this pass is limited to only modify
338
+ functions where all memref types can be normalized. If a function
339
+ contains any operations that are not MemRefNormalizable, then the function
340
+ and any functions that call or call it will not be modified.
341
+
342
+ Input
343
+
344
+ ```mlir
345
+ #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
346
+ func @matmul(%A: memref<16xf64, #tile>,
347
+ %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
348
+ affine.for %arg3 = 0 to 16 {
349
+ %a = affine.load %A[%arg3] : memref<16xf64, #tile>
350
+ %p = mulf %a, %a : f64
351
+ affine.store %p, %A[%arg3] : memref<16xf64, #tile>
352
+ }
353
+ %c = alloc() : memref<16xf64, #tile>
354
+ %d = affine.load %c[0] : memref<16xf64, #tile>
355
+ return %A: memref<16xf64, #tile>
356
+ }
357
+ ```
358
+
359
+ Output
360
+
361
+ ```mlir
362
+ func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
363
+ -> memref<4x4xf64> {
364
+ affine.for %arg3 = 0 to 16 {
365
+ %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
366
+ %4 = mulf %3, %3 : f64
367
+ affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
368
+ }
369
+ %0 = alloc() : memref<4x4xf64>
370
+ %1 = affine.apply #map1()
371
+ %2 = affine.load %0[0, 0] : memref<4x4xf64>
372
+ return %arg0 : memref<4x4xf64>
373
+ }
374
+ ```
375
+
376
+ Input
377
+
378
+ ```
379
+ #linear8 = affine_map<(i, j) -> (i * 8 + j)>
380
+ func @linearize(%arg0: memref<8x8xi32, #linear8>,
381
+ %arg1: memref<8x8xi32, #linear8>,
382
+ %arg2: memref<8x8xi32, #linear8>) {
383
+ %c8 = constant 8 : index
384
+ %c0 = constant 0 : index
385
+ %c1 = constant 1 : index
386
+ affine.for %arg3 = %c0 to %c8 {
387
+ affine.for %arg4 = %c0 to %c8 {
388
+ affine.for %arg5 = %c0 to %c8 {
389
+ %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
390
+ %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
391
+ %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
392
+ %3 = muli %0, %1 : i32
393
+ %4 = addi %2, %3 : i32
394
+ affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
395
+ }
396
+ }
397
+ }
398
+ return
399
+ }
400
+ ```
401
+
402
+ Output
403
+
404
+ ```mlir
405
+ func @linearize(%arg0: memref<64xi32>,
406
+ %arg1: memref<64xi32>,
407
+ %arg2: memref<64xi32>) {
408
+ %c8 = constant 8 : index
409
+ %c0 = constant 0 : index
410
+ affine.for %arg3 = %c0 to %c8 {
411
+ affine.for %arg4 = %c0 to %c8 {
412
+ affine.for %arg5 = %c0 to %c8 {
413
+ %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
414
+ %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
415
+ %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
416
+ %3 = muli %0, %1 : i32
417
+ %4 = addi %2, %3 : i32
418
+ affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
419
+ }
420
+ }
421
+ }
422
+ return
423
+ }
424
+ ```
425
+ }];
316
426
let constructor = "mlir::createNormalizeMemRefsPass()";
317
427
}
318
428
0 commit comments