Skip to content

Commit ab0872a

Browse files
committed
refactor save and load model weights using DCP
ghstack-source-id: b7642b4 Pull Request resolved: #2221
1 parent 8a35268 commit ab0872a

File tree

3 files changed

+38
-34
lines changed

3 files changed

+38
-34
lines changed

torchtitan/experiments/rl/unified/actors/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class TrajectoryData:
5757
advantages: torch.Tensor
5858

5959

60-
class VLLMRolloutEngine:
60+
class VLLMGenerator:
6161
"""
6262
vLLM engine for fast rollouts with weight updates.
6363
@@ -355,7 +355,7 @@ def __init__(
355355
Comm(),
356356
)
357357
# Initialize vLLM engine with job_config
358-
self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path)
358+
self.vllm_engine = VLLMGenerator(job_config, self.model_path)
359359

360360
# State machine
361361
self.state = GeneratorState.READY_TO_UPDATE

torchtitan/experiments/rl/unified/actors/trainer.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torchtitan.experiments.rl.vllm_compat.simple_rl import (
2020
compute_policy_gradient_loss_vllm,
2121
)
22-
from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm
2322

2423
logger = logging.getLogger(__name__)
2524

@@ -52,7 +51,6 @@ def __init__(
5251

5352
# load trainer model and patch to vllm.Attention()
5453
self.model = load_trainer_model(model_path)
55-
5654
self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size)
5755

5856
# apply PT-D Parallelism
@@ -77,14 +75,13 @@ def __init__(
7775

7876
@endpoint
7977
async def get_weights(self) -> dict:
80-
"""Get vLLM weights for generator.
78+
"""Get model weights for generator.
8179
8280
Returns:
83-
vLLM state dict
81+
model state dict
8482
"""
8583
titan_state = self.model.state_dict()
86-
vllm_state = torchtitan_to_vllm(titan_state)
87-
return vllm_state
84+
return titan_state
8885

8986
@endpoint
9087
async def step(self, trajectory: TrajectoryData) -> dict:
@@ -114,8 +111,6 @@ async def step(self, trajectory: TrajectoryData) -> dict:
114111

115112
self.policy_version += 1
116113

117-
# TODO: save dcp checkpoint to file here instead of sending weight dicts
118-
119114
# Return metrics
120115
metrics = {
121116
"loss": loss.item(),

torchtitan/experiments/rl/unified/models/vllm_wrapper.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -275,37 +275,19 @@ def compute_logits(
275275

276276
return logits
277277

278-
def load_weights(self, weights_iter):
278+
def load_weights_from_state_dict(self, titan_state_dict):
279279
"""
280-
Load weights from HF checkpoint using the provided state dict adapter.
281-
vLLM engine would call this function to load model weights.
282-
283-
Args:
284-
weights_iter: Iterator of (name, tensor) pairs from HF checkpoint
285-
286-
Returns:
287-
Set of loaded parameter names
280+
Load model weights directly from
288281
"""
289-
# Collect weights from iterator
290-
hf_state_dict = {}
291-
for name, tensor in weights_iter:
292-
hf_state_dict[name] = tensor
293-
294-
# Use adapter to convert HF → TorchTitan format
295-
adapter = self.state_dict_adapter(
296-
model_args=self.config,
297-
hf_assets_path=None,
298-
)
299282

300-
torchtitan_state_dict = adapter.from_hf(hf_state_dict)
301283
model_state_dict = {k: v for k, v in self.model.state_dict().items()}
302284

303285
# Convert to DTensor if target is DTensor
304-
for name, tensor in torchtitan_state_dict.items():
286+
for name, tensor in titan_state_dict.items():
305287
if name in model_state_dict and isinstance(model_state_dict[name], DTensor):
306288
target_dtensor = model_state_dict[name]
307289
device_mesh = target_dtensor.device_mesh
308-
torchtitan_state_dict[name] = DTensor.from_local(
290+
titan_state_dict[name] = DTensor.from_local(
309291
tensor.to(device_mesh.device_type),
310292
device_mesh=device_mesh,
311293
placements=[Replicate()],
@@ -314,10 +296,37 @@ def load_weights(self, weights_iter):
314296
# Load state dict
315297
set_model_state_dict(
316298
model=self.model,
317-
model_state_dict=torchtitan_state_dict,
299+
model_state_dict=titan_state_dict,
318300
options=StateDictOptions(strict=False),
319301
)
320302

321-
loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()}
303+
loaded_params = titan_state_dict.keys()
322304

323305
return loaded_params
306+
307+
def load_weights(self, weights_iter):
308+
"""
309+
Load weights from HF checkpoint using the provided state dict adapter.
310+
vLLM engine would call this function to load model weights.
311+
312+
Args:
313+
weights_iter: Iterator of (name, tensor) pairs from HF checkpoint
314+
315+
Returns:
316+
Set of loaded parameter names
317+
"""
318+
319+
# Since our model weights are already loaded during initialization,
320+
# we need to return the names of all parameters that have been loaded
321+
# so vLLM's safety check passes.
322+
loaded_param_names = set()
323+
for name, _ in self.model.named_parameters():
324+
loaded_param_names.add(name)
325+
326+
logger.info(
327+
f"Weights already loaded during model initialization. \
328+
Returning {len(loaded_param_names)} loaded parameter names to satisfy vLLM safety check."
329+
)
330+
331+
# Return the names of all loaded parameters so vLLM knows they were handled
332+
return loaded_param_names

0 commit comments

Comments
 (0)