|
| 1 | +// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s |
| 2 | + |
| 3 | +// This test verifies the simplification of IR patterns that emerge when |
| 4 | +// lowering high-level element-wise ops with unranked tensor inputs. Consider |
| 5 | +// the following function incrementing and doubling the value of an input |
| 6 | +// unranked tensor using ops in a hypothetical high-level dialect called 'hl': |
| 7 | +// |
| 8 | +// func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> { |
| 9 | +// %0 = hl.inc %input : tensor<*xf32> |
| 10 | +// %1 = hl.double %0 : tensor<*xf32> |
| 11 | +// return %1 : tensor<*xf32> |
| 12 | +// } |
| 13 | +// |
| 14 | +// A possible strategy to lower 'hl.inc' consists in reshaping its operand into |
| 15 | +// a 1D tensor, creating a 1D tensor splat with the same total size as the input |
| 16 | +// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and |
| 17 | +// reshaping the result back into the original input shape. A similar process |
| 18 | +// applies for 'hl.double', except with a tensor splat with value 2.0 and an |
| 19 | +// 'arith.mulf' op. The body of the function in the test below contains the full |
| 20 | +// sequence. |
| 21 | +// |
| 22 | +// Since such lowering process would operate on individual 'hl' ops in a |
| 23 | +// context-oblivious manner, the emitted code produces a redundant IR pattern |
| 24 | +// where the result of 'arith.addf' is reshaped into an unranked tensor, just |
| 25 | +// for it to be immediately reshaped back into the 1D tensor consumed by |
| 26 | +// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor |
| 27 | +// shape ('shape.shape_of') and size ('shape.num_elements'). |
| 28 | +// |
| 29 | +// This test verifies that the consecutive application of a canonicalization and |
| 30 | +// a CSE pass successfully simplifies this emerging pattern, leading to a |
| 31 | +// version of the code in which the result of the emitted 'arith.addf' op |
| 32 | +// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op |
| 33 | +// associated with 'hl.double', as observed in the FileCheck directives. The |
| 34 | +// main rewrite patterns at play are 'shape.shape_of' canonicalization, |
| 35 | +// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression |
| 36 | +// elimination. |
| 37 | +// |
| 38 | + |
| 39 | +// CHECK-LABEL: @unranked_tensor_lowering |
| 40 | +// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> |
| 41 | + |
| 42 | +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 |
| 43 | +// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 |
| 44 | + |
| 45 | +// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex> |
| 46 | +// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index |
| 47 | +// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex> |
| 48 | +// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> |
| 49 | + |
| 50 | +// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32> |
| 51 | +// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32> |
| 52 | + |
| 53 | +// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32> |
| 54 | +// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32> |
| 55 | + |
| 56 | +// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> |
| 57 | +// CHECK: return %[[PRODUCT]] : tensor<*xf32> |
| 58 | + |
| 59 | +func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> { |
| 60 | + |
| 61 | + // Collapse input |
| 62 | + %input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex> |
| 63 | + %input_size = shape.num_elements %input_shape : tensor<?xindex> -> index |
| 64 | + %input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex> |
| 65 | + %input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> |
| 66 | + |
| 67 | + // Second operand for sum |
| 68 | + %one = arith.constant 1.0 : f32 |
| 69 | + %one_splat = tensor.splat %one[%input_size] : tensor<?xf32> |
| 70 | + |
| 71 | + // Compute sum and expand it |
| 72 | + %sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32> |
| 73 | + %sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> |
| 74 | + |
| 75 | + // Collapse sum |
| 76 | + %sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex> |
| 77 | + %sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index |
| 78 | + %sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex> |
| 79 | + %sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> |
| 80 | + |
| 81 | + // Second operand for product |
| 82 | + %two = arith.constant 2.0 : f32 |
| 83 | + %two_splat = tensor.splat %two[%sum_size] : tensor<?xf32> |
| 84 | + |
| 85 | + // Compute product and expand it |
| 86 | + %product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32> |
| 87 | + %product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> |
| 88 | + |
| 89 | + return %product : tensor<*xf32> |
| 90 | +} |
0 commit comments