Skip to content

Create present #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Create present #427

wants to merge 2 commits into from

Conversation

JosefaOrtiz
Copy link

No description provided.

@xxh031203
Copy link

xxh031203 commented May 29, 2025 via email

@JosefaOrtiz
Copy link
Author

`# Copyright 2024 X.AI Corp.

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

import bisect
import functools
import logging
import math
import re
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional, Tuple

import haiku as hk
import jax
import jax.experimental.pjit as pjit
import jax.numpy as jnp
import numpy as np
import sentencepiece
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
from jax.typing import ArrayLike

import checkpoint as xai_checkpoint
from model import (
LanguageModelConfig,
LanguageModelOutput,
TrainingState,
apply_rules,
Memory,
KVMemory,
)

logger = logging.getLogger(name)
rank_logger = logging.getLogger("rank")

TOP_K = 8

class SampleSettings(NamedTuple):
temperature: ArrayLike
nucleus_p: ArrayLike
mask: ArrayLike
# Whether a given batch element is actively used. [B]
active: ArrayLike

class SampleOutput(NamedTuple):
token_id: ArrayLike
prob: ArrayLike
top_k_token_ids: ArrayLike
top_k_probs: ArrayLike

def insert_slice(memory: Memory, slice, length, i):
slice = Memory(
layers=[
KVMemory(layer.k, layer.v, step=jnp.array([length]))
for layer in slice.layers
],
)

return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),
                    memory, slice)

def pad_to_size(x, size):
if x.shape[0] > size:
# Left truncate if the context is too long.
x = x[-size:]
return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0)

def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:
"""Performs nucleus filtering on logits."""
assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}"
sorted_logits = jax.lax.sort(logits, is_stable=False)
sorted_probs = jax.nn.softmax(sorted_logits)
threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)
threshold_largest_logits = jnp.take_along_axis(
sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1
)
assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
mask = logits >= threshold_largest_logits
# Set unused logits to -inf.
logits = jnp.where(mask, logits, -1e10)
return logits

def sample_token(
rngs: jax.random.PRNGKey,
lm_outputs: LanguageModelOutput,
settings: SampleSettings,
) -> SampleOutput:
# Expand the settings shape to match the logit shape.
settings = SampleSettings(
temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1].
nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1].
mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V].
active=settings.active, # [B].
)
logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)
# Mask out all disallowed tokens by assigning them a near-zero probability.
logits = jnp.where(settings.mask, logits, -1e10)
# Mask out all tokens that don't fall into the p-th percentile.
logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))

new_token = jax.vmap(jax.random.categorical)(rngs, logits)

probabilities = jax.nn.softmax(logits)
token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)
token_prob = jnp.squeeze(token_prob, 1)

# Gather the top-k tokens and probabilities.
top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)
top_k_probs = jnp.squeeze(top_k_probs, 1)
top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)
return SampleOutput(
    new_token,
    token_prob,
    top_k_token_ids,
    top_k_probs,
)

@DataClass
class ModelRunner:
model: LanguageModelConfig

bs_per_device: float = 2.0

load_rename_rules: Optional[list[tuple[str, str]]] = None
load_exclude_rules: Optional[list[str]] = None

rng_seed: int = 42  # Initial rng seed.
transform_forward: bool = False

checkpoint_path: str = ""

def make_forward_fn(self, mesh: Any):
    def forward(tokens):
        out = self.model.make(mesh=mesh)(tokens)
        return out, None

    if self.transform_forward:
        forward = hk.transform(forward)
    return forward

def initialize(
    self,
    init_data,
    local_mesh_config: tuple[int, int],
    between_hosts_config: tuple[int, int],
):
    num_replicas = math.prod(between_hosts_config)
    self.model.initialize()
    self.model.fprop_dtype = jnp.bfloat16
    num_local_gpus = len(jax.local_devices())

    # Calculate the global batch size from the local batch size.
    self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)

    # Calculate the batch size per host from the global batch size.
    self.local_batch_size = self.batch_size // jax.process_count()

    self.local_mesh_config = local_mesh_config
    self.between_hosts_config = between_hosts_config
    rank_logger.info(
        f"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}..."
    )
    self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)
    self.forward = self.make_forward_fn(mesh=self.mesh)
    self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])

    self.eval_forward = self.make_forward_fn(mesh=self.mesh)
    self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])

    if self.transform_forward:
        self.state_sharding = self.get_state_sharding(init_data)
        rank_logger.info(f"State sharding type: {type(self.state_sharding)}")
        self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)

def init(self, rng: jax.Array, data) -> TrainingState:
    assert self.transform_forward
    rng, init_rng = jax.random.split(rng)
    params = self.forward.init(init_rng, data["inputs"])
    return TrainingState(params=params)

def get_state_sharding(self, init_data):
    assert self.transform_forward
    rng = jax.random.PRNGKey(self.rng_seed)
    rank_logger.info(f"partition rules: {self.model.partition_rules}")

    with self.mesh:
        shapes = jax.eval_shape(self.init, rng, init_data)
        sharding = jax.tree_util.tree_map_with_path(
            apply_rules(self.model.partition_rules()),
            shapes,
        )
    return sharding

def load_or_init(
    self,
    init_data: Any,
    from_checkpoint: bool = True,
    init_fn: Optional[Callable] = None,
):
    rng = jax.random.PRNGKey(self.rng_seed)

    if not self.checkpoint_path or not from_checkpoint:
        rank_logger.info("Initializing model...")
        with self.mesh:
            if init_fn is not None:
                state = init_fn(rng, init_data)
            else:
                assert self.transform_forward
                state = self.init_fn(rng, init_data)
        rank_logger.info("Model state is newly initialized.")
    else:
        with self.mesh:
            if init_fn:
                state_shapes = jax.eval_shape(init_fn, rng, init_data)
            else:
                assert self.transform_forward
                state_shapes = jax.eval_shape(self.init_fn, rng, init_data)
        init_state = None

        state = xai_checkpoint.restore(
            checkpoint_path=self.checkpoint_path,
            state_shapes=state_shapes,
            mesh=self.mesh,
            between_hosts_config=self.between_hosts_config,
            state_sharding=self.state_sharding,
            init_state=init_state,
            params_only=True,
        )

        del init_state
    return state

@DataClass
class Request:
prompt: str
temperature: float
nucleus_p: float
rng_seed: int
max_len: int

@DataClass
class InferenceRunner:
name: str
runner: Any
load: str
tokenizer_path: str = "/tmp/xai_data/tokenizer.model"
local_mesh_config: Tuple[int, int] = (1, 1)
between_hosts_config: Tuple[int, int] = (1, 1)
pad_sizes: tuple[int] = (1024,)

def get_pad_bucket(self, size):
    i = bisect.bisect_left(self.pad_sizes, size)
    return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]

def initialize(self):
    runner = self.runner
    self.runner.transform_forward = True
    dummy_data = dict(
        inputs=np.zeros((1, 256), dtype=np.int32),
        targets=np.zeros((1, 256), dtype=np.int32),
    )
    runner.initialize(
        dummy_data,
        local_mesh_config=self.local_mesh_config,
        between_hosts_config=self.between_hosts_config,
    )

    self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)

    max_len = runner.model.sequence_len

    self.vocab_size = self.runner.model.vocab_size

    params = runner.load_or_init(dummy_data)
    self.params = params

    def pad_to_max_len(x):
        if len(x.shape) > 1:
            pad_width = max_len - x.shape[1]
            return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])
        else:
            return x

    @functools.lru_cache
    def lm():
        return runner.model.make(mesh=runner.mesh)

    def hk_forward(
        tokens,
        memory=None,
        length=None,
        active=None,
    ) -> LanguageModelOutput:
        if memory is not None:
            assert active is not None
            layers = []
            for l in memory.layers:
                # Reset steps to 0 for inactive requests to avoid unnecessary computations.
                step = jnp.where(active, l.step, jnp.zeros_like(l.step))
                layers.append(l._replace(step=step))
            memory = memory._replace(layers=layers)
        return lm()(tokens, memory, length=length)

    def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):
        rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)
        lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)
        sample_result = sample_token(rngs_, lm_outputs, settings)
        return rngs, sample_result, lm_outputs.model_state

    def hk_new_memory(batch_size, sequence_len):
        return lm().init_memory(batch_size, sequence_len)

    def hk_prefill_memory(
        rngs,
        memory,
        settings,
        last_output,
        prompt,
        length,
        rng_seed,
        new_settings,
        i,
    ):
        rng = jax.random.PRNGKey(seed=rng_seed)
        rng, rng_ = jax.random.split(rng)

        # Allocate new memory for this sample. The memory length is equal to the length of the
        # prompt.
        slice = hk_new_memory(1, prompt.shape[0])

        # Move the settings for this individual batch entry into the joint settings tensor.
        settings = jax.tree_map(
            lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),
            settings,
            new_settings,
        )

        # Get the settings for the batch entry from the joint settings tensor.
        settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)

        # Process the first n-1 tokens of the prompt.
        lm_outputs = hk_forward(
            jnp.expand_dims(prompt, 0),
            memory=slice,
            length=jnp.expand_dims(length, 0),
            active=settings_slice.active,
        )

        # The forward pass doesn't correctly set the `step` counter inside the memory. Manually
        # override it so `hk_forward` uses the correct context length in the next call.
        slice = lm_outputs.model_state
        slice = slice._replace(
            layers=[l._replace(step=jnp.array([length])) for l in slice.layers]
        )

        # Sample the actual output token.
        rng_ = jnp.expand_dims(rng_, 0)
        new_output = sample_token(rng_, lm_outputs, settings_slice)

        # Update the KV cache/memory.
        slice = jax.tree_map(pad_to_max_len, slice)
        memory = insert_slice(memory, slice, length, i)

        rng = jnp.expand_dims(rng, 0)
        rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)

        # Move the network outputs for this batch entry into the joint output tensor.
        last_output = jax.tree_util.tree_map(
            lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),
            last_output,
            new_output,
        )
        return rngs, last_output, memory, settings

    sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))
    prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))
    new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))
    forward_ = hk.without_apply_rng(hk.transform(hk_forward))

    rng = jax.random.PRNGKey(42)
    dummy_tokens = jnp.zeros((1, max_len), jnp.int32)

    with runner.mesh:
        shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)

    self.params_sharding = jax.tree_util.tree_map_with_path(
        apply_rules(runner.model.partition_rules()),
        shapes,
    )

    ds = P("data")
    ms = runner.model.model.get_memory_sharding()
    self.sample_step = pjit.pjit(
        sample_step_.apply,
        in_shardings=(self.params_sharding, None, ds, ms, None),
        out_shardings=(None, ds, ms),
        donate_argnums=3,
    )
    self.prefill_memory = pjit.pjit(
        functools.partial(prefill_memory_.apply),
        in_shardings=(
            self.params_sharding,
            None,
            ms,
            None,
            ds,
            None,
            None,
            None,
            None,
            None,
        ),
        out_shardings=(None, ds, ms, None),
        donate_argnums=(2,),
    )
    self.new_memory = pjit.pjit(
        new_memory_.apply,
        static_argnums=(1, 2),
        out_shardings=ms,
    )

def run(self):
    """Generator that accepts prompts."""
    runner = self.runner
    mesh = runner.mesh
    max_len = runner.model.sequence_len
    batch_size = runner.batch_size
    params = self.params
    rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)
    with mesh:
        memory = self.new_memory(params, batch_size, max_len)
        settings = SampleSettings(
            temperature=np.zeros((batch_size,), dtype=np.float32),
            nucleus_p=np.zeros((batch_size,), dtype=np.float32),
            mask=np.ones((batch_size, self.vocab_size), dtype=np.int32),
            active=np.zeros((batch_size), dtype=np.int32),
        )
        last_output = SampleOutput(
            token_id=np.zeros((batch_size, 1), dtype=np.int32),
            prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),
            top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),
            top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),
        )

        prompt = np.array([300, 400, 500, 600, 600, 700, 800])

        new_settings = SampleSettings(
            temperature=np.float32(1),
            nucleus_p=np.float32(1),
            mask=np.ones((self.vocab_size,), dtype=np.int32),
            active=np.zeros((), dtype=np.int32),
        )
        rng_seed = np.uint64(1)

        for size in self.pad_sizes:
            if size > runner.model.sequence_len:
                break
            logger.info("Precompile {}".format(size))
            prompt_len = len(prompt)
            prompt = pad_to_size(prompt, size)
            rngs, last_output, memory, settings = self.prefill_memory(
                params,
                rngs,
                memory,
                settings,
                last_output,
                prompt,
                prompt_len,
                rng_seed,
                new_settings,
                0,
            )
    with runner.mesh:
        logger.info("Compiling...")
        rngs, last_output, memory = self.sample_step(
            params, rngs, last_output, memory, settings
        )
        logger.info("Done compiling.")

    all_tokens = []
    free_slots = list(range(batch_size))
    requests = [None] * batch_size
    first_output = [None] * batch_size
    jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
    prev_token = last_output
    step = 0
    total_num_tokens = 0
    total_num_sequences = 0
    with mesh:
        while True:
            while free_slots:
                request: Optional[Request] = yield
                tokens = self.tokenizer.encode(request.prompt)
                temperature = request.temperature
                nucleus_p = request.nucleus_p
                rng_seed = request.rng_seed

                i = free_slots.pop()
                prompt = np.array(tokens, dtype=np.int32)
                prompt_len = len(prompt)
                prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))
                # All tokens are allowed.
                mask = np.ones((self.vocab_size,), dtype=np.int32)

                new_settings = SampleSettings(
                    temperature=np.float32(temperature),
                    nucleus_p=np.float32(nucleus_p),
                    mask=mask,
                    active=np.ones((), dtype=np.int32),
                )
                rng_seed = np.uint64(rng_seed)
                rngs, last_output, memory, settings = self.prefill_memory(
                    params,
                    rngs,
                    memory,
                    settings,
                    last_output,
                    prompt,
                    prompt_len,
                    rng_seed,
                    new_settings,
                    i,
                )
                jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
                first_output[i] = last_output
                requests[i] = request
                total_num_sequences += 1

            rngs, last_output, memory = self.sample_step(
                params, rngs, last_output, memory, settings
            )
            total_num_tokens += batch_size - len(free_slots)

            # prev_token should already be on the host.
            prev_token = jax.tree_map(np.array, prev_token)
            for i in range(batch_size):
                if requests[i] is not None:
                    if first_output[i] is not None:
                        first_output_i = jax.tree_map(np.array, first_output[i])
                        all_tokens.append(int(first_output_i.token_id[i][0]))
                        first_output[i] = None
                        continue

                    all_tokens.append(int(prev_token.token_id[i][0]))
                    cont = len(all_tokens) < requests[i].max_len

                    if not cont:
                        output_str = self.tokenizer.decode(all_tokens)
                        requests[i] = None
                        free_slots.append(i)
                        all_tokens = []
                        settings = settings._replace(active=settings.active.at[i].set(0))
                        yield output_str

            jax.tree_map(lambda x: x.copy_to_host_async(), last_output)
            prev_token = last_output
            step += 1

def make_mesh(
local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]
) -> jax.sharding.Mesh:
assert len(local_mesh_config) == 2
assert len(between_hosts_config) == 2
rank_logger.info("Detected %s devices in mesh", jax.device_count())
device_mesh = mesh_utils.create_hybrid_device_mesh(
local_mesh_config,
between_hosts_config,
devices=jax.devices(),
process_is_granule=True,
)
rank_logger.debug(re.sub("\n+", "\n", f"Job device mesh is:\n{device_mesh}"))
return jax.sharding.Mesh(device_mesh, ("data", "model"))

def sample_from_model(server, prompt, max_len, temperature):
next(server)
inp = Request(
prompt=prompt,
temperature=temperature,
nucleus_p=1.0,
rng_seed=42,
max_len=max_len,
)
return server.send(inp)
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants