@@ -439,8 +439,7 @@ class usm_ndarray : public py::object
439
439
440
440
char * get_data () const
441
441
{
442
- PyObject * raw_o = this -> ptr ();
443
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
442
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
444
443
445
444
return UsmNDArray_GetData (raw_ar );
446
445
}
@@ -452,16 +451,14 @@ class usm_ndarray : public py::object
452
451
453
452
int get_ndim () const
454
453
{
455
- PyObject * raw_o = this -> ptr ();
456
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
454
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
457
455
458
456
return UsmNDArray_GetNDim (raw_ar );
459
457
}
460
458
461
459
const py ::ssize_t * get_shape_raw () const
462
460
{
463
- PyObject * raw_o = this -> ptr ();
464
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
461
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
465
462
466
463
return UsmNDArray_GetShape (raw_ar );
467
464
}
@@ -474,16 +471,14 @@ class usm_ndarray : public py::object
474
471
475
472
const py ::ssize_t * get_strides_raw () const
476
473
{
477
- PyObject * raw_o = this -> ptr ();
478
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
474
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
479
475
480
476
return UsmNDArray_GetStrides (raw_ar );
481
477
}
482
478
483
479
py ::ssize_t get_size () const
484
480
{
485
- PyObject * raw_o = this -> ptr ();
486
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
481
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
487
482
488
483
int ndim = UsmNDArray_GetNDim (raw_ar );
489
484
const py ::ssize_t * shape = UsmNDArray_GetShape (raw_ar );
@@ -499,8 +494,7 @@ class usm_ndarray : public py::object
499
494
500
495
std ::pair < py ::ssize_t , py ::ssize_t > get_minmax_offsets () const
501
496
{
502
- PyObject * raw_o = this -> ptr ();
503
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
497
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
504
498
505
499
int nd = UsmNDArray_GetNDim (raw_ar );
506
500
const py ::ssize_t * shape = UsmNDArray_GetShape (raw_ar );
@@ -533,33 +527,29 @@ class usm_ndarray : public py::object
533
527
534
528
sycl ::queue get_queue () const
535
529
{
536
- PyObject * raw_o = this -> ptr ();
537
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
530
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
538
531
539
532
DPCTLSyclQueueRef QRef = UsmNDArray_GetQueueRef (raw_ar );
540
533
return * (reinterpret_cast < sycl ::queue * > (QRef ));
541
534
}
542
535
543
536
int get_typenum () const
544
537
{
545
- PyObject * raw_o = this -> ptr ();
546
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
538
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
547
539
548
540
return UsmNDArray_GetTypenum (raw_ar );
549
541
}
550
542
551
543
int get_flags () const
552
544
{
553
- PyObject * raw_o = this -> ptr ();
554
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
545
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
555
546
556
547
return UsmNDArray_GetFlags (raw_ar );
557
548
}
558
549
559
550
int get_elemsize () const
560
551
{
561
- PyObject * raw_o = this -> ptr ();
562
- PyUSMArrayObject * raw_ar = reinterpret_cast < PyUSMArrayObject * > (raw_o );
552
+ PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
563
553
564
554
return UsmNDArray_GetElementSize (raw_ar );
565
555
}
@@ -575,6 +565,12 @@ class usm_ndarray : public py::object
575
565
int flags = this -> get_flags ();
576
566
return static_cast < bool > (flags & USM_ARRAY_F_CONTIGUOUS );
577
567
}
568
+
569
+ private :
570
+ PyUSMArrayObject * usm_array_ptr () const
571
+ {
572
+ return reinterpret_cast < PyUSMArrayObject * > (m_ptr );
573
+ }
578
574
};
579
575
580
576
} // end namespace tensor
0 commit comments