Skip to content

Commit a0fbb18

Browse files
authored
experimental: configure custom memory allocator (#2177)
Signed-off-by: Nuno Cruces <[email protected]>
1 parent 891e470 commit a0fbb18

File tree

9 files changed

+153
-35
lines changed

9 files changed

+153
-35
lines changed

experimental/memory.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package experimental
2+
3+
import (
4+
"context"
5+
6+
"github.com/tetratelabs/wazero/internal/ctxkey"
7+
)
8+
9+
// MemoryAllocator is a memory allocation hook which is invoked
10+
// to create a new MemoryBuffer, with the given specification:
11+
// min is the initial and minimum length (in bytes) of the backing []byte,
12+
// cap a suggested initial capacity, and max the maximum length
13+
// that will ever be requested.
14+
type MemoryAllocator func(min, cap, max uint64) MemoryBuffer
15+
16+
// MemoryBuffer is a memory buffer that backs a Wasm memory.
17+
type MemoryBuffer interface {
18+
// Return the backing []byte for the memory buffer.
19+
Buffer() []byte
20+
// Grow the backing memory buffer to size bytes in length.
21+
// To back a shared memory, Grow can't change the address
22+
// of the backing []byte (only its length/capacity may change).
23+
Grow(size uint64) []byte
24+
// Free the backing memory buffer.
25+
Free()
26+
}
27+
28+
// WithMemoryAllocator registers the given MemoryAllocator into the given
29+
// context.Context.
30+
func WithMemoryAllocator(ctx context.Context, allocator MemoryAllocator) context.Context {
31+
if allocator != nil {
32+
return context.WithValue(ctx, ctxkey.MemoryAllocatorKey{}, allocator)
33+
}
34+
return ctx
35+
}

imports/assemblyscript/assemblyscript_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ import (
1616
"github.com/tetratelabs/wazero/api"
1717
. "github.com/tetratelabs/wazero/experimental"
1818
"github.com/tetratelabs/wazero/experimental/logging"
19+
"github.com/tetratelabs/wazero/experimental/wazerotest"
1920
. "github.com/tetratelabs/wazero/internal/assemblyscript"
2021
"github.com/tetratelabs/wazero/internal/testing/proxy"
2122
"github.com/tetratelabs/wazero/internal/testing/require"
2223
"github.com/tetratelabs/wazero/internal/u64"
23-
"github.com/tetratelabs/wazero/internal/wasm"
2424
"github.com/tetratelabs/wazero/sys"
2525
)
2626

@@ -376,7 +376,7 @@ func Test_readAssemblyScriptString(t *testing.T) {
376376
tc := tt
377377

378378
t.Run(tc.name, func(t *testing.T) {
379-
mem := wasm.NewMemoryInstance(&wasm.Memory{Min: 1, Cap: 1, Max: 1})
379+
mem := wazerotest.NewFixedMemory(wazerotest.PageSize)
380380
tc.memory(mem)
381381

382382
s, ok := readAssemblyScriptString(mem, uint32(tc.offset))

internal/ctxkey/memory.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package ctxkey
2+
3+
type MemoryAllocatorKey struct{}

internal/wasm/memory.go

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"unsafe"
1313

1414
"github.com/tetratelabs/wazero/api"
15+
"github.com/tetratelabs/wazero/experimental"
1516
"github.com/tetratelabs/wazero/internal/internalapi"
1617
"github.com/tetratelabs/wazero/internal/wasmruntime"
1718
)
@@ -57,12 +58,22 @@ type MemoryInstance struct {
5758
// waiters implements atomic wait and notify. It is implemented similarly to golang.org/x/sync/semaphore,
5859
// with a fixed weight of 1 and no spurious notifications.
5960
waiters sync.Map
61+
62+
expBuffer experimental.MemoryBuffer
6063
}
6164

6265
// NewMemoryInstance creates a new instance based on the parameters in the SectionIDMemory.
63-
func NewMemoryInstance(memSec *Memory) *MemoryInstance {
64-
var size uint64
65-
if memSec.IsShared {
66+
func NewMemoryInstance(memSec *Memory, allocator experimental.MemoryAllocator) *MemoryInstance {
67+
minBytes := MemoryPagesToBytesNum(memSec.Min)
68+
capBytes := MemoryPagesToBytesNum(memSec.Cap)
69+
maxBytes := MemoryPagesToBytesNum(memSec.Max)
70+
71+
var buffer []byte
72+
var expBuffer experimental.MemoryBuffer
73+
if allocator != nil {
74+
expBuffer = allocator(minBytes, capBytes, maxBytes)
75+
buffer = expBuffer.Buffer()
76+
} else if memSec.IsShared {
6677
// Shared memory needs a fixed buffer, so allocate with the maximum size.
6778
//
6879
// The rationale as to why we can simply use make([]byte) to a fixed buffer is that Go's GC is non-relocating.
@@ -73,18 +84,17 @@ func NewMemoryInstance(memSec *Memory) *MemoryInstance {
7384
// the memory buffer allocation here is virtual and doesn't consume physical memory until it's used.
7485
// * https://github.com/golang/go/blob/8121604559035734c9677d5281bbdac8b1c17a1e/src/runtime/malloc.go#L1059
7586
// * https://github.com/golang/go/blob/8121604559035734c9677d5281bbdac8b1c17a1e/src/runtime/malloc.go#L1165
76-
size = MemoryPagesToBytesNum(memSec.Max)
87+
buffer = make([]byte, minBytes, maxBytes)
7788
} else {
78-
size = MemoryPagesToBytesNum(memSec.Cap)
89+
buffer = make([]byte, minBytes, capBytes)
7990
}
80-
81-
buffer := make([]byte, MemoryPagesToBytesNum(memSec.Min), size)
8291
return &MemoryInstance{
83-
Buffer: buffer,
84-
Min: memSec.Min,
85-
Cap: memoryBytesNumToPages(uint64(cap(buffer))),
86-
Max: memSec.Max,
87-
Shared: memSec.IsShared,
92+
Buffer: buffer,
93+
Min: memSec.Min,
94+
Cap: memoryBytesNumToPages(uint64(cap(buffer))),
95+
Max: memSec.Max,
96+
Shared: memSec.IsShared,
97+
expBuffer: expBuffer,
8898
}
8999
}
90100

@@ -222,6 +232,22 @@ func (m *MemoryInstance) Grow(delta uint32) (result uint32, ok bool) {
222232
newPages := currentPages + delta
223233
if newPages > m.Max || int32(delta) < 0 {
224234
return 0, false
235+
} else if m.expBuffer != nil {
236+
buffer := m.expBuffer.Grow(MemoryPagesToBytesNum(newPages))
237+
if m.Shared {
238+
if unsafe.SliceData(buffer) != unsafe.SliceData(m.Buffer) {
239+
panic("shared memory cannot move, this is a bug in the memory allocator")
240+
}
241+
// We assume grow is called under a guest lock.
242+
// But the memory length is accessed elsewhere,
243+
// so use atomic to make the new length visible across threads.
244+
atomicStoreLength(&m.Buffer, uintptr(len(buffer)))
245+
m.Cap = memoryBytesNumToPages(uint64(cap(buffer)))
246+
} else {
247+
m.Buffer = buffer
248+
m.Cap = newPages
249+
}
250+
return currentPages, true
225251
} else if newPages > m.Cap { // grow the memory.
226252
if m.Shared {
227253
panic("shared memory cannot be grown, this is a bug in wazero")
@@ -231,9 +257,10 @@ func (m *MemoryInstance) Grow(delta uint32) (result uint32, ok bool) {
231257
return currentPages, true
232258
} else { // We already have the capacity we need.
233259
if m.Shared {
234-
sp := (*reflect.SliceHeader)(unsafe.Pointer(&m.Buffer))
235-
// Use atomic write to ensure new length is visible across threads.
236-
atomic.StoreUintptr((*uintptr)(unsafe.Pointer(&sp.Len)), uintptr(MemoryPagesToBytesNum(newPages)))
260+
// We assume grow is called under a guest lock.
261+
// But the memory length is accessed elsewhere,
262+
// so use atomic to make the new length visible across threads.
263+
atomicStoreLength(&m.Buffer, uintptr(MemoryPagesToBytesNum(newPages)))
237264
} else {
238265
m.Buffer = m.Buffer[:MemoryPagesToBytesNum(newPages)]
239266
}
@@ -267,6 +294,13 @@ func PagesToUnitOfBytes(pages uint32) string {
267294

268295
// Below are raw functions used to implement the api.Memory API:
269296

297+
// Uses atomic write to update the length of a slice.
298+
func atomicStoreLength(slice *[]byte, length uintptr) {
299+
slicePtr := (*reflect.SliceHeader)(unsafe.Pointer(slice))
300+
lenPtr := (*uintptr)(unsafe.Pointer(&slicePtr.Len))
301+
atomic.StoreUintptr(lenPtr, length)
302+
}
303+
270304
// memoryBytesNumToPages converts the given number of bytes into the number of pages.
271305
func memoryBytesNumToPages(bytesNum uint64) (pages uint32) {
272306
return uint32(bytesNum >> MemoryPageSizeInBits)

internal/wasm/memory_test.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"unsafe"
1010

1111
"github.com/tetratelabs/wazero/api"
12+
"github.com/tetratelabs/wazero/experimental"
1213
"github.com/tetratelabs/wazero/internal/testing/require"
1314
)
1415

@@ -34,9 +35,11 @@ func TestMemoryInstance_Grow_Size(t *testing.T) {
3435
tests := []struct {
3536
name string
3637
capEqualsMax bool
38+
expAllocator bool
3739
}{
3840
{name: ""},
3941
{name: "capEqualsMax", capEqualsMax: true},
42+
{name: "expAllocator", expAllocator: true},
4043
}
4144

4245
for _, tt := range tests {
@@ -46,10 +49,14 @@ func TestMemoryInstance_Grow_Size(t *testing.T) {
4649
max := uint32(10)
4750
maxBytes := MemoryPagesToBytesNum(max)
4851
var m *MemoryInstance
49-
if tc.capEqualsMax {
50-
m = &MemoryInstance{Cap: max, Max: max, Buffer: make([]byte, 0, maxBytes)}
51-
} else {
52+
switch {
53+
default:
5254
m = &MemoryInstance{Max: max, Buffer: make([]byte, 0)}
55+
case tc.capEqualsMax:
56+
m = &MemoryInstance{Cap: max, Max: max, Buffer: make([]byte, 0, maxBytes)}
57+
case tc.expAllocator:
58+
expBuffer := sliceAllocator(0, 0, maxBytes)
59+
m = &MemoryInstance{Max: max, Buffer: expBuffer.Buffer(), expBuffer: expBuffer}
5360
}
5461

5562
res, ok := m.Grow(5)
@@ -814,6 +821,13 @@ func BenchmarkWriteString(b *testing.B) {
814821
}
815822
}
816823

824+
func Test_atomicStoreLength(t *testing.T) {
825+
// Doesn't verify atomicity, but at least we're updating the correct thing.
826+
slice := make([]byte, 10, 20)
827+
atomicStoreLength(&slice, 15)
828+
require.Equal(t, 15, len(slice))
829+
}
830+
817831
func TestNewMemoryInstance_Shared(t *testing.T) {
818832
tests := []struct {
819833
name string
@@ -832,7 +846,7 @@ func TestNewMemoryInstance_Shared(t *testing.T) {
832846
for _, tc := range tests {
833847
tc := tc
834848
t.Run(tc.name, func(t *testing.T) {
835-
m := NewMemoryInstance(tc.mem)
849+
m := NewMemoryInstance(tc.mem, nil)
836850
require.Equal(t, tc.mem.Min, m.Min)
837851
require.Equal(t, tc.mem.Max, m.Max)
838852
require.True(t, m.Shared)
@@ -979,3 +993,25 @@ func requireChannelEmpty(t *testing.T, ch chan string) {
979993
// fallthrough
980994
}
981995
}
996+
997+
func sliceAllocator(min, cap, max uint64) experimental.MemoryBuffer {
998+
return &sliceBuffer{make([]byte, min, cap), max}
999+
}
1000+
1001+
type sliceBuffer struct {
1002+
buf []byte
1003+
max uint64
1004+
}
1005+
1006+
func (b *sliceBuffer) Free() {}
1007+
1008+
func (b *sliceBuffer) Buffer() []byte { return b.buf }
1009+
1010+
func (b *sliceBuffer) Grow(size uint64) []byte {
1011+
if cap := uint64(cap(b.buf)); size > cap {
1012+
b.buf = append(b.buf[:cap], make([]byte, size-cap)...)
1013+
} else {
1014+
b.buf = b.buf[:size]
1015+
}
1016+
return b.buf
1017+
}

internal/wasm/module.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,10 @@ func paramNames(localNames IndirectNameMap, funcIdx uint32, paramLen int) []stri
652652
return nil
653653
}
654654

655-
func (m *ModuleInstance) buildMemory(module *Module) {
655+
func (m *ModuleInstance) buildMemory(module *Module, allocator experimental.MemoryAllocator) {
656656
memSec := module.MemorySection
657657
if memSec != nil {
658-
m.MemoryInstance = NewMemoryInstance(memSec)
658+
m.MemoryInstance = NewMemoryInstance(memSec, allocator)
659659
m.MemoryInstance.definition = &module.MemoryDefinitionSection[0]
660660
}
661661
}

internal/wasm/module_instance.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,24 @@ func (m *ModuleInstance) ensureResourcesClosed(ctx context.Context) (err error)
151151
}
152152

153153
if sysCtx := m.Sys; sysCtx != nil { // nil if from HostModuleBuilder
154-
if err = sysCtx.FS().Close(); err != nil {
155-
return err
156-
}
154+
err = sysCtx.FS().Close()
157155
m.Sys = nil
158156
}
159157

160-
if m.CodeCloser == nil {
161-
return
158+
if mem := m.MemoryInstance; mem != nil {
159+
if mem.expBuffer != nil {
160+
mem.expBuffer.Free()
161+
mem.expBuffer = nil
162+
}
162163
}
163-
if e := m.CodeCloser.Close(ctx); e != nil && err == nil {
164-
err = e
164+
165+
if m.CodeCloser != nil {
166+
if e := m.CodeCloser.Close(ctx); err == nil {
167+
err = e
168+
}
169+
m.CodeCloser = nil
165170
}
166-
m.CodeCloser = nil
167-
return
171+
return err
168172
}
169173

170174
// Memory implements the same method as documented on api.Module.

internal/wasm/module_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ func TestModule_buildGlobals(t *testing.T) {
839839
func TestModule_buildMemoryInstance(t *testing.T) {
840840
t.Run("nil", func(t *testing.T) {
841841
m := ModuleInstance{}
842-
m.buildMemory(&Module{})
842+
m.buildMemory(&Module{}, nil)
843843
require.Nil(t, m.MemoryInstance)
844844
})
845845
t.Run("non-nil", func(t *testing.T) {
@@ -850,7 +850,7 @@ func TestModule_buildMemoryInstance(t *testing.T) {
850850
m.buildMemory(&Module{
851851
MemorySection: &Memory{Min: min, Cap: min, Max: max},
852852
MemoryDefinitionSection: []MemoryDefinition{mDef},
853-
})
853+
}, nil)
854854
mem := m.MemoryInstance
855855
require.Equal(t, min, mem.Min)
856856
require.Equal(t, max, mem.Max)

internal/wasm/store.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync/atomic"
99

1010
"github.com/tetratelabs/wazero/api"
11+
"github.com/tetratelabs/wazero/experimental"
1112
"github.com/tetratelabs/wazero/internal/ctxkey"
1213
"github.com/tetratelabs/wazero/internal/internalapi"
1314
"github.com/tetratelabs/wazero/internal/leb128"
@@ -362,8 +363,13 @@ func (s *Store) instantiate(
362363
return nil, err
363364
}
364365

366+
var allocator experimental.MemoryAllocator
367+
if ctx != nil {
368+
allocator, _ = ctx.Value(ctxkey.MemoryAllocatorKey{}).(experimental.MemoryAllocator)
369+
}
370+
365371
m.buildGlobals(module, m.Engine.FunctionInstanceReference)
366-
m.buildMemory(module)
372+
m.buildMemory(module, allocator)
367373
m.Exports = module.Exports
368374
for _, exp := range m.Exports {
369375
if exp.Type == ExternTypeTable {

0 commit comments

Comments
 (0)