-
Notifications
You must be signed in to change notification settings - Fork 681
[rl] refactor save and load model weights using DCP #2221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/wwwjn/6/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| self.temp_model_dir = os.path.abspath( | ||
| os.path.join(job_config.job.dump_folder, "vllm_temp_model") | ||
| # Load TorchTitan plugin at runtime | ||
| from torchtitan.experiments.rl.unified.plugin import register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use from torchtitan.experiments.rl.unified import register?
can we move the import statement to the header?
|
|
||
| return self.load_weights_from_state_dict(torchtitan_state_dict) | ||
|
|
||
| def load_weights(self, weights_iter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we load weights from weights_iter?
|
I'm wondering that should we refactor TorchTitan checkpointer so that it can be directly used in this case. While the current PR work, if TorchTitan migrates to a new checkpoint library other use cases need the same updates as well. This is more future work, not blocking this PR. |
| return logits | ||
|
|
||
| def load_weights(self, weights_iter): | ||
| def load_weights_from_state_dict(self, titan_state_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
titan_state_dict is ambiguous -- both sides should be titan models.
What other name could we use, e.g. trainer_state_dict?
| # We need to split our weights to match the original 2-shard structure | ||
| import glob | ||
| # directly update model weights in place | ||
| load_weights = self._get_model().load_weights_from_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- IIUC this only works when trainer and generator are on exactly the same global mesh. Is it right?
- Is it true that this has been the assumption before this PR? I.e. is our current monarch script only allow colocated trainer and generator?
Stack from ghstack (oldest at bottom):