Skip to content

Commit c3bb6d6

Browse files
tarun292facebook-github-bot
authored andcommitted
Improve example used to demonstrate delegate debug mapping generation (#519)
Summary: Improving the example model used to demonstrate the delegate mapping generation and logging logic. Reviewed By: cccclai Differential Revision: D49620454
1 parent a6ec5dc commit c3bb6d6

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

exir/backend/test/backend_with_delegate_mapping_demo.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,51 @@ def preprocess(
159159
@staticmethod
160160
# The sample model that will work with BackendWithDelegateMapping show above.
161161
def get_test_model_and_inputs():
162-
class ConvReLUAddModel(nn.Module):
162+
class SimpleConvNet(nn.Module):
163163
def __init__(self):
164-
super(ConvReLUAddModel, self).__init__()
165-
# Define a convolutional layer
166-
self.conv_layer = nn.Conv2d(
167-
in_channels=1, out_channels=64, kernel_size=3, padding=1
164+
super(SimpleConvNet, self).__init__()
165+
166+
# First convolutional layer
167+
self.conv1 = nn.Conv2d(
168+
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
169+
)
170+
self.relu1 = nn.ReLU()
171+
172+
# Second convolutional layer
173+
self.conv2 = nn.Conv2d(
174+
in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
168175
)
176+
self.relu2 = nn.ReLU()
177+
178+
def forward(self, x):
179+
# Forward pass through the first convolutional layer
180+
x = self.conv1(x)
181+
x = self.relu1(x)
182+
183+
# Forward pass through the second convolutional layer
184+
x = self.conv2(x)
185+
x = self.relu2(x)
186+
187+
return x
188+
189+
class ConvReLUTanModel(nn.Module):
190+
def __init__(self):
191+
super(ConvReLUTanModel, self).__init__()
192+
193+
# Define a convolutional layer
194+
self.conv_layer = SimpleConvNet()
169195

170196
def forward(self, x):
171197
# Forward pass through convolutional layer
172198
conv_output = self.conv_layer(x)
173-
# Apply ReLU activation
174-
relu_output = nn.functional.relu(conv_output)
175-
# Perform tan on relu output
176-
added_output = torch.tan(relu_output)
177-
return added_output
178199

179-
return (ConvReLUAddModel(), (torch.randn(1, 1, 32, 32),))
200+
# Perform tan on conv_output
201+
tan_output = torch.tan(conv_output)
202+
203+
return tan_output
204+
205+
batch_size = 4
206+
channels = 3
207+
height = 64
208+
width = 64
209+
return (ConvReLUTanModel(), (torch.randn(batch_size, channels, height, width),))

exir/backend/test/test_delegate_map_builder.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,24 @@ def test_backend_with_delegate_mapping(self) -> None:
121121
debug_handle_map = lowered_module.meta.get("debug_handle_map")
122122
self.assertIsNotNone(debug_handle_map)
123123
# There should be 3 backend ops in this model.
124-
self.assertEqual(len(debug_handle_map), 4)
124+
self.assertEqual(len(debug_handle_map), 5)
125125
# Check to see that all the delegate debug indexes in the range [0,2] are present.
126126
self.assertTrue(
127127
all(element in debug_handle_map.keys() for element in [0, 1, 2, 3])
128128
)
129-
lowered_module.program()
129+
130+
class CompositeModule(torch.nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
self.lowered_module = lowered_module
134+
135+
def forward(self, x):
136+
return self.lowered_module(x)
137+
138+
composite_model = CompositeModule()
139+
exir.capture(
140+
composite_model, inputs, exir.CaptureConfig()
141+
).to_edge().to_executorch()
130142

131143
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132144

runtime/executor/test/test_backend_with_delegate_mapping.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/runtime/backend/interface.h>
1010
#include <executorch/runtime/core/error.h>
1111
#include <executorch/runtime/core/evalue.h>
12+
#include <executorch/runtime/core/event_tracer_hooks_delegate.h>
1213

1314
#include <cstdlib> /* strtol */
1415
#include <cstring>
@@ -125,22 +126,23 @@ class BackendWithDelegateMapping final : public PyTorchBackendInterface {
125126
"Op name = %s Delegate debug index = %ld",
126127
op_list->ops[index].name,
127128
op_list->ops[index].debug_handle);
129+
event_tracer_log_profiling_delegate(
130+
context.event_tracer(),
131+
nullptr,
132+
op_list->ops[index].debug_handle,
133+
0,
134+
1);
135+
/**
136+
If you used string based delegate debug identifiers then the profiling
137+
call would be as below.
138+
event_tracer_log_profiling_delegate(
139+
context.event_tracer(),
140+
pointer_to_delegate_debug_string,
141+
-1,
142+
0,
143+
1);
144+
*/
128145
}
129-
// The below API's are not available yet but they are a representative
130-
// example of what we'll be enabling.
131-
/*
132-
Option 1: Log performance event with an ID. An integer ID must have been
133-
provided to DelegateMappingBuilder during AOT compilation.
134-
*/
135-
// EVENT_TRACER_LOG_DELEGATE_PROFILING_EVENT_ID(op_list->ops[index].debug_handle,
136-
// start_time, end_time);
137-
/*
138-
Option 2: Log performance event with a name. A string
139-
name must have been provided to DelegateMappingBuilder during AOT
140-
compilation.
141-
*/
142-
// EVENT_TRACER_LOG_DELEGATE_PROFILING_EVENT_NAME(op_list->ops[index].name,
143-
// start_time, end_time);
144146

145147
return Error::Ok;
146148
}

0 commit comments

Comments
 (0)