Skip to content

Commit 7c845f8

Browse files
committed
add ut and unify training input to WorkerLogItem
1 parent 1011b6f commit 7c845f8

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

xtuner/v1/rl/base/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from .controller import TrainingController, TrainingControllerProxy, TrainingStepTimeLog
22
from .loss import BaseRLLossConfig, RLLossContextInputItem
3-
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem
3+
from .worker import (
4+
TrainingWorker,
5+
TrainingWorkerClass,
6+
TrainingWorkerProxy,
7+
WorkerConfig,
8+
WorkerInputItem,
9+
WorkerLogItem,
10+
)
411

512

613
__all__ = [
@@ -14,4 +21,5 @@
1421
"RLLossContextInputItem",
1522
"WorkerLogItem",
1623
"TrainingStepTimeLog",
24+
"WorkerInputItem",
1725
]

xtuner/v1/rl/base/controller.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from xtuner.v1.train.trainer import LoadCheckpointConfig
1818
from xtuner.v1.utils import get_logger, ray_method
1919

20+
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
2021

21-
TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours
2222

23-
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
23+
TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours
2424

2525

2626
class TrainingStepTimeLog(TypedDict):
@@ -314,7 +314,10 @@ def _set_data_batches_properties(self, data_batches: list[WorkerInputItem]):
314314
def _pad_and_pack_batches(self, batch4pack: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem:
315315
seq_ctx_list = [item["seq_ctx"] for item in batch4pack]
316316
label_list = [item["shifted_labels"] for item in batch4pack]
317-
advantage_list = [torch.tensor([item["advantages"]]).float().unsqueeze(0) for item in batch4pack]
317+
advantage_list = []
318+
for item in batch4pack:
319+
advantages = item["advantages"].reshape(1, -1)
320+
advantage_list.append(advantages)
318321
rollout_logprobs_list = [
319322
item["rollout_logprobs"] if self.has_rollout_logprobs else None for item in batch4pack
320323
]
@@ -366,6 +369,7 @@ def _pad_to_max_packs_across_workes(
366369
padding_item = self._create_padding_item(pack_max_length, pack_max_length)
367370
padding_items = [padding_item for _ in range(num_padding_packs)]
368371
packed_data_batches[dp_rank][step_idx].extend(padding_items)
372+
return packed_data_batches
369373

370374
@ray_method
371375
def fit(
@@ -428,7 +432,9 @@ def fit(
428432
# padding for each worker to have same number of packs in each optimizer step
429433
for step_idx in range(optimizer_steps):
430434
max_packs = max_packs_per_step[step_idx]
431-
self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)
435+
packed_data_batches = self._pad_to_max_packs_across_workes(
436+
packed_data_batches, step_idx, max_packs, pack_max_length
437+
)
432438

433439
pack_end_time = time.perf_counter()
434440
self.logger.info(f"Data packing took {pack_end_time - pack_start_time:.2f} seconds.")

xtuner/v1/rl/base/worker.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,13 @@ def _apply_rollout_is_correction(
516516
all_rollout_is_metrics.append(rollout_is_metrics)
517517
all_mismatch_metrics.append(mismatch_metrics)
518518

519-
worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}}
520-
logger_msg = f"Rollout {rollout_idx}: "
521-
sum_entropy = cast(torch.Tensor, sum_entropy)
522-
dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
523-
avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0)
524-
worker_log_item["train_entropy"] = avg_sum_entropy.item()
525-
logger_msg += f"avg entropy: {avg_sum_entropy:.4f}"
519+
metrics = {
520+
"sum_entropy": sum_entropy,
521+
"sum_rollout_entropy": sum_rollout_entropy,
522+
"all_mismatch_metrics": all_mismatch_metrics,
523+
"all_rollout_is_metrics": all_rollout_is_metrics,
524+
}
525+
return loss_ctx_input_list, metrics
526526

527527
@ray_method
528528
def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
@@ -579,10 +579,7 @@ def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
579579
global_grad_tokens = rank_grad_tokens.clone()
580580
dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM)
581581

582-
worker_log_item: WorkerLogItem = {
583-
"train_entropy": 0.0,
584-
"train_metrics": [],
585-
}
582+
worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}}
586583
log_parts = []
587584

588585
sum_entropy = cast(torch.Tensor, metrics["sum_entropy"])
@@ -678,7 +675,10 @@ def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
678675
f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}"
679676
for key, value in log_info.items()
680677
)
681-
log_str = f"Rank{self.rank} Rollout {rollout_idx} Step {i}: gradient_accumulation_steps={num_packs_this_step}" + log_str
678+
log_str = (
679+
f"Rank{self.rank} Rollout {rollout_idx} Step {i}: gradient_accumulation_steps={num_packs_this_step}, "
680+
+ log_str
681+
)
682682
self.logger.info(log_str)
683683

684684
self._rollout_step += 1

xtuner/v1/train/rl_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TrainingWorkerClass,
3434
TrainingWorkerProxy,
3535
WorkerConfig,
36+
WorkerInputItem,
3637
WorkerLogItem,
3738
)
3839
from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker
@@ -774,10 +775,10 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
774775
rollout_logprobs = None
775776

776777
seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1)
777-
data_dict = {
778+
data_dict: WorkerInputItem = {
778779
"seq_ctx": seq_ctx,
779780
"shifted_labels": shifted_labels,
780-
"advantage": advantages[i].item(),
781+
"advantages": advantages[i],
781782
"rollout_logprobs": rollout_logprobs,
782783
}
783784

0 commit comments

Comments
 (0)