@@ -273,134 +273,6 @@ __attribute__((noinline)) void conv2d_nhwc_core_generic(
273
273
}
274
274
}
275
275
276
- void convolution_nchw (
277
- const Tensor& input,
278
- const Tensor& weight,
279
- const Tensor& bias,
280
- IntArrayRef stride,
281
- IntArrayRef padding,
282
- IntArrayRef dilation,
283
- int16_t groups,
284
- Tensor& output) {
285
- bool conv1d = input.dim () == 3 ;
286
- // input = [n, c, h, w]
287
- const int n = input.size (0 );
288
- const int c = input.size (1 );
289
- const int h = conv1d ? 1 : input.size (2 );
290
- const int w = conv1d ? input.size (2 ) : input.size (3 );
291
- // weight = [oc, wc, wh, ww]
292
- const int oc = weight.size (0 );
293
- const int wc = weight.size (1 );
294
- const int wh = conv1d ? 1 : weight.size (2 );
295
- const int ww = conv1d ? weight.size (2 ) : weight.size (3 );
296
- // output = [n, oc, oh, ow]
297
- const int oh = conv1d ? 1 : output.size (2 );
298
- const int ow = conv1d ? output.size (2 ) : output.size (3 );
299
-
300
- float * __restrict__ p_out = output.mutable_data_ptr <float >();
301
- const float * __restrict__ p_in = input.const_data_ptr <float >();
302
- const float * __restrict__ p_weight = weight.const_data_ptr <float >();
303
- const float * __restrict__ p_bias = bias.const_data_ptr <float >();
304
-
305
- conv2d_nchw_core_generic<>(
306
- p_in,
307
- p_weight,
308
- p_bias,
309
- p_out,
310
- n,
311
- c,
312
- h,
313
- w,
314
- oc,
315
- wc,
316
- wh,
317
- ww,
318
- oh,
319
- ow,
320
- conv1d ? 1 : stride[0 ],
321
- conv1d ? stride[0 ] : stride[1 ],
322
- conv1d ? 0 : padding[0 ],
323
- conv1d ? padding[0 ] : padding[1 ],
324
- conv1d ? 1 : dilation[0 ],
325
- conv1d ? dilation[0 ] : dilation[1 ],
326
- groups);
327
- }
328
-
329
- void convolution_nhwc (
330
- const Tensor& input,
331
- const Tensor& weight,
332
- const Tensor& bias,
333
- IntArrayRef stride,
334
- IntArrayRef padding,
335
- IntArrayRef dilation,
336
- int16_t groups,
337
- Tensor& output) {
338
- bool conv1d = input.dim () == 3 ;
339
- // input = [n, h, w, c]
340
- const int n = input.size (0 );
341
- const int h = conv1d ? 1 : input.size (1 );
342
- const int w = conv1d ? input.size (1 ) : input.size (2 );
343
- const int c = conv1d ? input.size (2 ) : input.size (3 );
344
-
345
- // weight = [oc, wh, ww, wc]
346
- const int oc = weight.size (0 );
347
- const int wh = conv1d ? 1 : weight.size (1 );
348
- const int ww = conv1d ? weight.size (1 ) : weight.size (2 );
349
- const int wc = conv1d ? weight.size (2 ) : weight.size (3 );
350
-
351
- // output = [n, oh, ow, oc]
352
- const int oh = conv1d ? 1 : output.size (1 );
353
- const int ow = conv1d ? output.size (1 ) : output.size (2 );
354
-
355
- float * __restrict__ p_out = output.mutable_data_ptr <float >();
356
- const float * __restrict__ p_in = input.const_data_ptr <float >();
357
- const float * __restrict__ p_weight = weight.const_data_ptr <float >();
358
- const float * __restrict__ p_bias = bias.const_data_ptr <float >();
359
-
360
- conv2d_nhwc_core_generic<>(
361
- p_in,
362
- p_weight,
363
- p_bias,
364
- p_out,
365
- n,
366
- h,
367
- w,
368
- c,
369
- oc,
370
- wh,
371
- ww,
372
- wc,
373
- oh,
374
- ow,
375
- conv1d ? 1 : stride[0 ],
376
- conv1d ? stride[0 ] : stride[1 ],
377
- conv1d ? 0 : padding[0 ],
378
- conv1d ? padding[0 ] : padding[1 ],
379
- conv1d ? 1 : dilation[0 ],
380
- conv1d ? dilation[0 ] : dilation[1 ],
381
- groups);
382
- }
383
-
384
- void convolution_out (
385
- __ET_UNUSED KernelRuntimeContext& ctx,
386
- const Tensor& input,
387
- const Tensor& weight,
388
- const Tensor& bias,
389
- IntArrayRef stride,
390
- IntArrayRef padding,
391
- IntArrayRef dilation,
392
- int64_t groups,
393
- bool channel_last,
394
- Tensor& output) {
395
- if (channel_last) {
396
- convolution_nhwc (
397
- input, weight, bias, stride, padding, dilation, groups, output);
398
- } else {
399
- convolution_nchw (
400
- input, weight, bias, stride, padding, dilation, groups, output);
401
- }
402
- }
403
-
404
276
// The quantized convolution kernel. in_scale and weight_scale are implicit in
405
277
// bias_scale, since it is a product of the two. The kernel will branch to
406
278
// quantized::conv1d or quantized::conv2d based on the dimensionality of
0 commit comments