@@ -22,8 +22,8 @@ Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector
22
22
MulStyle (A:: CompositeMap ) = MulStyle (A. maps... ) === TwoArg () ? TwoArg () : ThreeArg ()
23
23
24
24
# basic methods
25
- Base. size (A:: CompositeMap ) = (size (A. maps[ end ] , 1 ), size (A. maps[ 1 ] , 2 ))
26
- Base. axes (A:: CompositeMap ) = (axes (A. maps[ end ]) [1 ], axes (A. maps[ 1 ] )[2 ])
25
+ Base. size (A:: CompositeMap ) = (size (last ( A. maps) , 1 ), size (first ( A. maps) , 2 ))
26
+ Base. axes (A:: CompositeMap ) = (axes (last ( A. maps)) [1 ], axes (first ( A. maps) )[2 ])
27
27
Base. isreal (A:: CompositeMap ) = all (isreal, A. maps) # sufficient but not necessary
28
28
29
29
# the following rules are sufficient but not necessary
@@ -32,17 +32,17 @@ for (f, _f, g) in ((:issymmetric, :_issymmetric, :transpose),
32
32
@eval begin
33
33
LinearAlgebra.$ f (A:: CompositeMap ) = $ _f (A. maps)
34
34
$ _f (maps:: Tuple{} ) = true
35
- $ _f (maps:: Tuple{<:LinearMap} ) = $ f (maps[ 1 ] )
35
+ $ _f (maps:: Tuple{<:LinearMap} ) = $ f (first ( maps) )
36
36
$ _f (maps:: LinearMapTuple ) =
37
- maps[end ] == $ g (maps[ 1 ] ) && $ _f (Base. front (Base. tail (maps)))
37
+ maps[end ] == $ g (first ( maps) ) && $ _f (Base. front (Base. tail (maps)))
38
38
function $_f (maps:: LinearMapVector )
39
39
n = length (maps)
40
40
if n == 0
41
41
return true
42
42
elseif n == 1
43
- return ($ f (maps[ 1 ] )):: Bool
43
+ return ($ f (first ( maps) )):: Bool
44
44
else
45
- return ((maps[ end ] == $ g (maps[ 1 ])) :: Bool && $ _f (@views maps[2 : end - 1 ]))
45
+ return ((last ( maps) == $ g (first ( maps))) :: Bool && $ _f (@views maps[begin + 1 : end - 1 ]))
46
46
end
47
47
end
48
48
# since the introduction of ScaledMap, the following cases cannot occur
@@ -71,10 +71,10 @@ function _isposdef(maps::LinearMapVector)
71
71
if n == 0
72
72
return true
73
73
elseif n == 1
74
- return isposdef (maps[ 1 ] )
74
+ return isposdef (first ( maps) )
75
75
else
76
- return (maps[ end ] == adjoint (maps[ 1 ]) || maps[ end ] == maps[ 1 ] ) &&
77
- isposdef (maps[ 1 ]) && _isposdef (maps[2 : end - 1 ])
76
+ return (last ( maps) == adjoint (first ( maps)) || last ( maps) == first ( maps) ) &&
77
+ isposdef (first ( maps)) && _isposdef (maps[begin + 1 : end - 1 ])
78
78
end
79
79
end
80
80
@@ -180,12 +180,11 @@ function _unsafe_mul!(y, A::CompositeMap, x::AbstractVector)
180
180
return y
181
181
end
182
182
_unsafe_mul! (y, A:: CompositeMap , x:: AbstractMatrix ) = _compositemul! (y, A, x)
183
+ _unsafe_mul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap}} , x:: AbstractVector ) =
184
+ _unsafe_mul! (y, A. maps[1 ], x)
185
+ _unsafe_mul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap}} , X:: AbstractMatrix ) =
186
+ _unsafe_mul! (y, A. maps[1 ], X)
183
187
184
- function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap}} , x,
185
- source = nothing ,
186
- dest = nothing )
187
- return _unsafe_mul! (y, A. maps[1 ], x)
188
- end
189
188
function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}} , x,
190
189
source = nothing ,
191
190
dest = nothing )
@@ -206,9 +205,9 @@ function _compositemul!(y, A::CompositeMap{<:Any,<:LinearMapVector}, x,
206
205
dest = nothing )
207
206
N = length (A. maps)
208
207
if N == 1
209
- return _unsafe_mul! (y, A. maps[1 ], x)
208
+ return _unsafe_mul! (y, A. maps[begin ], x)
210
209
elseif N == 2
211
- return _unsafe_mul! (y, A. maps[2 ] * A. maps[1 ], x)
210
+ return _unsafe_mul! (y, A. maps[end ] * A. maps[begin ], x)
212
211
else
213
212
return _compositemulN! (y, A, x, source, dest)
214
213
end
@@ -218,19 +217,21 @@ function _compositemulN!(y, A::CompositeMap, x,
218
217
src = nothing ,
219
218
dst = nothing )
220
219
N = length (A. maps) # ≥ 3
220
+ n = n0 = firstindex (A. maps)
221
221
source = isnothing (src) ?
222
- convert (AbstractArray, A. maps[1 ] * x) :
223
- _unsafe_mul! (src, A. maps[1 ], x)
222
+ convert (AbstractArray, A. maps[n] * x) :
223
+ _unsafe_mul! (src, A. maps[n], x)
224
+ n += 1
224
225
dest = isnothing (dst) ?
225
- convert (AbstractArray, A. maps[2 ] * source) :
226
- _unsafe_mul! (dst, A. maps[2 ], source)
226
+ convert (AbstractArray, A. maps[n ] * source) :
227
+ _unsafe_mul! (dst, A. maps[n ], source)
227
228
dest, source = source, dest # alternate dest and source
228
- for n in 3 : N- 1
229
+ for n in (n0 + 2 ) : N- 1
229
230
dest = _resize (dest, (size (A. maps[n], 1 ), size (x)[2 : end ]. .. ))
230
231
_unsafe_mul! (dest, A. maps[n], source)
231
232
dest, source = source, dest # alternate dest and source
232
233
end
233
- _unsafe_mul! (y, A. maps[N] , source)
234
+ _unsafe_mul! (y, last ( A. maps) , source)
234
235
return y
235
236
end
236
237
0 commit comments