Skip to content

Commit 13c41a4

Browse files
committed
let broadcast scoping
1 parent 20c89b0 commit 13c41a4

28 files changed

+219
-180
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.137"
4+
version = "0.12.138"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/LoopVectorization.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ using Static: StaticInt, gt, static, Zero, One, reduce_tup
55
using VectorizationBase,
66
SLEEFPirates, UnPack, OffsetArrays, ArrayInterfaceOffsetArrays, ArrayInterfaceStaticArrays
77
using LayoutPointers:
8-
AbstractStridedPointer, StridedPointer, StridedBitPointer, grouped_strided_pointer,
9-
stridedpointer_preserve, GroupedStridedPointers
8+
AbstractStridedPointer,
9+
StridedPointer,
10+
StridedBitPointer,
11+
grouped_strided_pointer,
12+
stridedpointer_preserve,
13+
GroupedStridedPointers
1014
import LayoutPointers
1115

1216
using SIMDTypes: NativeTypes
@@ -114,8 +118,7 @@ using Base.Meta: isexpr
114118
using DocStringExtensions
115119
import LinearAlgebra # for check_args
116120

117-
using Base:
118-
unsafe_trunc
121+
using Base: unsafe_trunc
119122

120123
using Base.FastMath:
121124
add_fast,

src/broadcast.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
@generated function append_true(::Val{D},::Val{N}) where {D,N}
1+
@generated function append_true(::Val{D}, ::Val{N}) where {D,N}
22
length(D) == N && return D
33
t = Expr(:tuple)
4-
for d = D
4+
for d in D
55
push!(t.args, d)
66
end
77
for n = length(D)+1:N
@@ -12,14 +12,16 @@ end
1212
struct LowDimArray{D,T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
1313
data::A
1414
function LowDimArray{D}(data::A) where {D,T,N,A<:AbstractArray{T,N}}
15-
new{append_true(Val{D}(),Val{N}()),T,N,A}(data)
15+
new{append_true(Val{D}(), Val{N}()),T,N,A}(data)
1616
end
1717
function LowDimArray{D,T,N,A}(data::A) where {D,T,N,A<:AbstractArray{T,N}}
18-
new{append_true(Val{D}(),Val{N}()),T,N,A}(data)
18+
new{append_true(Val{D}(), Val{N}()),T,N,A}(data)
1919
end
2020
end
21-
function LowDimArray{D0}(data::LowDimArray{D1,T,N,A}) where {D0,T,N,D1,A<:AbstractArray{T,N}}
22-
LowDimArray{map(|,D0,D1),T,N,A}(parent(data))
21+
function LowDimArray{D0}(
22+
data::LowDimArray{D1,T,N,A},
23+
) where {D0,T,N,D1,A<:AbstractArray{T,N}}
24+
LowDimArray{map(|, D0, D1),T,N,A}(parent(data))
2325
end
2426
Base.@propagate_inbounds Base.getindex(
2527
A::LowDimArray,
@@ -115,7 +117,9 @@ end
115117
end
116118
Expr(:block, Expr(:meta, :inline), staticexpr(Cnew))
117119
end
118-
function ArrayInterface.contiguous_axis(::Type{LowDimArrayForBroadcast{D,T,N,A}}) where {D,T,N,A}
120+
function ArrayInterface.contiguous_axis(
121+
::Type{LowDimArrayForBroadcast{D,T,N,A}},
122+
) where {D,T,N,A}
119123
ArrayInterface.contiguous_axis(A)
120124
end
121125
@inline function ArrayInterface.stride_rank(
@@ -180,8 +184,8 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve
180184
use_stride_acc = true
181185
stride_acc::Int = 1
182186
if is_column_major(R)
183-
# elseif is_row_major(R)
184-
# Nrange = reverse(Nrange)
187+
# elseif is_row_major(R)
188+
# Nrange = reverse(Nrange)
185189
else # not worth my time optimizing this case at the moment...
186190
# will write something generic stride-rank agnostic eventually
187191
use_stride_acc = false
@@ -323,14 +327,8 @@ function add_broadcast!(
323327
mA = gensym!(ls, "Aₘₖ")
324328
mB = gensym!(ls, "Bₖₙ")
325329
gf = GlobalRef(Core, :getfield)
326-
pushprepreamble!(
327-
ls,
328-
Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))),
329-
)
330-
pushprepreamble!(
331-
ls,
332-
Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))),
333-
)
330+
pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
331+
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
334332
pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, gf, Expr(:call, :size, mB), 1, false)))
335333
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
336334
k = gensym!(ls, "k")
@@ -432,7 +430,7 @@ function add_broadcast!(
432430
pushprepreamble!(ls, Expr(:(=), bcname2, Expr(:call, forbroadcast, lda)))
433431
ArrayReference(bcname2, fulldims)
434432
end
435-
433+
436434
loadop = add_simple_load!(ls, destname, ref, elementbytes, true)::Operation
437435
doaddref!(ls, loadop)
438436
end
@@ -486,10 +484,7 @@ function add_broadcast!(
486484
gf = GlobalRef(Core, :getfield)
487485
for (i, arg) enumerate(args)
488486
argname = gensym!(ls, "arg")
489-
pushprepreamble!(
490-
ls,
491-
Expr(:(=), argname, Expr(:call, gf, bcargs, i, false)),
492-
)
487+
pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, gf, bcargs, i, false)))
493488
# dynamic dispatch
494489
parent = add_broadcast!(
495490
ls,
@@ -542,7 +537,7 @@ end
542537
bc::BC,
543538
::Val{Mod},
544539
::Val{UNROLL},
545-
::Val{dontbc}
540+
::Val{dontbc},
546541
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
547542
# 2 + 1
548543
# we have an N dimensional loop.
@@ -580,8 +575,16 @@ end
580575
bc::BC,
581576
::Val{Mod},
582577
::Val{UNROLL},
583-
::Val{dontbc}
584-
) where {T<:NativeTypes,N,A<:AbstractArray{T,N},BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
578+
::Val{dontbc},
579+
) where {
580+
T<:NativeTypes,
581+
N,
582+
A<:AbstractArray{T,N},
583+
BC<:Union{Broadcasted,Product},
584+
Mod,
585+
UNROLL,
586+
dontbc,
587+
}
585588
# we have an N dimensional loop.
586589
# need to construct the LoopSet
587590
ls = LoopSet(Mod)
@@ -626,14 +629,14 @@ end
626629
bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}},
627630
::Val{Mod},
628631
::Val{UNROLL},
629-
::Val{dontbc}
632+
::Val{dontbc},
630633
) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc}
631634
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
632635
quote
633636
$(Expr(:meta, :inline))
634637
arg = T(first(bc.args))
635638
@turbo inline = $inline unroll = ($u₁, $u₂) thread = $threads vectorize = $v for i
636-
eachindex(
639+
eachindex(
637640
dest,
638641
)
639642
dest[i] = arg
@@ -646,7 +649,7 @@ end
646649
bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}},
647650
::Val{Mod},
648651
::Val{UNROLL},
649-
::Val{dontbc}
652+
::Val{dontbc},
650653
) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc}
651654
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
652655
quote
@@ -680,4 +683,4 @@ end
680683
end
681684

682685
# vmaterialize!(dest, bc, ::Val, ::Val, ::StaticInt, ::StaticInt, ::StaticInt) =
683-
# Base.Broadcast.materialize!(dest, bc)
686+
# Base.Broadcast.materialize!(dest, bc)

src/codegen/lower_threads.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ end
157157
scale_cost(c) = @fastmath c * (Sys.ARCH === :x86_64 ? 0.0225 : 0.005625)
158158
scale_cost(c, looplen) = scale_cost(@fastmath c / looplen)
159159
@inline function choose_num_threads(
160-
C::T,
161-
NT::UInt,
162-
x::Base.BitInteger,
163-
) where {T<:Union{Float32,Float64}}
160+
C::T,
161+
NT::UInt,
162+
x::Base.BitInteger,
163+
) where {T<:Union{Float32,Float64}}
164164
_choose_num_threads(scale_cost(T(C)), NT, x)
165165
end
166166
@inline function _choose_num_threads(
@@ -529,12 +529,8 @@ function thread_one_loops_expr(
529529
end
530530
else# eliminate undef var errors that the compiler should be able to figure out are unreachable, but doesn't
531531
var"#torelease#tuple#" = (zero(UInt),)
532-
var"#threads#tuple#" = (
533-
PolyesterWeave.UnsignedIteratorEarlyStop(
534-
zero(UInt),
535-
0x00000000,
536-
),
537-
)
532+
var"#threads#tuple#" =
533+
(PolyesterWeave.UnsignedIteratorEarlyStop(zero(UInt), 0x00000000),)
538534
end
539535
end
540536
var"#avx#call#args#" = $avxcall_args
@@ -808,10 +804,8 @@ function thread_two_loops_expr(
808804
end
809805
else# eliminate undef var errors that the compiler should be able to figure out are unreachable, but doesn't
810806
var"#torelease#tuple#" = (zero(UInt),)
811-
var"#threads#tuple#" = PolyesterWeave.UnsignedIteratorEarlyStop(
812-
zero(UInt),
813-
0x00000000,
814-
)
807+
var"#threads#tuple#" =
808+
PolyesterWeave.UnsignedIteratorEarlyStop(zero(UInt), 0x00000000)
815809
end
816810
end
817811
# @show $lastboundexpr

src/codegen/lowering.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,11 @@ end
543543
end
544544
end
545545
@inline function of_same_size(::Type{T}, ::Type{S}) where {T,S}
546-
of_same_size(T, S, VectorizationBase.register_size() ÷ VectorizationBase.simd_integer_register_size())
546+
of_same_size(
547+
T,
548+
S,
549+
VectorizationBase.register_size() ÷ VectorizationBase.simd_integer_register_size(),
550+
)
547551
end
548552
function outer_reduction_zero(
549553
op::Operation,

src/condense_loopset.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ function parents_uint(op::Operation)
187187
N = length(opv)
188188
@assert N 32
189189
p0 = parents_uint(view(opv, 1:min(8, N)))
190-
p1 = N > 8 ? parents_uint(view(opv, 9:min(16,N))) : zero(p0)
191-
p2 = N > 16 ? parents_uint(view(opv, 17:min(24,N))) : zero(p0)
190+
p1 = N > 8 ? parents_uint(view(opv, 9:min(16, N))) : zero(p0)
191+
p2 = N > 16 ? parents_uint(view(opv, 17:min(24, N))) : zero(p0)
192192
p3 = N > 24 ? parents_uint(view(opv, 25:N)) : zero(p0)
193193
p0, p1, p2, p3
194194
end
@@ -361,8 +361,12 @@ end
361361
val(x) = Expr(:call, Expr(:curly, :Val, x))
362362

363363
@inline gespf1(x, i) = gesp(x, i)
364-
@inline gespf1(x::StridedPointer{T,1}, i::Tuple{I}) where {T,I<:Union{Integer,StaticInt}} = gesp(x, i)
365-
@inline gespf1(x::StridedBitPointer{T,1}, i::Tuple{I}) where {T,I<:Union{Integer,StaticInt}} = gesp(x, i)
364+
@inline gespf1(x::StridedPointer{T,1}, i::Tuple{I}) where {T,I<:Union{Integer,StaticInt}} =
365+
gesp(x, i)
366+
@inline gespf1(
367+
x::StridedBitPointer{T,1},
368+
i::Tuple{I},
369+
) where {T,I<:Union{Integer,StaticInt}} = gesp(x, i)
366370
@inline gespf1(x::StridedPointer{T,1}, i::Tuple{Zero}) where {T} = x
367371
@inline gespf1(x::StridedBitPointer{T,1}, i::Tuple{Zero}) where {T} = x
368372
@generated function gespf1(
@@ -571,7 +575,7 @@ end
571575
register_size(),
572576
available_registers(),
573577
lv_max_num_threads(),
574-
cache_linesize()
578+
cache_linesize(),
575579
)
576580
end
577581
function find_samename_constparent(op::Operation, opname::Symbol)
@@ -924,7 +928,7 @@ can_turbo(::typeof(Base.literal_pow), ::Val{3}) = true
924928
can_turbo(::typeof(Base.FastMath.pow_fast), ::Val{2}) = true
925929

926930
for f (convert, reinterpret, trunc, unsafe_trunc, round, ceil, floor)
927-
@eval can_turbo(::typeof($f), ::Val{2}) = true
931+
@eval can_turbo(::typeof($f), ::Val{2}) = true
928932
end
929933

930934
"""

src/constructors.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function substitute_broadcast(
5656
)
5757
ci = first(Meta.lower(LoopVectorization, q).args).code
5858
nargs = length(ci) - 1
59-
ex = Expr(:block)
59+
lb = Expr(:block)
6060
syms = Vector{Symbol}(undef, nargs)
6161
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg, safe)
6262
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
@@ -70,17 +70,21 @@ function substitute_broadcast(
7070
ciₙargs = ciₙ.args
7171
f = first(ciₙargs)
7272
if ciₙ.head === :(=)
73-
push!(ex.args, Expr(:(=), f, syms[((ciₙargs[2])::Core.SSAValue).id]))
73+
push!(lb.args, Expr(:(=), f, syms[((ciₙargs[2])::Core.SSAValue).id]))
7474
elseif isglobalref(f, Base, :materialize!)
75-
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n, unroll_param_tup, mod)
75+
add_ci_call!(lb, lv(:vmaterialize!), ciₙargs, syms, n, unroll_param_tup, mod)
7676
elseif isglobalref(f, Base, :materialize)
77-
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n, unroll_param_tup, mod)
77+
add_ci_call!(lb, lv(:vmaterialize), ciₙargs, syms, n, unroll_param_tup, mod)
7878
else
79-
add_ci_call!(ex, f, ciₙargs, syms, n)
79+
add_ci_call!(lb, f, ciₙargs, syms, n)
8080
end
8181
end
8282
end
83-
esc(ex)
83+
ret::Expr = pop!(lb.args)
84+
if Meta.isexpr(ret, :(=), 2)
85+
ret = ret.args[2]
86+
end
87+
esc(Expr(:let, lb, Expr(:block, ret)))
8488
end
8589

8690

src/getconstindexes.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ const EXTRACTFUNS = (
3737
:thirtysixth,
3838
:thirtyseventh,
3939
:thirtyeighth,
40-
:last)
40+
:last,
41+
)
4142

4243
for (i, f) enumerate(EXTRACTFUNS)
4344
(i == 1 || i == length(EXTRACTFUNS)) && continue

src/modeling/costs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ const COST = Dict{Symbol,InstructionCost}(
266266
:significand => InstructionCost(8, 1.0),
267267
)
268268

269-
for f = EXTRACTFUNS
269+
for f in EXTRACTFUNS
270270
COST[f] = InstructionCost(0, 0.0, 0.0, 0)
271271
end
272272

@@ -351,7 +351,8 @@ end
351351
VectorizationBase.fmap(ier, VectorizationBase.data(a), VectorizationBase.data(b)),
352352
)
353353

354-
@inline (iec::IfElseCollapser)(a) = VectorizationBase.contract(IfElseOp(iec.f), a, StaticInt{1}())
354+
@inline (iec::IfElseCollapser)(a) =
355+
VectorizationBase.contract(IfElseOp(iec.f), a, StaticInt{1}())
355356
@inline (iec::IfElseCollapser)(a, ::StaticInt{C}) where {C} =
356357
VectorizationBase.contract(IfElseOp(iec.f), a, StaticInt{C}())
357358

src/modeling/graphs.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,7 @@ end
509509
available_registers() =
510510
ifelse(has_opmask_registers(), register_count(), register_count() - One())
511511
function set_hw!(ls::LoopSet)
512-
set_hw!(
513-
ls,
514-
Int(register_size()),
515-
Int(available_registers()),
516-
Int(cache_linesize())
517-
)
512+
set_hw!(ls, Int(register_size()), Int(available_registers()), Int(cache_linesize()))
518513
end
519514
reg_size(ls::LoopSet) = ls.register_size
520515
reg_count(ls::LoopSet) = ls.register_count
@@ -924,7 +919,7 @@ function add_block!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
924919
end
925920
function makestatic!(expr)
926921
expr isa Expr || return expr
927-
for i = eachindex(expr.args)
922+
for i in eachindex(expr.args)
928923
ex = expr.args[i]
929924
if ex isa Int
930925
expr.args[i] = staticexpr(ex)
@@ -1481,7 +1476,7 @@ function add_operation!(
14811476
add_comparison!(ls, LHS_sym, RHS, elementbytes, position)
14821477
else
14831478
throw(LoopError("Expression not recognized.", RHS))
1484-
end
1479+
end
14851480
end
14861481

14871482
function prepare_rhs_for_storage!(

src/parse/add_compute.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ function add_compute!(
466466
arg1 = args[1]
467467
arg2 = args[2]
468468
if arg1 isa Number && convert(Float64, arg1) === -1.0
469-
return add_compute!(ls, var, :(2iseven($arg2)-1), elementbytes, position, mpref)
469+
return add_compute!(ls, var, :(2iseven($arg2) - 1), elementbytes, position, mpref)
470470
end
471471
if arg2 isa Number
472472
return add_pow!(ls, var, args[1], arg2, elementbytes, position)

0 commit comments

Comments
 (0)