@@ -1143,7 +1143,7 @@ def __getitem__(self, item):
1143
1143
units = getattr (self , "units" , None )
1144
1144
ret = super (QuantityND , self ).__getitem__ (item )
1145
1145
if isinstance (ret , QuantityND ) or units is not None :
1146
- return QuantityND (ret , units )
1146
+ ret = QuantityND (ret , units )
1147
1147
return ret
1148
1148
1149
1149
def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ):
@@ -1180,10 +1180,27 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
1180
1180
out_arr = QuantityND (out_arr , unit )
1181
1181
return out_arr
1182
1182
1183
+ @property
1184
+ def v (self ):
1185
+ return self .view (np .ndarray )
1186
+
1187
+
1188
+ def test_quantitynd ():
1189
+ q = QuantityND ([1 , 2 ], "m" )
1190
+ q0 , q1 = q [:]
1191
+ assert np .all (q .v == np .asarray ([1 , 2 ]))
1192
+ assert q .units == "m"
1193
+ assert np .all ((q0 + q1 ).v == np .asarray ([3 ]))
1194
+ assert (q0 * q1 ).units == "m*m"
1195
+ assert (q1 / q0 ).units == "m/(m)"
1196
+ with pytest .raises (ValueError ):
1197
+ q0 + QuantityND (1 , "s" )
1198
+
1199
+
1183
1200
def test_imshow_quantitynd ():
1184
1201
# generate a dummy ndarray subclass
1185
- arr = QuantityND (np .ones ((2 ,2 )), "m" )
1202
+ arr = QuantityND (np .ones ((2 , 2 )), "m" )
1186
1203
fig , ax = plt .subplots ()
1187
1204
ax .imshow (arr )
1188
1205
# executing the draw should not raise an exception
1189
- fig .canvas .draw ()
1206
+ fig .canvas .draw ()
0 commit comments