Skip to content

Commit 18bb772

Browse files
committed
Fix broken precompile statements, add kwargs to matmul_params for loop lengths and vector width
1 parent 7c2b4b1 commit 18bb772

File tree

5 files changed

+38
-11
lines changed

5 files changed

+38
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.83"
4+
version = "0.12.84"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/parse/memory_ops_common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices,
9393
subsetvptr
9494
end
9595

96-
function gesp_const_offset!(ls::LoopSet, vptrarray, ninds, indices, loopedindex, mlt::Integer, sym, D)
96+
function gesp_const_offset!(ls::LoopSet, vptrarray::Symbol, ninds::Int, indices::Vector{Symbol}, loopedindex::Vector{Bool}, mlt::Integer, sym, D::Int)
9797
if isone(mlt)
9898
subset_vptr!(ls, vptrarray, ninds, sym, indices, loopedindex, D)
9999
else
@@ -102,7 +102,7 @@ function gesp_const_offset!(ls::LoopSet, vptrarray, ninds, indices, loopedindex,
102102
subset_vptr!(ls, vptrarray, ninds, mltsym, indices, loopedindex, D)
103103
end
104104
end
105-
function gesp_const_offsets!(ls::LoopSet, vptrarray, ninds, indices, loopedindex, mltsyms, D)
105+
function gesp_const_offsets!(ls::LoopSet, vptrarray::Symbol, ninds::Int, indices::Vector{Symbol}, loopedindex::Vector{Bool}, mltsyms::Vector{Tuple{Int,Symbol}}, D::Int)
106106
length(mltsyms) > 1 && sort!(mltsyms, by = last) # if multiple have same combination of syms, make sure they match even if order is different
107107
for (mlt,sym) mltsyms
108108
vptrarray = gesp_const_offset!(ls, vptrarray, ninds, indices, loopedindex, mlt, sym, D)

src/precompile.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function _precompile_()
1515
Base.precompile(Tuple{typeof(substitute_broadcast),Expr,Symbol,Bool,Int8,Int8,Int8,Int,Int}) # time: 0.02281322
1616
Base.precompile(Tuple{typeof(push!),LoopSet,Expr,Int,Int}) # time: 0.022659862
1717
Base.precompile(Tuple{typeof(add_compute!),LoopSet,Symbol,Expr,Int,Int,Nothing}) # time: 0.02167476
18-
Base.precompile(Tuple{typeof(checkforoffset!),LoopSet,Symbol,Int,Vector{Operation},Vector{Symbol},Vector{Int8},Vector{Int8},Vector{Bool},Vector{Symbol},Vector{Symbol},Expr}) # time: 0.020454278
18+
Base.precompile(Tuple{typeof(checkforoffset!),LoopSet,Symbol,Int,Vector{Operation},Vector{Symbol},Vector{Int8},Vector{Int8},Vector{Bool},Vector{Symbol},Vector{Symbol},Expr,Int}) # time: 0.020454278
1919
Base.precompile(Tuple{typeof(generate_call),LoopSet,Tuple{Bool, Int8, Int8, Int8},UInt,Bool}) # time: 0.020274462
2020
Base.precompile(Tuple{typeof(expandbyoffset!),Vector{Tuple{Int, Tuple{Int, Int32, Bool}}},Vector{Any},Vector{Int}}) # time: 0.019860294
2121
Base.precompile(Tuple{typeof(isscopedname),Symbol,Symbol,Symbol}) # time: 0.016642524
@@ -43,19 +43,19 @@ function _precompile_()
4343
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float32, 4, 1, 0, (1, 2, 3, 4), Tuple{StaticInt{4}, Int, Int, Int}, NTuple{4, StaticInt{1}}},Tuple{StaticInt{-1}, StaticInt{-1}, StaticInt{1}, StaticInt{1}}}) # time: 0.006164707
4444
Base.precompile(Tuple{typeof(add_ci_call!),Expr,Any,Vector{Any},Vector{Symbol},Int,Expr,Symbol}) # time: 0.006148137
4545
Base.precompile(Tuple{typeof(add_ci_call!),Expr,Any,Vector{Any},Vector{Symbol},Int}) # time: 0.006063301
46-
Base.precompile(Tuple{typeof(mem_offset),Operation,UnrollArgs,Vector{Bool},Bool,LoopSet}) # time: 0.005945972
46+
Base.precompile(Tuple{typeof(mem_offset),Operation,UnrollArgs,Vector{Bool},Bool,LoopSet,Bool}) # time: 0.005945972
4747
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 3, 1, 0, (1, 2), Tuple{StaticInt{8}, StaticInt{16}, Int}, Tuple{StaticInt{1}, StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{1}, StaticInt{1}, StaticInt{0}}}) # time: 0.005927015
4848
Base.precompile(Tuple{typeof(sizeofeltypes),Core.SimpleVector}) # time: 0.005828176
4949
Base.precompile(Tuple{typeof(cse_constant_offsets!),LoopSet,Vector{ArrayReferenceMeta},Int,Vector{Vector{Int}},Vector{Vector{Tuple{Int, Int, Int}}}}) # time: 0.005694307
5050
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 4, 1, 0, (1, 2, 3, 4), Tuple{StaticInt{8}, StaticInt{16}, Int, Int}, NTuple{4, StaticInt{1}}},Tuple{StaticInt{1}, VectorizationBase.NullStep, StaticInt{2}, VectorizationBase.NullStep}}) # time: 0.005314204
5151
Base.precompile(Tuple{typeof(indices_loop!),LoopSet,Expr,Symbol}) # time: 0.005283243
5252
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{0, Tuple{}}, VectorizationBase.NullStep, VectorizationBase.CartesianVIndex{4, NTuple{4, StaticInt{1}}}}}) # time: 0.005256126
53-
Base.precompile(Tuple{typeof(gesp_const_offsets!),LoopSet,Symbol,Int,Vector{Symbol},Vector{Bool},Vector{Tuple{Int, Symbol}}}) # time: 0.005168524
53+
Base.precompile(Tuple{typeof(gesp_const_offsets!),LoopSet,Symbol,Int,Vector{Symbol},Vector{Bool},Vector{Tuple{Int, Symbol}},Int}) # time: 0.005168524
5454
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{4, NTuple{4, StaticInt{1}}}, VectorizationBase.NullStep, VectorizationBase.CartesianVIndex{0, Tuple{}}}}) # time: 0.005122315
5555
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{2, Tuple{StaticInt{1}, StaticInt{1}}}, VectorizationBase.NullStep, VectorizationBase.CartesianVIndex{2, Tuple{StaticInt{1}, StaticInt{1}}}}}) # time: 0.005078802
5656
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float32, 2, 1, 0, (1, 2), Tuple{StaticInt{4}, Int}, Tuple{StaticInt{1}, StaticInt{1}}},Tuple{StaticInt{0}, StaticInt{0}}}) # time: 0.005036135
5757
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 4, 2, 0, (3, 1, 4, 2), Tuple{Int, StaticInt{8}, Int, Int}, NTuple{4, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{4, NTuple{4, StaticInt{1}}}}}) # time: 0.004968671
58-
Base.precompile(Tuple{typeof(subset_vptr!),LoopSet,Symbol,Int,Symbol,Vector{Symbol},Vector{Bool},Bool}) # time: 0.004904486
58+
Base.precompile(Tuple{typeof(subset_vptr!),LoopSet,Symbol,Int,Symbol,Vector{Symbol},Vector{Bool},Int}) # time: 0.004904486
5959
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{3, Tuple{StaticInt{1}, StaticInt{1}, StaticInt{1}}}, VectorizationBase.NullStep, VectorizationBase.CartesianVIndex{1, Tuple{StaticInt{1}}}}}) # time: 0.004722758
6060
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{0, Tuple{}}, Int, VectorizationBase.CartesianVIndex{4, NTuple{4, StaticInt{1}}}}}) # time: 0.004705647
6161
Base.precompile(Tuple{typeof(gespf1),StridedPointer{Float64, 5, 1, 0, (1, 2, 3, 4, 5), Tuple{StaticInt{8}, Int, Int, Int, Int}, NTuple{5, StaticInt{1}}},Tuple{VectorizationBase.CartesianVIndex{1, Tuple{StaticInt{1}}}, VectorizationBase.NullStep, VectorizationBase.CartesianVIndex{3, Tuple{StaticInt{1}, StaticInt{1}, StaticInt{1}}}}}) # time: 0.00464261

src/user_api_conveniences.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,31 @@ const GEMMLOOPSET = loopset(
1212
);
1313

1414

15-
function matmul_params(rs::Int, rc::Int, cls::Int)
16-
set_hw!(GEMMLOOPSET, rs, rc, cls, Int(cache_size(StaticInt(1))), Int(cache_size(StaticInt(2))), Int(cache_size(StaticInt(3))))
17-
order = choose_order(GEMMLOOPSET)
18-
order[5], last(order)
15+
# function matmul_params(rs::Int, rc::Int, cls::Int)
16+
# set_hw!(GEMMLOOPSET, rs, rc, cls, Int(cache_size(StaticInt(1))), Int(cache_size(StaticInt(2))), Int(cache_size(StaticInt(3))))
17+
# order = choose_order(GEMMLOOPSET)
18+
# order[5], last(order)
19+
# end
20+
function matmul_params(rs::Int, rc::Int, cls::Int; M = nothing, K = nothing, N = nothing, W = 0)
21+
set_hw!(GEMMLOOPSET, rs, rc, cls, Int(cache_size(StaticInt(1))), Int(cache_size(StaticInt(2))), Int(cache_size(StaticInt(3))))
22+
if N nothing
23+
nloop = GEMMLOOPSET.loops[1]
24+
GEMMLOOPSET.loops[1] = Loop(:n, MaybeKnown(1), MaybeKnown(N), MaybeKnown(1), nloop.rangesym, nloop.lensym)
25+
end
26+
if M nothing
27+
mloop = GEMMLOOPSET.loops[2]
28+
GEMMLOOPSET.loops[2] = Loop(:m, MaybeKnown(1), MaybeKnown(M), MaybeKnown(1), mloop.rangesym, mloop.lensym)
29+
end
30+
if K nothing
31+
kloop = GEMMLOOPSET.loops[3]
32+
GEMMLOOPSET.loops[3] = Loop(:k, MaybeKnown(1), MaybeKnown(K), MaybeKnown(1), kloop.rangesym, kloop.lensym)
33+
end
34+
GEMMLOOPSET.vector_width = W
35+
order = choose_order(GEMMLOOPSET)
36+
(N nothing) && (GEMMLOOPSET.loops[1] = nloop)
37+
(M nothing) && (GEMMLOOPSET.loops[2] = mloop)
38+
(K nothing) && (GEMMLOOPSET.loops[3] = kloop)
39+
order[5], last(order)
1940
end
2041
@generated function matmul_params(::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}) where {RS,RC,CLS}
2142
mᵣ, nᵣ = matmul_params(RS, RC, CLS)

test/gemm.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
if LoopVectorization.register_count() != 8
1010
@test @inferred(LoopVectorization.matmul_params()) == (Unum, Tnum)
1111
end
12+
13+
@test LoopVectorization.matmul_params(64, 32, 64; M=8, K=100, N=100, W=8) == (1, 25)
14+
@test LoopVectorization.matmul_params(64, 32, 64; M=8, K=100, N= 96, W=8) == (1, 24)
15+
@test LoopVectorization.matmul_params(64, 32, 64; M=8, K=100, N= 92, W=8) == (1, 23)
16+
@test LoopVectorization.matmul_params(64, 32, 64; M=8, K=100, N= 95, W=8) == (1, 10)
17+
1218
AmulBtq1 = :(for m axes(A,1), n axes(B,2)
1319
C[m,n] = zeroB
1420
for k axes(A,2)

0 commit comments

Comments
 (0)