Skip to content

Commit b85923a

Browse files
authored
GH-119726: Deduplicate AArch64 trampolines within a trace (GH-123872)
1 parent 7a178b7 commit b85923a

File tree

5 files changed

+147
-59
lines changed

5 files changed

+147
-59
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
The JIT now generates more efficient code for calls to C functions resulting
2+
in up to 0.8% memory savings and 1.5% speed improvement on AArch64. Patch by Diego Russo.

Python/jit.c

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "Python.h"
44

55
#include "pycore_abstract.h"
6+
#include "pycore_bitutils.h"
67
#include "pycore_call.h"
78
#include "pycore_ceval.h"
89
#include "pycore_critical_section.h"
@@ -113,6 +114,21 @@ mark_executable(unsigned char *memory, size_t size)
113114

114115
// JIT compiler stuff: /////////////////////////////////////////////////////////
115116

117+
#define SYMBOL_MASK_WORDS 4
118+
119+
typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];
120+
121+
typedef struct {
122+
unsigned char *mem;
123+
symbol_mask mask;
124+
size_t size;
125+
} trampoline_state;
126+
127+
typedef struct {
128+
trampoline_state trampolines;
129+
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
130+
} jit_state;
131+
116132
// Warning! AArch64 requires you to get your hands dirty. These are your gloves:
117133

118134
// value[value_start : value_start + len]
@@ -390,66 +406,126 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
390406
patch_32r(location, value);
391407
}
392408

409+
void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
410+
393411
#include "jit_stencils.h"
394412

413+
#if defined(__aarch64__) || defined(_M_ARM64)
414+
#define TRAMPOLINE_SIZE 16
415+
#else
416+
#define TRAMPOLINE_SIZE 0
417+
#endif
418+
419+
// Generate and patch AArch64 trampolines. The symbols to jump to are stored
420+
// in the jit_stencils.h in the symbols_map.
421+
void
422+
patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
423+
{
424+
// Masking is done modulo 32 as the mask is stored as an array of uint32_t
425+
const uint32_t symbol_mask = 1 << (ordinal % 32);
426+
const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
427+
assert(symbol_mask & trampoline_mask);
428+
429+
// Count the number of set bits in the trampoline mask lower than ordinal,
430+
// this gives the index into the array of trampolines.
431+
int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
432+
for (int i = 0; i < ordinal / 32; i++) {
433+
index += _Py_popcount32(state->trampolines.mask[i]);
434+
}
435+
436+
uint32_t *p = (uint32_t*)(state->trampolines.mem + index * TRAMPOLINE_SIZE);
437+
assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
438+
439+
uint64_t value = (uintptr_t)symbols_map[ordinal];
440+
441+
/* Generate the trampoline
442+
0: 58000048 ldr x8, 8
443+
4: d61f0100 br x8
444+
8: 00000000 // The next two words contain the 64-bit address to jump to.
445+
c: 00000000
446+
*/
447+
p[0] = 0x58000048;
448+
p[1] = 0xD61F0100;
449+
p[2] = value & 0xffffffff;
450+
p[3] = value >> 32;
451+
452+
patch_aarch64_26r(location, (uintptr_t)p);
453+
}
454+
455+
static void
456+
combine_symbol_mask(const symbol_mask src, symbol_mask dest)
457+
{
458+
// Calculate the union of the trampolines required by each StencilGroup
459+
for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) {
460+
dest[i] |= src[i];
461+
}
462+
}
463+
395464
// Compiles executor in-place. Don't forget to call _PyJIT_Free later!
396465
int
397466
_PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length)
398467
{
399468
const StencilGroup *group;
400469
// Loop once to find the total compiled size:
401-
uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
402470
size_t code_size = 0;
403471
size_t data_size = 0;
472+
jit_state state = {};
404473
group = &trampoline;
405474
code_size += group->code_size;
406475
data_size += group->data_size;
407476
for (size_t i = 0; i < length; i++) {
408477
const _PyUOpInstruction *instruction = &trace[i];
409478
group = &stencil_groups[instruction->opcode];
410-
instruction_starts[i] = code_size;
479+
state.instruction_starts[i] = code_size;
411480
code_size += group->code_size;
412481
data_size += group->data_size;
482+
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
413483
}
414484
group = &stencil_groups[_FATAL_ERROR];
415485
code_size += group->code_size;
416486
data_size += group->data_size;
487+
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
488+
// Calculate the size of the trampolines required by the whole trace
489+
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
490+
state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
491+
}
417492
// Round up to the nearest page:
418493
size_t page_size = get_page_size();
419494
assert((page_size & (page_size - 1)) == 0);
420-
size_t padding = page_size - ((code_size + data_size) & (page_size - 1));
421-
size_t total_size = code_size + data_size + padding;
495+
size_t padding = page_size - ((code_size + data_size + state.trampolines.size) & (page_size - 1));
496+
size_t total_size = code_size + data_size + state.trampolines.size + padding;
422497
unsigned char *memory = jit_alloc(total_size);
423498
if (memory == NULL) {
424499
return -1;
425500
}
426501
// Update the offsets of each instruction:
427502
for (size_t i = 0; i < length; i++) {
428-
instruction_starts[i] += (uintptr_t)memory;
503+
state.instruction_starts[i] += (uintptr_t)memory;
429504
}
430505
// Loop again to emit the code:
431506
unsigned char *code = memory;
432507
unsigned char *data = memory + code_size;
508+
state.trampolines.mem = memory + code_size + data_size;
433509
// Compile the trampoline, which handles converting between the native
434510
// calling convention and the calling convention used by jitted code
435511
// (which may be different for efficiency reasons). On platforms where
436512
// we don't change calling conventions, the trampoline is empty and
437513
// nothing is emitted here:
438514
group = &trampoline;
439-
group->emit(code, data, executor, NULL, instruction_starts);
515+
group->emit(code, data, executor, NULL, &state);
440516
code += group->code_size;
441517
data += group->data_size;
442518
assert(trace[0].opcode == _START_EXECUTOR);
443519
for (size_t i = 0; i < length; i++) {
444520
const _PyUOpInstruction *instruction = &trace[i];
445521
group = &stencil_groups[instruction->opcode];
446-
group->emit(code, data, executor, instruction, instruction_starts);
522+
group->emit(code, data, executor, instruction, &state);
447523
code += group->code_size;
448524
data += group->data_size;
449525
}
450526
// Protect against accidental buffer overrun into data:
451527
group = &stencil_groups[_FATAL_ERROR];
452-
group->emit(code, data, executor, NULL, instruction_starts);
528+
group->emit(code, data, executor, NULL, &state);
453529
code += group->code_size;
454530
data += group->data_size;
455531
assert(code == memory + code_size);

Tools/jit/_stencils.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dataclasses
44
import enum
5-
import sys
65
import typing
76

87
import _schema
@@ -103,8 +102,8 @@ class HoleValue(enum.Enum):
103102
HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
104103
HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
105104
HoleValue.TARGET: "instruction->target",
106-
HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
107-
HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
105+
HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]",
106+
HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]",
108107
HoleValue.ZERO: "",
109108
}
110109

@@ -125,6 +124,7 @@ class Hole:
125124
symbol: str | None
126125
# ...plus this addend:
127126
addend: int
127+
need_state: bool = False
128128
func: str = dataclasses.field(init=False)
129129
# Convenience method:
130130
replace = dataclasses.replace
@@ -157,10 +157,12 @@ def as_c(self, where: str) -> str:
157157
if value:
158158
value += " + "
159159
value += f"(uintptr_t)&{self.symbol}"
160-
if _signed(self.addend):
160+
if _signed(self.addend) or not value:
161161
if value:
162162
value += " + "
163163
value += f"{_signed(self.addend):#x}"
164+
if self.need_state:
165+
return f"{self.func}({location}, {value}, state);"
164166
return f"{self.func}({location}, {value});"
165167

166168

@@ -175,7 +177,6 @@ class Stencil:
175177
body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
176178
holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
177179
disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
178-
trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
179180

180181
def pad(self, alignment: int) -> None:
181182
"""Pad the stencil to the given alignment."""
@@ -184,39 +185,6 @@ def pad(self, alignment: int) -> None:
184185
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
185186
self.body.extend([0] * padding)
186187

187-
def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
188-
"""Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
189-
assert hole.symbol is not None
190-
reuse_trampoline = hole.symbol in self.trampolines
191-
if reuse_trampoline:
192-
# Re-use the base address of the previously created trampoline
193-
base = self.trampolines[hole.symbol]
194-
else:
195-
self.pad(alignment)
196-
base = len(self.body)
197-
new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)
198-
199-
if reuse_trampoline:
200-
return new_hole
201-
202-
self.disassembly += [
203-
f"{base + 4 * 0:x}: 58000048 ldr x8, 8",
204-
f"{base + 4 * 1:x}: d61f0100 br x8",
205-
f"{base + 4 * 2:x}: 00000000",
206-
f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}",
207-
f"{base + 4 * 3:x}: 00000000",
208-
]
209-
for code in [
210-
0x58000048.to_bytes(4, sys.byteorder),
211-
0xD61F0100.to_bytes(4, sys.byteorder),
212-
0x00000000.to_bytes(4, sys.byteorder),
213-
0x00000000.to_bytes(4, sys.byteorder),
214-
]:
215-
self.body.extend(code)
216-
self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
217-
self.trampolines[hole.symbol] = base
218-
return new_hole
219-
220188
def remove_jump(self, *, alignment: int = 1) -> None:
221189
"""Remove a zero-length continuation jump, if it exists."""
222190
hole = max(self.holes, key=lambda hole: hole.offset)
@@ -282,18 +250,32 @@ class StencilGroup:
282250
default_factory=dict, init=False
283251
)
284252
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
285-
286-
def process_relocations(self, *, alignment: int = 1) -> None:
253+
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
254+
255+
def process_relocations(
256+
self,
257+
known_symbols: dict[str, int],
258+
*,
259+
alignment: int = 1,
260+
) -> None:
287261
"""Fix up all GOT and internal relocations for this stencil group."""
288262
for hole in self.code.holes.copy():
289263
if (
290264
hole.kind
291265
in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
292266
and hole.value is HoleValue.ZERO
293267
):
294-
new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
295-
self.code.holes.remove(hole)
296-
self.code.holes.append(new_hole)
268+
hole.func = "patch_aarch64_trampoline"
269+
hole.need_state = True
270+
assert hole.symbol is not None
271+
if hole.symbol in known_symbols:
272+
ordinal = known_symbols[hole.symbol]
273+
else:
274+
ordinal = len(known_symbols)
275+
known_symbols[hole.symbol] = ordinal
276+
self._trampolines.add(ordinal)
277+
hole.addend = ordinal
278+
hole.symbol = None
297279
self.code.remove_jump(alignment=alignment)
298280
self.code.pad(alignment)
299281
self.data.pad(8)
@@ -348,9 +330,20 @@ def _emit_global_offset_table(self) -> None:
348330
)
349331
self.data.body.extend([0] * 8)
350332

333+
def _get_trampoline_mask(self) -> str:
334+
bitmask: int = 0
335+
trampoline_mask: list[str] = []
336+
for ordinal in self._trampolines:
337+
bitmask |= 1 << ordinal
338+
while bitmask:
339+
word = bitmask & ((1 << 32) - 1)
340+
trampoline_mask.append(f"{word:#04x}")
341+
bitmask >>= 32
342+
return "{" + ", ".join(trampoline_mask) + "}"
343+
351344
def as_c(self, opname: str) -> str:
352345
"""Dump this hole as a StencilGroup initializer."""
353-
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
346+
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
354347

355348

356349
def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:

Tools/jit/_targets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]):
4444
stable: bool = False
4545
debug: bool = False
4646
verbose: bool = False
47+
known_symbols: dict[str, int] = dataclasses.field(default_factory=dict)
4748

4849
def _compute_digest(self, out: pathlib.Path) -> str:
4950
hasher = hashlib.sha256()
@@ -95,7 +96,9 @@ async def _parse(self, path: pathlib.Path) -> _stencils.StencilGroup:
9596
if group.data.body:
9697
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
9798
group.data.disassembly.append(line)
98-
group.process_relocations(alignment=self.alignment)
99+
group.process_relocations(
100+
known_symbols=self.known_symbols, alignment=self.alignment
101+
)
99102
return group
100103

101104
def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
@@ -231,7 +234,7 @@ def build(
231234
if comment:
232235
file.write(f"// {comment}\n")
233236
file.write("\n")
234-
for line in _writer.dump(stencil_groups):
237+
for line in _writer.dump(stencil_groups, self.known_symbols):
235238
file.write(f"{line}\n")
236239
try:
237240
jit_stencils_new.replace(jit_stencils)

0 commit comments

Comments
 (0)