@@ -148,8 +148,9 @@ class ArmBackend final : public PyTorchBackendInterface {
148
148
if (both_char and permuted_input_shape) {
149
149
// permuted byte copy CHW to HWC
150
150
permute_CHW_to_HWC (
151
- scratch_addr,
152
151
tensor_in.mutable_data_ptr <char >(),
152
+ scratch_addr,
153
+ tensor_in.size (1 ),
153
154
tensor_in.size (2 ),
154
155
tensor_in.size (3 ));
155
156
} else if (both_char or both_int) {
@@ -204,13 +205,31 @@ class ArmBackend final : public PyTorchBackendInterface {
204
205
// Process input EValue into scratch
205
206
// Outputs are in the index immediately after inputs
206
207
auto tensor_out = args[handles.inputs ->count + i]->toTensor ();
207
- for (int j = 0 ; j < tensor_out.numel (); j++) {
208
- if (tensor_out.scalar_type () == ScalarType::Char) {
209
- char * output_address = (char *)output_addr;
210
- tensor_out.mutable_data_ptr <char >()[j] = output_address[j];
211
- } else {
212
- int * output_address = (int *)output_addr;
213
- tensor_out.mutable_data_ptr <int >()[j] = output_address[j];
208
+ bool permuted_output_shape;
209
+ ET_CHECK_OK_OR_RETURN_ERROR (check_requires_permute (
210
+ i,
211
+ tensor_out,
212
+ &handles.outputs ->io [i],
213
+ execution_handle->permuted_io_flag ,
214
+ &permuted_output_shape));
215
+ if (tensor_out.scalar_type () == ScalarType::Char and
216
+ permuted_output_shape) {
217
+ char * output_address = (char *)output_addr;
218
+ permute_HWC_to_CHW (
219
+ output_address,
220
+ tensor_out.mutable_data_ptr <char >(),
221
+ tensor_out.size (1 ),
222
+ tensor_out.size (2 ),
223
+ tensor_out.size (3 ));
224
+ } else {
225
+ for (int j = 0 ; j < tensor_out.numel (); j++) {
226
+ if (tensor_out.scalar_type () == ScalarType::Char) {
227
+ char * output_address = (char *)output_addr;
228
+ tensor_out.mutable_data_ptr <char >()[j] = output_address[j];
229
+ } else {
230
+ int * output_address = (int *)output_addr;
231
+ tensor_out.mutable_data_ptr <int >()[j] = output_address[j];
232
+ }
214
233
}
215
234
}
216
235
}
@@ -225,51 +244,62 @@ class ArmBackend final : public PyTorchBackendInterface {
225
244
private:
226
245
Error check_requires_permute (
227
246
int index,
228
- const exec_aten::Tensor tensor_in ,
229
- VelaIO* input ,
247
+ const exec_aten::Tensor tensor ,
248
+ VelaIO* io ,
230
249
bool permuted_io_flag,
231
250
bool * is_permuted) const {
232
- bool permuted_input_shape = false ;
233
- if (tensor_in .dim () == 4 ) {
251
+ bool permuted_shape = false ;
252
+ if (tensor .dim () == 4 ) {
234
253
// special case for NHWC workaround in AOT; as the compilation has
235
254
// permuted to channel last in an undetectable way, we assume here
236
- // that the application has similarly permuted any input tensors.
237
- permuted_input_shape = tensor_in.size (0 ) == input->shape [0 ] &&
238
- tensor_in.size (1 ) == input->shape [3 ] &&
239
- tensor_in.size (2 ) == input->shape [1 ] &&
240
- tensor_in.size (3 ) == input->shape [2 ];
241
- if (permuted_input_shape) {
242
- ET_LOG (Info, " Tensor input %d will be permuted" , index);
255
+ // that the application has similarly permuted any input/output tensors.
256
+ permuted_shape = tensor.size (0 ) == io->shape [0 ] &&
257
+ tensor.size (1 ) == io->shape [3 ] && tensor.size (2 ) == io->shape [1 ] &&
258
+ tensor.size (3 ) == io->shape [2 ];
259
+ if (permuted_shape) {
260
+ ET_LOG (Info, " Tensor input/output %d will be permuted" , index);
243
261
}
244
- if (permuted_io_flag != permuted_input_shape) {
245
- ET_LOG (Error, " Permute compile flag and permuted input don't agree" );
262
+ if (permuted_io_flag != permuted_shape) {
263
+ ET_LOG (
264
+ Error,
265
+ " Permute compile flag and permuted input/output don't agree" );
246
266
return Error::InvalidProgram;
247
267
}
248
268
}
249
- if (!permuted_input_shape ) {
269
+ if (!permuted_shape ) {
250
270
// Error check matching shapes in the general case
251
- for (int i = 0 ; i < tensor_in .dim (); i++) {
252
- if (tensor_in .size (i) != input ->shape [i]) {
253
- ET_LOG (Error, " Tensor input %d mismatched shape" , index);
271
+ for (int i = 0 ; i < tensor .dim (); i++) {
272
+ if (tensor .size (i) != io ->shape [i]) {
273
+ ET_LOG (Error, " Tensor input/output %d mismatched shape" , index);
254
274
ET_LOG (
255
275
Error,
256
276
" dimension %d mismatch, %zd != %d" ,
257
277
index,
258
- tensor_in .size (i),
259
- input ->shape [i]);
278
+ tensor .size (i),
279
+ io ->shape [i]);
260
280
return Error::InvalidProgram;
261
281
}
262
282
}
263
283
}
264
- *is_permuted = permuted_input_shape ;
284
+ *is_permuted = permuted_shape ;
265
285
return Error::Ok;
266
286
}
267
287
268
- void permute_CHW_to_HWC (char * input, char * output, int H, int W) const {
288
+ void permute_CHW_to_HWC (char * input, char * output, int C, int H, int W)
289
+ const {
269
290
for (int i = 0 ; i != H * W; ++i) {
270
- output[i * 3 + 0 ] = input[i + 0 * W * H];
271
- output[i * 3 + 1 ] = input[i + 1 * W * H];
272
- output[i * 3 + 2 ] = input[i + 2 * W * H];
291
+ for (int j = 0 ; j < C; ++j) {
292
+ output[i * C + j] = input[i + j * W * H];
293
+ }
294
+ }
295
+ }
296
+
297
+ void permute_HWC_to_CHW (char * input, char * output, int C, int H, int W)
298
+ const {
299
+ for (int i = 0 ; i != H * W; ++i) {
300
+ for (int j = 0 ; j < C; ++j) {
301
+ output[i + j * W * H] = input[i * C + j];
302
+ }
273
303
}
274
304
}
275
305
};
0 commit comments