Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class TrainConfig(TypedDict, total=False):
positive advantages. Defaults to 0.0 (perfectly balanced)."""
allow_training_without_logprobs: bool
epsilon: float # clip epsilon, using the same name as TRL
normalize_by_length: bool
"""When True (default), divides loss by response length. Set to False for \
Dr. GRPO which removes length normalization bias."""
epsilon_high: (
float | None
) # asymmetric clip upper bound. Defaults to epsilon when None
Expand Down
3 changes: 3 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _get_packed_tensors(
advantage_balance: float,
allow_training_without_logprobs: bool,
scale_rewards: bool,
normalize_by_length: bool,
plot_tensors: bool,
) -> PackedTensors | None:
if model.base_model not in self._tokenizers:
Expand All @@ -195,6 +196,7 @@ def _get_packed_tensors(
allow_training_without_logprobs,
scale_rewards,
image_processor=self._image_processors[model.base_model],
normalize_by_length=normalize_by_length,
)
)
if not tokenized_results:
Expand Down Expand Up @@ -458,6 +460,7 @@ async def _train_model(
"allow_training_without_logprobs", False
),
scale_rewards=dev_config.get("scale_rewards", True),
normalize_by_length=dev_config.get("normalize_by_length", True),
plot_tensors=dev_config.get("plot_tensors", False),
)
if packed_tensors is None:
Expand Down
11 changes: 8 additions & 3 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def tokenize_trajectory_groups(
scale_rewards: bool,
shuffle_group_trajectories: bool = True,
image_processor: BaseImageProcessor | None = None,
normalize_by_length: bool = True,
) -> Generator["TokenizedResult", None, None]:
for group in trajectory_groups:
if not group:
Expand Down Expand Up @@ -86,9 +87,13 @@ def tokenize_trajectory_groups(
allow_training_without_logprobs,
):
trajectory_results.append(result)
weight = 1 / (
sum(sum(result.assistant_mask) for result in trajectory_results) + 1e-6
)
if normalize_by_length:
weight = 1 / (
sum(sum(result.assistant_mask) for result in trajectory_results)
+ 1e-6
)
else:
weight = 1.0
for result in trajectory_results:
result.weight = weight
results.extend(trajectory_results)
Expand Down