Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jan 13, 2026

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 13, 2026
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: bcd9f5e
Pull Request resolved: #2221
@wwwjn wwwjn changed the title refactor save and load model weights using DCP [WIP] refactor save and load model weights using DCP Jan 13, 2026
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: b7642b4
Pull Request resolved: #2221
wwwjn added a commit that referenced this pull request Jan 13, 2026
ghstack-source-id: b7642b4
Pull Request resolved: #2221
wwwjn added a commit that referenced this pull request Jan 14, 2026
ghstack-source-id: 87a29dc
Pull Request resolved: #2221
@wwwjn wwwjn changed the title [WIP] refactor save and load model weights using DCP [rl] refactor save and load model weights using DCP Jan 14, 2026
self.temp_model_dir = os.path.abspath(
os.path.join(job_config.job.dump_folder, "vllm_temp_model")
# Load TorchTitan plugin at runtime
from torchtitan.experiments.rl.unified.plugin import register
Copy link
Contributor

@acisseJZhong acisseJZhong Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use from torchtitan.experiments.rl.unified import register?
can we move the import statement to the header?


return self.load_weights_from_state_dict(torchtitan_state_dict)

def load_weights(self, weights_iter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we load weights from weights_iter?

@fegin
Copy link
Contributor

fegin commented Jan 14, 2026

I'm wondering that should we refactor TorchTitan checkpointer so that it can be directly used in this case. While the current PR work, if TorchTitan migrates to a new checkpoint library other use cases need the same updates as well. This is more future work, not blocking this PR.

return logits

def load_weights(self, weights_iter):
def load_weights_from_state_dict(self, titan_state_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

titan_state_dict is ambiguous -- both sides should be titan models.
What other name could we use, e.g. trainer_state_dict?

# We need to split our weights to match the original 2-shard structure
import glob
# directly update model weights in place
load_weights = self._get_model().load_weights_from_state_dict(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • IIUC this only works when trainer and generator are on exactly the same global mesh. Is it right?
  • Is it true that this has been the assumption before this PR? I.e. is our current monarch script only allow colocated trainer and generator?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants