Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions src/fairseq2/recipes/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from fairseq2.datasets import DataReader, DataReadError
from fairseq2.device import CPU, SupportsDeviceTransfer
from fairseq2.error import InternalError, InvalidOperationError
from fairseq2.gang import GangError, Gangs, broadcast_flag
from fairseq2.gang import GangError, Gangs, all_sum, broadcast_flag
from fairseq2.logging import log
from fairseq2.metrics import (
Mean,
Expand Down Expand Up @@ -504,6 +504,8 @@ def __init__(

self._last_lr = 0.0

self._multi_loss_norm = getattr(self._unit, "multi_loss_norm", False)

@override
def run(self) -> None:
if self._state != _TrainerState.NOT_STARTED:
Expand Down Expand Up @@ -774,19 +776,49 @@ def _do_run_step(self, progress_task: ProgressTask) -> _TrainerState:
batch.to(gangs.root.device, non_blocking=True)

with self._maybe_no_sync(batch_nr, num_batches):
with record_function(f"step_{step_nr}_{batch_nr}_forward"):
loss, num_batch_targets = self._compute_loss(batch)

# If the unit does not return the number of logit targets
# of this batch, we assume that the loss is the mean loss
# and that each batch in this step has the same number of
# logit targets. In this case, we don't need to normalize
# the gradients at the end of the step, but we still have
# to take gradient accumulation into account.
if num_batch_targets is None:
loss = loss / num_batches
if self._multi_loss_norm:
assert (
num_batches == 1
), "microbatching is not supported for multiple loss norm yet"
with record_function(f"step_{step_nr}_{batch_nr}_forward"):
loss_target_count_dict = self._compute_loss(batch)

all_losses = []
for name, (
curr_loss,
target_count,
) in loss_target_count_dict.items():
if target_count is None:
target_sum = num_batches
else:
# we do the all sum here as compared to in
# grad scaling to apply different norm to
# different loss components.
# TODO(lidli): double check if we need to consider the factor of world size like in grad scale.
target_sum = all_sum(gangs.dp, target_count)
curr_loss = curr_loss * gangs.dp.size / target_sum
all_losses.append(curr_loss)
log.info(f"{name}_loss={curr_loss}, {target_sum=}")
self._metric_bag.get(Mean, f"{name}_after_norm").update(
curr_loss / batch.batch_size,
weight=batch.batch_size,
)
log.info(f"{all_losses=}")
loss = sum(all_losses)
else:
num_targets += num_batch_targets
with record_function(f"step_{step_nr}_{batch_nr}_forward"):
loss, num_batch_targets = self._compute_loss(batch)

# If the unit does not return the number of logit targets
# of this batch, we assume that the loss is the mean loss
# and that each batch in this step has the same number of
# logit targets. In this case, we don't need to normalize
# the gradients at the end of the step, but we still have
# to take gradient accumulation into account.
if num_batch_targets is None:
loss = loss / num_batches
else:
num_targets += num_batch_targets

with record_function(f"step_{step_nr}_{batch_nr}_backward"):
self._loss_scaler.backward(loss)
Expand Down
9 changes: 9 additions & 0 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,15 @@ def update_grpo_loss(
metric_bag.get(Mean, "tis_imp_ratio").update(tis_imp_ratio)


@torch.inference_mode()
def update_ntp_loss(
metric_bag: MetricBag, batch: PromptBatch, ntp_loss: Tensor
) -> None:
metric_bag.get(Mean, "ntp_loss").update(
ntp_loss / batch.batch_size, weight=batch.batch_size
)


def compute_reference_logps(
gangs: Gangs,
reference_model: RemoteVllmModel,
Expand Down
138 changes: 124 additions & 14 deletions src/fairseq2/recipes/lm/_online_finetune/_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from copy import copy
from dataclasses import dataclass, field
from typing import Any, Dict, Final, List, Union, cast, final
from typing import Any, Dict, Final, List, Union, cast, final, Literal

import torch
from torch import Tensor
Expand Down Expand Up @@ -47,6 +47,7 @@
update_batch_metrics,
update_grpo_batch_metrics,
update_grpo_loss,
update_ntp_loss,
update_logit_entropy,
update_std_reward,
)
Expand Down Expand Up @@ -211,6 +212,38 @@ def prepare_grpo_batch(
return grpo_batch


def prepare_prompt_completion_seq_batch(
prompt_batch: PromptBatch,
prompt_key: str,
completion_key: str,
tokenizer: AutoTokenizer,
gangs: Gang,
) -> SequenceBatch:
prompt_comp_token_batch, prompt_lens = [], []
for prefix_text, completion_text in zip(
prompt_batch.meta_info.get(prompt_key),
prompt_batch.meta_info.get(completion_key),
):
prefix_text = prefix_text.removesuffix(" <think>").removesuffix("<think>")

if not (prefix_text[-1].isspace() or completion_text[0].isspace()):
prefix_text += " "
prefix_tokens = tokenizer.encode(prefix_text, add_special_tokens=False)
completion_tokens = tokenizer.encode(completion_text, add_special_tokens=False)
prompt_lens.append(len(prefix_tokens))
prompt_comp_token_batch.append(
torch.tensor(
prefix_tokens + completion_tokens,
device=gangs.dp.device,
)
)

prompt_completion_batch: SequenceBatch = collate_with_target_mask(
prompt_comp_token_batch, prompt_lens, device=gangs.dp.device
)
return prompt_completion_batch


@final
class GrpoFinetuneUnit(TrainUnit[SequenceBatch]):
"""Represents the language model DPO-finetuning unit with online generations. Paper: https://arxiv.org/abs/2305.18290."""
Expand Down Expand Up @@ -250,13 +283,16 @@ def __init__(
)
)

if self._config.rollout_tokenizer is not None:
self._rollout_tokenizer = AutoTokenizer.from_pretrained(
self._config.rollout_tokenizer
)
if self._config.tokenizer is not None:
self._tokenizer = AutoTokenizer.from_pretrained(self._config.tokenizer)
self.prompt_key = self._config.prompt_key
self.completion_key = self._config.answer_key

self._display_name = "GRPO"

# this flag tells trainer to process each loss's norm independently.
self._multi_loss_norm = self._config.loss_config.ntp_loss_weight > 0

@property
@override
def display_name(self) -> str | None:
Expand All @@ -273,6 +309,10 @@ def finalize(self, metric_bag: MetricBag) -> None:
def name(self) -> str | None:
return self._display_name

@property
def multi_loss_norm(self) -> bool:
return self._multi_loss_norm

def validate_reward(
self, prompt_batch: PromptBatch, metric_bag
) -> tuple[Tensor, int]:
Expand Down Expand Up @@ -357,25 +397,25 @@ def __call__(
)
if self._config.clip_rollout_after_think is not None:
prompt_batch.meta_info["suffix"] = [
self._rollout_tokenizer.decode(
self._rollout_tokenizer.encode(text, add_special_tokens=False)[
self._tokenizer.decode(
self._tokenizer.encode(text, add_special_tokens=False)[
: self._config.clip_reference
]
)
for text in prompt_batch.meta_info.get("suffix")
]
prompt_batch.meta_info["suffix_ids"] = [
self._rollout_tokenizer.encode(text, add_special_tokens=False)[
self._tokenizer.encode(text, add_special_tokens=False)[
: self._config.clip_reference
]
for text in prompt_batch.meta_info.get("suffix")
]
think_tokens = self._rollout_tokenizer.encode(
think_tokens = self._tokenizer.encode(
"</think>", add_special_tokens=False
)
rollouts = clip_outputs_after_think_token(
rollouts,
self._rollout_tokenizer,
self._tokenizer,
think_tokens,
self._config.clip_rollout_after_think,
)
Expand Down Expand Up @@ -536,12 +576,74 @@ def __call__(
update_std_reward(metric_bag, std_reward)
update_avg_reward(metric_bag, avg_reward)

loss = grpo_loss
# ntp is per prompt rather than per rollout, so it's added only once in the 1st microbatches in each training step to avoid redundant computation.
if (
self._config.loss_config.ntp_loss_weight > 0
and self._rollout_bag.bag_step - 1 == 0
):
prompt_completion_seq_batch: SequenceBatch = (
prepare_prompt_completion_seq_batch(
prompt_batch,
self.prompt_key,
self.completion_key,
self._tokenizer,
self._gangs,
)
)
ntp_input_batch, ntp_target_batch = (
prompt_completion_seq_batch.as_auto_regressive()
)
ntp_input_batch_seqs, ntp_input_batch_seqs_layout = (
ntp_input_batch.as_input()
)

ntp_model_logits: Tensor = self._model.module(
ntp_input_batch_seqs, ntp_input_batch_seqs_layout
)
ntp_loss: Tensor = -self._gather_lprobs(
ntp_model_logits, ntp_target_batch
) # (bsz, s_len)

ntp_loss *= ntp_target_batch.target_mask # mask: (bsz, s_len)

if (
self._config.loss_config.group_size
/ self._config.loss_config.forward_group_size
!= 1
):
# TODO(lidli): if we want support ntp loss for multiple micro-batch
# case, we take care of the gradient scaling in trainer code.
raise NotImplementedError(
"Micro batching is not supported now for ntp currently"
)

if self._config.loss_config.ntp_loss_norm == "length":
ntp_loss = (
ntp_loss.sum(-1) / ntp_target_batch.target_mask.sum(dim=-1)
).sum()
ntp_num_batch_targets: int = prompt_batch.batch_size
elif self._config.loss_config.ntp_loss_norm == "all_tokens":
ntp_loss = ntp_loss.sum()
ntp_num_batch_targets: int = ntp_target_batch.target_mask.sum().item()
elif self._config.loss_config.ntp_loss_norm == "none":
ntp_loss = ntp_loss.sum()
ntp_num_batch_targets: int = prompt_batch.batch_size
else:
raise ValueError("Invalid ntp_loss_norm value")

ntp_loss *= self._config.loss_config.ntp_loss_weight

update_ntp_loss(metric_bag, prompt_batch, ntp_loss)

if self._config.loss_config.loss_token_mean:
return loss, total_tokens
grpo_result = (grpo_loss, total_tokens)
else:
return loss, prompt_batch.batch_size
grpo_result = (grpo_loss, prompt_batch.batch_size)

if self._config.loss_config.ntp_loss_weight == 0:
return grpo_result

return {"grpo_loss": grpo_result, "ntp_loss": (ntp_loss, ntp_num_batch_targets)}

def _gather_lprobs(self, logits: Tensor, target: SequenceBatch) -> Tensor:
assert target.target_mask is not None
Expand Down Expand Up @@ -673,6 +775,10 @@ class GrpoLossConfig:
tis_imp_ratio_cap: float = 2.0
"""Maximum cap for the truncated importance sampling ratio. If <= 0, no cap is applied."""

ntp_loss_weight: float = 0.0 # ntp loss is enabled when this is greater than 0

ntp_loss_norm: Literal["length", "all_tokens", "none"] = "length"


@dataclass(kw_only=True)
class GrpoFinetuneConfig:
Expand Down Expand Up @@ -706,7 +812,11 @@ class GrpoFinetuneConfig:

clip_reference: int | None = None

rollout_tokenizer: str | None = None
tokenizer: str | None = None

prompt_key: str | None = None

answer_key: str | None = None


@final
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/setup/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def register(name: str, *args: Any, **kwargs: Any) -> None:
register("orpo_loss", "ORPO Loss", 0, format_as_float)
register("simpo_loss", "SimPO Loss", 0, format_as_float)
register("grpo_loss", "GRPO Loss", 0, format_as_float)
register("ntp_loss", "NTP Loss", 0, format_as_float)
register("ntp_loss_after_norm", "NTP Loss After Norm", 0, format_as_float)
register("grpo_loss_after_norm", "GRPO Loss After Norm", 0, format_as_float)
register("tis_imp_ratio", "Truncted Importance Sampling Ratio", 0, format_as_float)
register("avg_reward", "Reward", 1, format_as_float)
register("std_reward", "StdDev Reward", 1, format_as_float)
Expand Down