Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,15 @@ research_*.json
research_*.jsonc
daemon_logs*
paper
val_results.md
cocktail_vs_separate*
cocktail_results_*
PRINCIPLE.md
TODO.md
plot_mean_reward.py
appworld_swarm_results_token
appworld_swarm_results_text
appworld_swarm_results
TODO
.trash.*
tutorial/opencode_build_aime/auto_research/results/*
13 changes: 9 additions & 4 deletions ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def run_ppo(config: DictConfig) -> None:
if not ray.is_initialized():
# this is for local ray cluster
runtime_env = get_runtime_env(config)
ray.init(
runtime_env=runtime_env,
)
ray.init(runtime_env=runtime_env)

def on_shutdown():
if ray.is_initialized():
Expand Down Expand Up @@ -132,9 +130,16 @@ def run(self, config):

# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
from ajet.tokenizer.service import start_tokenizer_service

trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
local_tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Cache hot tokenization calls (encode / decode / apply_chat_template)
# in a sidecar process; every other tokenizer attribute is served by
# the local instance directly.
tokenizer = start_tokenizer_service(
local_tokenizer, local_path, trust_remote_code=trust_remote_code
)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)

Expand Down
185 changes: 173 additions & 12 deletions ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataP
task_id_counter[tid] += 1
else:
task_id_counter[tid] = 1
current_id = task_id_counter[tid]
gen_batch_output.non_tensor_batch['rollout_ids'][i] = f"T{tid}R{current_id}"
logger.info(f'task_id_counter: {task_id_counter}')
return gen_batch_output

Expand Down Expand Up @@ -166,6 +164,117 @@ def import_or_export_data_proto(batch: DataProto, direction: str = "export", fil
else:
raise ValueError(f"direction must be 'import' or 'export', got '{direction}'")

def compute_grpo_episode_level_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
episode_index: np.ndarray,
norm_adv_by_std_in_grpo: bool = True,
epsilon: float = 1e-6,
) -> tuple[torch.Tensor, torch.Tensor]:
"""GRPO outcome advantage with the baseline computed at *episode* scope.

Mirrors ``verl.trainer.ppo.core_algos.compute_grpo_outcome_advantage`` but,
instead of treating every sample equally when forming the per-task (``uid``)
baseline, it first reduces every episode (``episode_uuids``) to its mean
scalar reward and then computes the task baseline mean/std over those
per-episode means. This way an episode that produced many samples does not
dominate the baseline of an episode that produced few.

Example (matches the documented behaviour):
task T -> episode 1 (2 samples, reward 1) + episode 2 (1 sample, reward 0)
sample scope baseline = (1 + 1 + 0) / 3 = 0.667
episode scope baseline = (mean[1, 1] + mean[0]) / 2 = (1 + 0) / 2 = 0.5

Args:
token_level_rewards: (bsz, response_length) reward tensor.
response_mask: (bsz, response_length) mask of trainable response tokens.
index: per-sample task id (``non_tensor_batch["uid"]``).
episode_index: per-sample episode id (``non_tensor_batch["episode_uuids"]``).
norm_adv_by_std_in_grpo: divide the centred reward by the (episode-level)
group std when True, otherwise only subtract the group mean.
epsilon: numerical-stability term added to the std denominator.

Returns:
(advantages, returns) - both (bsz, response_length); identical, as in GRPO.
"""
scores = (token_level_rewards * response_mask).sum(dim=-1) # (bsz,) scalar reward
bsz = scores.shape[0]

# 1) reduce each episode to its mean scalar reward
episode_score_sum: dict = defaultdict(float)
episode_score_cnt: dict = defaultdict(int)
for i in range(bsz):
ep = episode_index[i]
episode_score_sum[ep] += scores[i].item()
episode_score_cnt[ep] += 1
episode_mean = {ep: episode_score_sum[ep] / episode_score_cnt[ep] for ep in episode_score_sum}

# 2) collect, per task, the set of distinct episodes it produced
task2episodes: dict = defaultdict(dict) # use dict as ordered set
for i in range(bsz):
task2episodes[index[i]][episode_index[i]] = None

# 3) per-task baseline = mean/std over the per-episode means.
# Single-episode tasks are degenerate -> follow verl's convention
# (mean=0, std=1) so the advantage reduces to the raw score.
task_mean: dict = {}
task_std: dict = {}
for task, episodes in task2episodes.items():
vals = torch.tensor([episode_mean[ep] for ep in episodes], dtype=torch.float32)
if vals.numel() == 1:
task_mean[task] = torch.tensor(0.0)
task_std[task] = torch.tensor(1.0)
else:
task_mean[task] = vals.mean()
task_std[task] = vals.std()

# 4) centre (and optionally normalise) every sample against its task baseline
adv = scores.clone()
for i in range(bsz):
task = index[i]
if norm_adv_by_std_in_grpo:
adv[i] = (scores[i] - task_mean[task]) / (task_std[task] + epsilon)
else:
adv[i] = scores[i] - task_mean[task]

adv = adv.unsqueeze(-1) * response_mask
return adv, adv


def compute_episode_level_loss_weight(data: DataProto) -> torch.Tensor:
"""Per-token loss weight that makes every episode contribute equally.

Each sample belonging to an episode (same ``non_tensor_batch["episode_uuids"]``)
that produced ``N`` samples receives weight ``1 / N``. The weights of all
samples of one episode therefore sum to 1, so an episode that emitted many
samples does not contribute more to the loss than one that emitted few.

The weight is broadcast across the response dimension so it has the **same
shape as ``advantages``** ((bsz, response_length)); this lets it multiply
both the per-token policy-gradient term and the per-token KL term directly.

Returns:
A (bsz, response_length) tensor (matching ``data.batch["advantages"]``
dtype/device) of per-token loss weights, constant along the response
dimension for a given sample.
"""
episode_index = data.non_tensor_batch["episode_uuids"]
bsz = len(episode_index)
episode_count: dict = defaultdict(int)
for ep in episode_index:
episode_count[ep] += 1
advantages = data.batch["advantages"] # (bsz, response_length)
per_sample = torch.tensor(
[1.0 / episode_count[episode_index[i]] for i in range(bsz)],
dtype=advantages.dtype,
device=advantages.device,
)
# broadcast per-sample weight to the same shape as advantages
weights = per_sample.view(-1, 1) * torch.ones_like(advantages)
return weights


def compute_advantage(
data: DataProto,
adv_estimator: AdvantageEstimator,
Expand All @@ -174,6 +283,7 @@ def compute_advantage(
num_repeat: int = 1,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
advantage_estimation_episode_level: bool = False,
) -> DataProto:
"""Compute advantage estimates for policy optimization.

Expand All @@ -189,13 +299,21 @@ def compute_advantage(
norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in
GRPO. Defaults to True.
config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None.
advantage_estimation_episode_level (bool, optional): When True (and using the GRPO estimator),
the GRPO baseline is computed at episode scope instead of sample scope so every episode
contributes equally regardless of how many samples it produced. Defaults to False.

Returns:
DataProto: The updated data with computed advantages and returns.
"""
# Back-compatible with trainers that do not compute response mask in fit
if "response_mask" not in data.batch.keys():
data.batch["response_mask"] = compute_response_mask(data)
if advantage_estimation_episode_level and adv_estimator != AdvantageEstimator.GRPO:
raise NotImplementedError(
"ajet.trainer_common.advantage_estimation_episode_level is only "
f"supported with the GRPO advantage estimator, got {adv_estimator}."
)
# prepare response group
if adv_estimator == AdvantageEstimator.GAE:
# Compute advantages and returns using Generalized Advantage Estimation (GAE)
Expand All @@ -222,13 +340,30 @@ def compute_advantage(
response_length = grpo_calculation_mask.size(1)
# This mask is the one intended for GRPO
grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:]
# Call compute_grpo_outcome_advantage with parameters matching its definition
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data.non_tensor_batch["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
if advantage_estimation_episode_level:
# Episode-scope baseline: every episode contributes equally to the
# per-task baseline regardless of how many samples it produced.
if "episode_uuids" not in data.non_tensor_batch:
raise KeyError(
"advantage_estimation_episode_level is enabled but "
"non_tensor_batch['episode_uuids'] is missing; cannot identify "
"same-episode samples."
)
advantages, returns = compute_grpo_episode_level_outcome_advantage(
token_level_rewards=data.batch["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data.non_tensor_batch["uid"],
episode_index=data.non_tensor_batch["episode_uuids"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
else:
# Call compute_grpo_outcome_advantage with parameters matching its definition
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data.non_tensor_batch["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
Expand Down Expand Up @@ -511,15 +646,16 @@ def fit(self): # noqa: C901
]
)
)
logger.info("start fit rollout")
logger.info("start batch rollout")
self.parallel_env.current_global_steps = self.global_steps
# rollout stage begin ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨
context_tracker_arr: List[SingleAgentContextTracker] = self.parallel_env.rollout(
tasks, mode="sample", epoch=f"train.{epoch}"
)

# from ajet import bp; bp("BATCH")

logger.info("end fit rollout")
logger.info("end batch rollout")
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
logger.info("end dataproto convertion")

Expand Down Expand Up @@ -691,6 +827,13 @@ def fit(self): # noqa: C901
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor

# [AJET] episode-scope advantage baseline (disabled by default)
advantage_estimation_episode_level = bool(
self.config.ajet.trainer_common.get(
"advantage_estimation_episode_level", False
)
)

batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
Expand All @@ -699,8 +842,26 @@ def fit(self): # noqa: C901
num_repeat=self.config.ajet.rollout.num_repeat,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
advantage_estimation_episode_level=advantage_estimation_episode_level,
)

# [AJET] per-sample loss weight that makes every episode
# contribute equally to the policy-gradient update
# (disabled by default). Consumed in
# AjetDataParallelPPOActor.update_policy.
if bool(
self.config.ajet.trainer_common.get(
"loss_weight_normalization_episode_level", False
)
):
if "episode_uuids" not in batch.non_tensor_batch:
raise KeyError(
"loss_weight_normalization_episode_level is enabled but "
"non_tensor_batch['episode_uuids'] is missing; cannot "
"identify same-episode samples."
)
batch.batch["loss_weight"] = compute_episode_level_loss_weight(batch)

# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
Expand All @@ -710,7 +871,7 @@ def fit(self): # noqa: C901

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
# update actor ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨
with marked_timer("update_actor", timing_raw, color="red"):
actor_output = self._update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
Expand Down
25 changes: 25 additions & 0 deletions ajet/backbone/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def update_policy(self, data: DataProto):
# Include rollout_log_probs for computing rollout_corr metrics in bypass mode
if "rollout_log_probs" in data.batch.keys():
select_keys.append("rollout_log_probs")
# [AJET] per-sample loss weight (episode-level loss normalization).
# Present only when ajet.trainer_common.loss_weight_normalization_episode_level
# is enabled; absent => every sample weighted equally (default behaviour).
if "loss_weight" in data.batch.keys():
select_keys.append("loss_weight")

has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = []
Expand Down Expand Up @@ -208,6 +213,20 @@ def update_policy(self, data: DataProto):
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
# [AJET] Episode-level loss-weight normalization.
# When ajet.trainer_common.loss_weight_normalization_episode_level
# is enabled, every sample carries a per-token weight (1/N for
# an episode that produced N samples), same shape as advantages.
# Scaling the advantages by this positive weight scales each
# sample's policy-gradient contribution by the same factor (the
# clip/ratio behaviour is unchanged since the weight is a
# positive per-sample constant); the same weight is applied to
# the per-token KL term below, so every episode contributes
# equally to the total loss.
loss_weight = model_inputs.get("loss_weight", None)
if loss_weight is not None:
loss_weight = loss_weight.to(advantages.dtype)
advantages = advantages * loss_weight
# [AJET] Debug logging for tensor shapes
input_ids = model_inputs["input_ids"]
_shape_msg = f'[Update Policy] -> Micro batch shape, input_ids {input_ids.shape}, response {response_mask.shape} @{micro_batch_idx}/{num_micro_batches}'
Expand Down Expand Up @@ -294,6 +313,12 @@ def update_policy(self, data: DataProto):
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
)
# [AJET] apply the per-token episode-level loss weight to
# the KL term as well (same weight/shape used for the
# policy-gradient term above), so each episode contributes
# equally to the KL loss too.
if loss_weight is not None:
kld = kld * loss_weight
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
Expand Down
11 changes: 5 additions & 6 deletions ajet/backbone/warm_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import logging
import os
from datetime import datetime
from ajet.utils.async_utils import (
apply_httpx_aclose_patch,
silence_hermes_tool_parser_loggers,
Expand All @@ -14,21 +15,19 @@
apply_httpx_aclose_patch()
suppress_httpx_aclose_exception()

def init_parallel_rollout_logger(experiment_name, experiment_dir):
def init_parallel_rollout_logger(experiment_dir):
"""Initialize the logger with the given configuration."""
if "PROCESS_LEVEL_WARMUP_INIT_LOGGER" in os.environ:
return
os.environ["PROCESS_LEVEL_WARMUP_INIT_LOGGER"] = "1"

from datetime import datetime
os.environ["PROCESS_LEVEL_WARMUP_INIT_LOGGER"] = "1"

from beast_logger import register_logger

final_log_path = os.path.join(
experiment_dir,
datetime.now().strftime("%Y_%m_%d_%H_%M"),
# machine host name
os.uname().nodename,
os.uname().nodename, # machine host name
)
os.environ["BEST_LOGGER_PATH"] = final_log_path
non_console_mods = ["rollout", "token_clip", "bad_case"]
Expand Down Expand Up @@ -102,6 +101,6 @@ def warm_up_process(config):
os.environ["PROCESS_LEVEL_WARMUP_INIT"] = "1"
experiment_name = config.ajet.experiment_name
experiment_dir = config.ajet.experiment_dir
init_parallel_rollout_logger(experiment_name, experiment_dir)
init_parallel_rollout_logger(experiment_dir)
warm_up_task_judge_when_needed(config)
clean_up_tmp_ajet_dir(config)
Loading
Loading