Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 12 additions & 68 deletions torchtitan/experiments/rl/unified/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging
import os

from dataclasses import dataclass
from typing import List

import torch
Expand All @@ -19,43 +18,15 @@
# Import unified module - this automatically registers TorchTitan models with vLLM
from torchtitan.experiments.rl import unified # noqa: F401

from torchtitan.experiments.rl.unified.actors.scorer import TrajectoryData
from torchtitan.experiments.rl.unified.job_config import JobConfig

from torchtitan.experiments.rl.vllm_compat.simple_rl import (
compute_grpo_advantages,
compute_grpo_advantages_stable,
trivial_reward_function,
)
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind

logger = logging.getLogger(__name__)


@dataclass
class TrajectoryData:
"""
Data from one generation batch.

Attributes:
policy_version: Version of policy that produced this batch
completions: List of completion strings
vllm_token_ids: List of token ID lists for each completion
vllm_token_log_probs: List of per-token log prob lists
prompt_token_ids: List of prompt token ID lists
rewards: Computed rewards for each completion
advantages: Computed advantages for each completion
"""

policy_version: int
completions: List[str]
vllm_token_ids: List[List[int]]
vllm_token_log_probs: List[List[float]]
prompt_token_ids: List[List[int]]
rewards: torch.Tensor
advantages: torch.Tensor


class VLLMGenerator:
"""
vLLM engine for fast rollouts with weight updates.
Expand Down Expand Up @@ -258,7 +229,7 @@ class Generator(Actor):

Maintains a vLLM engine that is synchronized with the Trainer
via weight sync. Generates completions for given prompts and
computes rewards/advantages.
returns unscored trajectory data for the Scorer to process.

Args:
job_config: JobConfig dataclass containing all configuration
Expand All @@ -284,8 +255,6 @@ def __init__(
self.max_new_tokens = job_config.generation.sampling.max_tokens
self.temperature = job_config.generation.sampling.temperature
self.group_size = job_config.rl.grpo_group_size
self.grpo_beta = job_config.rl.grpo_beta
self.use_stable_grpo = job_config.rl.use_stable_grpo

# Initialize distributed environment for SPMD generator
world_size = dist_utils.init_distributed(
Expand All @@ -299,14 +268,15 @@ def __init__(
self.cond = asyncio.Condition()
self.policy_version = 0

# Reward function. TODO: Use a real reward function
self.reward_fn = trivial_reward_function

logger.info("Generator initialized with vLLM engine")

@endpoint
async def generate(self) -> None:
"""Generate trajectories and compute rewards/advantages."""
async def generate(self) -> TrajectoryData:
"""Generate trajectories without computing rewards/advantages.

Returns:
TrajectoryData for the Scorer to process (rewards=None)
"""
logger.info(
f"{os.getpid()=} Generating start generate (policy v{self.policy_version})..."
)
Expand All @@ -330,43 +300,17 @@ async def generate(self) -> None:
n_samples_per_prompt=self.group_size,
)

# Compute rewards
logger.info(
f"Computing rewards: {len(completions)} completions, "
f"{len(self.expected_answers)} expected answers, "
f"group_size={self.group_size}"
)
rewards = self.reward_fn(
completions, self.expected_answers, self.group_size
)
logger.info(f"Generated {len(completions)} completions for scoring")

# Normalize rewards
reward_mean = rewards.mean()
reward_std = rewards.std()
if reward_std > 1e-8:
rewards_normalized = (rewards - reward_mean) / reward_std
else:
rewards_normalized = rewards - reward_mean

# Compute advantages using GRPO
if self.use_stable_grpo:
advantages = compute_grpo_advantages_stable(
rewards_normalized, self.group_size
)
else:
advantages = compute_grpo_advantages(
rewards_normalized, self.group_size, beta=self.grpo_beta
)

# Create trajectory data
# Create trajectory data (rewards initialized to zeros, filled by Scorer)
trajectory = TrajectoryData(
policy_version=self.policy_version,
completions=completions,
vllm_token_ids=vllm_token_ids,
vllm_token_log_probs=vllm_token_log_probs,
prompt_token_ids=prompt_token_ids,
rewards=rewards,
advantages=advantages,
expected_answers=self.expected_answers,
rewards=torch.zeros(len(completions)),
)

# Signal ready for update
Expand Down
114 changes: 114 additions & 0 deletions torchtitan/experiments/rl/unified/actors/scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass
from typing import Callable, List, Optional

import torch
from monarch.actor import Actor, endpoint
from torchtitan.experiments.rl.unified.job_config import JobConfig
from torchtitan.experiments.rl.vllm_compat.simple_rl import trivial_reward_function

logger = logging.getLogger(__name__)


@dataclass
class TrajectoryData:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we deprecated the name trajectory which is intrinsically ambiguous, but I don't know what we replace it by, Episode?

"""
Data from one generation batch.

Attributes:
policy_version: Version of policy that produced this batch
completions: List of completion strings
vllm_token_ids: List of token ID lists for each completion
vllm_token_log_probs: List of per-token log prob lists
prompt_token_ids: List of prompt token ID lists
expected_answers: List of expected answers for reward computation
rewards: Rewards for each completion (initialized to zeros, filled by Scorer)
"""

policy_version: int
completions: List[str]
vllm_token_ids: List[List[int]]
vllm_token_log_probs: List[List[float]]
prompt_token_ids: List[List[int]]
expected_answers: List[str]
rewards: torch.Tensor


class Scorer(Actor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you chose to use Grader. Not sure what's the difference but but aligned.

"""
Evaluates completions and assigns rewards to trajectory data.

The Scorer receives trajectory data from the Generator
and computes rewards using a reward function. Advantage computation
is done by the Trainer.

Args:
job_config: JobConfig dataclass containing all configuration
reward_fn: Optional custom reward function. If not provided,
uses trivial_reward_function from simple_rl.
"""

def __init__(
self,
job_config: JobConfig,
reward_fn: Optional[Callable] = None,
):
# Extract needed fields from job_config
self.group_size = job_config.rl.grpo_group_size

# Set reward function
self.reward_fn = reward_fn if reward_fn is not None else trivial_reward_function

logger.info(f"Scorer initialized with group_size={self.group_size}")

@endpoint
async def score(self, trajectory: TrajectoryData) -> TrajectoryData:
"""
Score a trajectory by computing rewards.

Args:
trajectory: Trajectory data (with or without rewards)

Returns:
TrajectoryData with computed rewards
"""
logger.info(
f"Scorer scoring trajectory (policy v{trajectory.policy_version})..."
)

# Compute rewards using reward function
rewards = self.reward_fn(
trajectory.completions,
trajectory.expected_answers,
self.group_size,
)

reward_mean = rewards.mean()
reward_std = rewards.std()

# Update trajectory with rewards
trajectory.rewards = rewards

logger.info(
f"Scorer finished scoring: "
f"reward_mean={reward_mean.item():.4f}, reward_std={reward_std.item():.4f}"
)

return trajectory

@endpoint
async def set_reward_fn(self, reward_fn: Callable) -> None:
"""
Update the reward function.

Args:
reward_fn: New reward function to use
"""
self.reward_fn = reward_fn
logger.info("Scorer reward function updated")
Loading
Loading