|
| 1 | +# AutoHeuristic |
| 2 | +AutoHeuristic is a framework that allows one to use results from autotuning to learn a heuristic as a decision tree, that can be generated to code and shipped with compiler. |
| 3 | + |
| 4 | +## How to use AutoHeuristic |
| 5 | +In general, the following steps have to performed: |
| 6 | +- The AutoHeursitic constructor has to be called. |
| 7 | +- A script that runs benchmarks in order to collect training data has to be implemented. |
| 8 | +- The train_decision.py (if you want to learn a decision tree) or train_regression.py (if you want to learn a regression tree) script has to be run in order to learn the heuristic and generate it to code. |
| 9 | + |
| 10 | +## Step 1: Calling the AutoHeuristic constructor |
| 11 | +Currently, two use cases are supported: |
| 12 | + |
| 13 | +### Use case 1: Local autotuning |
| 14 | +When your feedback function is able to immediately return a result, you can just call the AutoHeuristic constructor. This is done e.g. for pad_mm |
| 15 | +``` |
| 16 | +autoheuristic = AutoHeuristic( |
| 17 | + fallback=fallback, |
| 18 | + choices=choices, |
| 19 | + feedback=feedback, |
| 20 | + context=context, |
| 21 | + name=name, |
| 22 | + augment_context=pad_mm_operations(), |
| 23 | + precondition=pad_mm_precondition, |
| 24 | +) |
| 25 | +``` |
| 26 | +Here, `feedback` is a function that benchmarks a given choice and returns the execution time. For an example, see: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/pad_mm.py. |
| 27 | + |
| 28 | +### Use case 2: Kernel choice selection |
| 29 | +If you want to use AutoHeuristic for kernel choice selection, you have to call the AutoHeuristicSelectAlgorithm constructor. This is done e.g. for mixed_mm |
| 30 | +``` |
| 31 | +autoheuristic = AutoHeuristicSelectAlgorithm( |
| 32 | + fallback=fallback, |
| 33 | + choices=choices, |
| 34 | + input_nodes=input_nodes, |
| 35 | + context=context, |
| 36 | + name=name, |
| 37 | + augment_context=ops, |
| 38 | + precondition=precondition, |
| 39 | +) |
| 40 | +``` |
| 41 | +This call has to be followed by a call to `autotune_select_algorithm()`, |
| 42 | +``` |
| 43 | +autotune_select_algorithm(name, choices, input_nodes, layout) |
| 44 | +``` |
| 45 | +Note that `choices`, `input_nodes`, and `name` in the `AutoHeuristicSelectAlgorithm()` and `autotune_select_algorithm()` calls have to match when you want to use AutoHeuristic to collect data. |
| 46 | + |
| 47 | +For an example, see: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py |
| 48 | + |
| 49 | +## Step 2: Collecting training data |
| 50 | +After adding the call to the AutoHeuristic constructor, you need to collect training data in order to learn a heuristic. Let's say you have a script `run.py` that triggers the AutoHeuristic constructor that you just added. Run the following command in order to store data into file `train.txt`: |
| 51 | +``` |
| 52 | +TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH="train.txt" \ |
| 53 | + TORCHINDUCTOR_AUTOHEURISTIC_COLLECT="pad_mm" python run.py |
| 54 | +``` |
| 55 | +Replace "pad_mm" with the name you provided in the call to the AutoHeuristic constructor. |
| 56 | + |
| 57 | +AutoHeuristic provides a `BenchmarkRunner` class (https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/benchmark_runner.py) that simplifies the process of collecting data. To use it, create a new class that subclasses `BenchmarkRunner`, and implements the `run_benchmark()` and `create_input()` methods. |
| 58 | + |
| 59 | +These examples might be helpful: |
| 60 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py |
| 61 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/mixed_mm/gen_data_mixed_mm.py |
| 62 | + |
| 63 | + |
| 64 | +## Step 3: Learning a heuristic and using it |
| 65 | +Once you have collected enough training data, you are ready to learn a heuristic: |
| 66 | +``` |
| 67 | +python torchgen/_autoheuristic/train_decision.py train.txt --heuristic-name SimpleHeuristic |
| 68 | +``` |
| 69 | +will learn a heuristic and generate it to `torch/_inductor/autoheuristic/artifacts/_SimpleHeuristic.py`. |
| 70 | + |
| 71 | +You can now use your learned heuristic: |
| 72 | +``` |
| 73 | +TORCHINDUCTOR_AUTOHEURISTIC_USE="pad_mm" python run.py |
| 74 | +``` |
| 75 | +Here, you again have to replace "pad_mm" with the name you provided in the call to the AutoHeuristic constructor. |
| 76 | + |
| 77 | +Instead of just running the `train_decision.py` script, you probably want to customize the training process in some way. To do this, create a new class that subclasses `AHTrainDecision` and override methods you want to customize. Here are some examples: |
| 78 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/mixed_mm/train_decision_mixedmm.py |
| 79 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/pad_mm/train_decision_pad_mm.py |
| 80 | + |
| 81 | +## Other |
| 82 | + |
| 83 | +### How do I specify features that the heuristic is going to use to make a decision? |
| 84 | +The AutoHeuristic constructor requires a `context` argument of type `AHContext`, which will contain all features. You specify features in the following way: |
| 85 | +``` |
| 86 | +context = AHContext() |
| 87 | +
|
| 88 | +# adding numerical features |
| 89 | +context.add_feature("m", mat1.shape[0]) |
| 90 | +context.add_feature("k", mat1.shape[1]) |
| 91 | +
|
| 92 | +# adding a categorical feture |
| 93 | +context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) |
| 94 | +``` |
| 95 | + |
| 96 | +You might want to use features that are a combination of other features, such as `m*k`. You can of course add such features in the same way as above, i.e., |
| 97 | +``` |
| 98 | +context.add_feature("m*k", mat1.shape[0] * mat1.shape[1]) |
| 99 | +``` |
| 100 | +but AutoHeuristic also provides a way to 'augment' features. Augmented features are not stored when data is collected, instead they are created before a heuristic is learned, or before a learned heuristic is used. You can specify such augmented features by creating a list of `AHOperation` objects: |
| 101 | +``` |
| 102 | +def m_times_k(data: Any) -> float: |
| 103 | + return data['m'] * data['k'] |
| 104 | +
|
| 105 | +m_times_k_op = AHOperation("m*k', m_times_k) |
| 106 | +ah_operations = [m_times_k_op] |
| 107 | +
|
| 108 | +# specify augmented features by setting `augment_context` to `ah_operations` |
| 109 | +autoheuristic = AutoHeuristic(..., augment_context=ah_operations, ...) |
| 110 | +``` |
| 111 | + |
| 112 | +Note that you also have to specify these operations when you want to learn a heuristic. Look at the `add_new_features()` method in these examples, to see how it is done: |
| 113 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/mixed_mm/train_decision_mixedmm.py |
| 114 | +- https://github.com/pytorch/pytorch/blob/main/torchgen/_autoheuristic/pad_mm/train_decision_pad_mm.py |
| 115 | + |
| 116 | +### Where has AutoHeuristic already been used? |
| 117 | +Take a look at the following PRs in which AutoHeuristic has enabled for various optimizations. |
| 118 | +Looking at these examples may be helpful if you want to use AutoHeuristic yourself. |
| 119 | +- pad_mm: https://github.com/pytorch/pytorch/pull/128643 |
| 120 | +- mixed_mm: |
| 121 | + - Enabling of AutoHeuristic: https://github.com/pytorch/pytorch/pull/131610 |
| 122 | + - Script to collect data: https://github.com/pytorch/pytorch/pull/131611 |
| 123 | + - A100 heuristic: https://github.com/pytorch/pytorch/pull/131613 |
| 124 | + - H100 heuristic: https://github.com/pytorch/pytorch/pull/132685 |
| 125 | +- flex_attention: https://github.com/pytorch/pytorch/pull/130398 |
| 126 | +- mm (heuristic for ranking choices): |
| 127 | + - https://github.com/pytorch/pytorch/pull/131615 |
| 128 | + - https://github.com/pytorch/pytorch/pull/131617 |
| 129 | + - https://github.com/pytorch/pytorch/pull/131705 |
| 130 | + - https://github.com/pytorch/pytorch/pull/131714 |
0 commit comments