Skip to content

Commit 4ab007d

Browse files
Brox Chenrolandschulz
andauthored
[SYCL] Add operator overloading for aggregate types in annotated_ref (#11971)
Added several operator overloading for aggregate types for annotated_ref class. This https://godbolt.org/z/h5cTTr17K would be a good example to show why this is not working without this fix. This PR includes several changes: 1. Propogate operators including all binaries and unary operators, including arithmetic, comparator, and logical. 2. Using perfecting forwarding for binaries operators, compound operators, and unary operators. This covers cases in which the sequence of conversions will be correct when implicit conversion is involved. i.e ``` annotated_ref<int> a; double b; auto p = a + b; // expected to be (double)a + b, and p should be double ``` without this fix, ``` T operator+(T a) const; annotated_ref<int> a; double b; auto p = a + b; // this become a + (int)b, and p will be int ``` --------- Co-authored-by: Roland Schulz <[email protected]>
1 parent 19c17e2 commit 4ab007d

File tree

4 files changed

+560
-32
lines changed

4 files changed

+560
-32
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_annotated_ptr.asciidoc

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,32 @@ class annotated_ref {
549549
public:
550550
annotated_ref(const annotated_ref&) = delete;
551551
operator T() const;
552-
T operator=(T) const;
552+
553+
template <typename O> //available only if O is not an annotated_ref type
554+
T operator=(O&&) const;
553555
T operator=(const annotated_ref&) const;
556+
554557
// OP is: +=, -=, *=, /=, %=, <<=, >>=, &=, |=, ^=
555-
T operatorOP(T) const;
558+
template <typename O> //available only if O is not an annotated_ref type
559+
T operatorOP(O&& a) const;
560+
T operatorOP(const annotated_ref &b) const;
561+
556562
T operator++() const;
557563
T operator++(int) const;
558564
T operator--() const;
559565
T operator--(int) const;
566+
567+
// OP is: +, -, *, /, %, <<, >>, &, |, ^, <, <=, >, >=, ==, ~=, &&, ||
568+
template <typename O>
569+
auto friend operatorOP(O&& a, const annotated_ref& b) ->
570+
decltype(std::forward<O>(a) OP std::declval<T>());
571+
template <typename O> //available only if O is not an annotated_ref type
572+
auto friend operatorOP(const annotated_ref& a, O&& b) ->
573+
decltype(std::declval<T>() OP std::forward<O>(b));
574+
575+
// OP is: +, -, !, ~
576+
template <typename O=T>
577+
auto operatorOP() -> decltype(OP std::declval<O>());
560578
};
561579
} // namespace sycl::ext::oneapi::experimental
562580
```
@@ -581,10 +599,13 @@ annotations when the object is loaded from memory.
581599
a|
582600
[source,c++]
583601
----
584-
T operator=(T) const;
602+
template <typename O>
603+
T operator=(O&&) const;
585604
----
586605
|
587-
Writes an object of type `T` to the location referenced by this wrapper,
606+
Writes an object of type `O` to the location referenced by this wrapper.
607+
`O` cannot be a type of `annotated_ref`.
608+
588609
applying the annotations when the object is stored to memory.
589610

590611
// --- ROW BREAK ---
@@ -608,18 +629,47 @@ Does not rebind the reference!
608629
a|
609630
[source,c++]
610631
----
611-
T operatorOP(T) const;
632+
template <typename O>
633+
T operatorOP(O&& a) const;
612634
----
613635
a|
614636
Where [code]#OP# is: [code]#pass:[+=]#, [code]#-=#,[code]#*=#, [code]#/=#, [code]#%=#, [code]#+<<=+#, [code]#>>=#, [code]#&=#, [code]#\|=#, [code]#^=#.
615637

616-
Compound assignment operators. Return result by value.
638+
Compound assignment operators for type `O`. `O` cannot be a type of `annotated_ref`.
639+
640+
Return result by value.
641+
642+
Available only if the corresponding assignment operator OP is available for `T` taking a type of `O`.
643+
Equivalent to:
644+
```c++
645+
T tmp = *this; // Reads from memory
646+
// with annotations
647+
tmp OP std::forward<O>(a);
648+
*this = tmp; // Writes to memory
649+
// with annotations
650+
return tmp;
651+
```
652+
// --- ROW BREAK ---
653+
a|
654+
[source,c++]
655+
----
656+
T operatorOP(const annotated_ref &b) const;
657+
----
658+
a|
659+
Where [code]#OP# is: [code]#pass:[+=]#, [code]#-=#,[code]#*=#, [code]#/=#, [code]#%=#, [code]#+<<=+#, [code]#>>=#, [code]#&=#, [code]#\|=#, [code]#^=#.
660+
661+
Compound assignment operators for type `annotated_ref`.
662+
663+
Return result by value.
664+
617665
Available only if the corresponding assignment operator OP is available for `T`.
618666
Equivalent to:
619667
```c++
620668
T tmp = *this; // Reads from memory
621669
// with annotations
622-
tmp OP val;
670+
T tmp2 = b; // Reads from memory
671+
// with annotations
672+
tmp OP b;
623673
*this = tmp; // Writes to memory
624674
// with annotations
625675
return tmp;
@@ -638,6 +688,64 @@ Increment and decrement operator of annotated_ref. Increment/Decrement the objec
638688
referenced by this wrapper via ``T``'s Increment/Decrement operator.
639689

640690
The annotations are applied when the object `T` is loaded and stored to the memory.
691+
692+
a|
693+
[source,c++]
694+
----
695+
template <typename O>
696+
auto friend operatorOP(O&& a, const annotated_ref& b) ->
697+
decltype(std::forward<O>(a) OP std::declval<T>());
698+
----
699+
a|
700+
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#,[code]#*#, [code]#/#, [code]#%#, [code]#+<<+#, [code]#>>#, [code]#&#, [code]#\|#, [code]#\^#, [code]#<#, [code]#<=#, [code]#>#, [code]#>=#, [code]#==#, [code]#!=#, [code]#&&#, [code]#\|\|#.
701+
702+
Defines a hidden friend operator `OP` overload for type `O` and `annotated_ref`.
703+
704+
Let `operatorOP` denotes the operator used. The overloaded operator `operatorOP` utilizes
705+
`operatorOP(O&&, T&&)` and is available only if `operatorOP(O&&, T&&)` is well formed. The value and result
706+
is the same as the result of `operatorOP(O&&, T&&)` applied to the objects of
707+
type `O` and `T`.
708+
709+
The annotations from `PropertyListT` are applied when the object `b` is loaded from memory.
710+
711+
a|
712+
[source,c++]
713+
----
714+
template <typename O>
715+
auto friend operatorOP(const annotated_ref& a, O&& b) ->
716+
decltype(std::declval<T>() OP std::forward<O>(b));
717+
----
718+
a|
719+
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#,[code]#*#, [code]#/#, [code]#%#, [code]#+<<+#, [code]#>>#, [code]#&#, [code]#\|#, [code]#\^#, [code]#<#, [code]#<=#, [code]#>#, [code]#>=#, [code]#==#, [code]#!=#, [code]#&&#, [code]#\|\|#.
720+
721+
Defines a hidden friend operator `OP` overload for type `annotated_ref` and `O`. `O` cannot be
722+
a type of `annotated_ref`.
723+
724+
Let `operatorOP` denotes the operator used. The overloaded operator `operatorOP` utilizes
725+
`operatorOP(T&&, O&&)` and is available only if `operatorOP(T&&, O&&)` is well formed. The value and result
726+
is the same as the result of `operatorOP(T&&, O&&)` applied to the objects of
727+
type `T` and `O`.
728+
729+
The annotations from `PropertyListT` are applied when the object `a` is loaded from memory.
730+
731+
a|
732+
[source,c++]
733+
----
734+
template <typename O=T>
735+
auto operatorOP() -> decltype(OP std::declval<O>());
736+
----
737+
a|
738+
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#, [code]#!#, [code]#~#.
739+
740+
Defines a operator `OP` overload for types `O` where the default type is `T`.
741+
742+
Let `operatorOP` denotes the operator used. The overloaded operator
743+
`operatorOP` utilizes `operatorOP(O)` and is available only if `operatorOP(O)`
744+
is well formed. The value and result is the same as the result of `operatorOP(O)`
745+
applied to the objects of type `O`.
746+
747+
The annotations from `PropertyListT` are applied when the object `a` is loaded from memory.
748+
641749
|===
642750

643751
== Issues related to `annotated_ptr`
@@ -685,6 +793,7 @@ the alignment is set up.
685793
[options="header"]
686794
|========================================
687795
|Rev|Date|Author|Changes
796+
|5|2023-11-30|Brox Chen|API fixes: operators fowarding for annnotated_ref
688797
|4|2023-06-28|Roland Schulz|API fixes: constructors and annotated_ref assignment
689798
|3|2022-04-05|Abhishek Tiwari|*Addressed review comments*
690799
|2|2022-03-07|Abhishek Tiwari|*Corrected API and updated description*

sycl/include/sycl/ext/oneapi/experimental/annotated_ptr/annotated_ptr.hpp

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
//
12
//==----------- annotated_ptr.hpp - SYCL annotated_ptr extension -----------==//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -31,14 +32,6 @@ namespace oneapi {
3132
namespace experimental {
3233

3334
namespace {
34-
#define PROPAGATE_OP(op) \
35-
T operator op##=(T rhs) const { \
36-
T t = *this; \
37-
t op## = rhs; \
38-
*this = t; \
39-
return t; \
40-
}
41-
4235
// compare strings on compile time
4336
constexpr bool compareStrs(const char *Str1, const char *Str2) {
4437
return std::string_view(Str1) == Str2;
@@ -66,6 +59,7 @@ struct PropertiesFilter {
6659
std::tuple<>>::type...>;
6760
};
6861
} // namespace
62+
6963
template <typename T, typename PropertyListT = empty_properties_t>
7064
class annotated_ref {
7165
// This should always fail when instantiating the unspecialized version.
@@ -74,6 +68,17 @@ class annotated_ref {
7468
static_assert(is_valid_property_list, "Property list is invalid.");
7569
};
7670

71+
namespace detail {
72+
template <class T> struct is_ann_ref_impl : std::false_type {};
73+
template <class T, class P>
74+
struct is_ann_ref_impl<annotated_ref<T, P>> : std::true_type {};
75+
template <class T, class P>
76+
struct is_ann_ref_impl<const annotated_ref<T, P>> : std::true_type {};
77+
template <class T>
78+
constexpr bool is_ann_ref_v =
79+
is_ann_ref_impl<std::remove_reference_t<T>>::value;
80+
} // namespace detail
81+
7782
template <typename T, typename... Props>
7883
class annotated_ref<T, detail::properties_t<Props...>> {
7984
using property_list_t = detail::properties_t<Props...>;
@@ -84,11 +89,12 @@ class annotated_ref<T, detail::properties_t<Props...>> {
8489

8590
private:
8691
T *m_Ptr;
87-
annotated_ref(T *Ptr) : m_Ptr(Ptr) {}
92+
explicit annotated_ref(T *Ptr) : m_Ptr(Ptr) {}
8893

8994
public:
9095
annotated_ref(const annotated_ref &) = delete;
9196

97+
// implicit conversion with annotaion
9298
operator T() const {
9399
#ifdef __SYCL_DEVICE_ONLY__
94100
return *__builtin_intel_sycl_ptr_annotation(
@@ -99,30 +105,100 @@ class annotated_ref<T, detail::properties_t<Props...>> {
99105
#endif
100106
}
101107

102-
T operator=(T Obj) const {
108+
// assignment operator with annotaion
109+
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>>
110+
T operator=(O &&Obj) const {
103111
#ifdef __SYCL_DEVICE_ONLY__
104-
*__builtin_intel_sycl_ptr_annotation(
105-
m_Ptr, detail::PropertyMetaInfo<Props>::name...,
106-
detail::PropertyMetaInfo<Props>::value...) = Obj;
112+
return *__builtin_intel_sycl_ptr_annotation(
113+
m_Ptr, detail::PropertyMetaInfo<Props>::name...,
114+
detail::PropertyMetaInfo<Props>::value...) =
115+
std::forward<O>(Obj);
107116
#else
108-
*m_Ptr = Obj;
117+
return *m_Ptr = std::forward<O>(Obj);
109118
#endif
110-
return Obj;
111119
}
112120

113-
T operator=(const annotated_ref &Ref) const { return *this = T(Ref); }
121+
template <class O, class P>
122+
T operator=(const annotated_ref<O, P> &Ref) const {
123+
O t2 = Ref;
124+
return *this = t2;
125+
}
114126

127+
// propagate compound operators
128+
#define PROPAGATE_OP(op) \
129+
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
130+
T operator op(O &&rhs) const { \
131+
T t = *this; \
132+
t op std::forward<O>(rhs); \
133+
*this = t; \
134+
return t; \
135+
} \
136+
template <class O, class P> \
137+
T operator op(const annotated_ref<O, P> &rhs) const { \
138+
T t = *this; \
139+
O t2 = rhs; \
140+
t op t2; \
141+
*this = t; \
142+
return t; \
143+
}
144+
PROPAGATE_OP(+=)
145+
PROPAGATE_OP(-=)
146+
PROPAGATE_OP(*=)
147+
PROPAGATE_OP(/=)
148+
PROPAGATE_OP(%=)
149+
PROPAGATE_OP(^=)
150+
PROPAGATE_OP(&=)
151+
PROPAGATE_OP(|=)
152+
PROPAGATE_OP(<<=)
153+
PROPAGATE_OP(>>=)
154+
#undef PROPAGATE_OP
155+
156+
// propagate binary operators
157+
#define PROPAGATE_OP(op) \
158+
template <class O> \
159+
friend auto operator op(O &&a, const annotated_ref &b) \
160+
->decltype(std::forward<O>(a) op std::declval<T>()) { \
161+
return std::forward<O>(a) op T(b); \
162+
} \
163+
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
164+
friend auto operator op(const annotated_ref &a, O &&b) \
165+
->decltype(std::declval<T>() op std::forward<O>(b)) { \
166+
return T(a) op std::forward<O>(b); \
167+
}
115168
PROPAGATE_OP(+)
116169
PROPAGATE_OP(-)
117170
PROPAGATE_OP(*)
118171
PROPAGATE_OP(/)
119172
PROPAGATE_OP(%)
120-
PROPAGATE_OP(^)
121-
PROPAGATE_OP(&)
122173
PROPAGATE_OP(|)
174+
PROPAGATE_OP(&)
175+
PROPAGATE_OP(^)
123176
PROPAGATE_OP(<<)
124177
PROPAGATE_OP(>>)
178+
PROPAGATE_OP(<)
179+
PROPAGATE_OP(<=)
180+
PROPAGATE_OP(>)
181+
PROPAGATE_OP(>=)
182+
PROPAGATE_OP(==)
183+
PROPAGATE_OP(!=)
184+
PROPAGATE_OP(&&)
185+
PROPAGATE_OP(||)
186+
#undef PROPAGATE_OP
125187

188+
// Propagate unary operators
189+
// by setting a default template we get SFINAE to kick in
190+
#define PROPAGATE_OP(op) \
191+
template <typename O = T> \
192+
auto operator op() const->decltype(op std::declval<O>()) { \
193+
return op O(*this); \
194+
}
195+
PROPAGATE_OP(+)
196+
PROPAGATE_OP(-)
197+
PROPAGATE_OP(!)
198+
PROPAGATE_OP(~)
199+
#undef PROPAGATE_OP
200+
201+
// Propagate inc/dec operators
126202
T operator++() const {
127203
T t = *this;
128204
++t;
@@ -156,8 +232,6 @@ class annotated_ref<T, detail::properties_t<Props...>> {
156232
template <class T2, class P2> friend class annotated_ptr;
157233
};
158234

159-
#undef PROPAGATE_OP
160-
161235
#ifdef __cpp_deduction_guides
162236
template <typename T, typename... Args>
163237
annotated_ptr(T *, Args...)

0 commit comments

Comments
 (0)