@@ -275,37 +275,19 @@ def compute_logits(
275275
276276 return logits
277277
278- def load_weights (self , weights_iter ):
278+ def load_weights_from_state_dict (self , titan_state_dict ):
279279 """
280- Load weights from HF checkpoint using the provided state dict adapter.
281- vLLM engine would call this function to load model weights.
282-
283- Args:
284- weights_iter: Iterator of (name, tensor) pairs from HF checkpoint
285-
286- Returns:
287- Set of loaded parameter names
280+ Load model weights directly from
288281 """
289- # Collect weights from iterator
290- hf_state_dict = {}
291- for name , tensor in weights_iter :
292- hf_state_dict [name ] = tensor
293-
294- # Use adapter to convert HF → TorchTitan format
295- adapter = self .state_dict_adapter (
296- model_args = self .config ,
297- hf_assets_path = None ,
298- )
299282
300- torchtitan_state_dict = adapter .from_hf (hf_state_dict )
301283 model_state_dict = {k : v for k , v in self .model .state_dict ().items ()}
302284
303285 # Convert to DTensor if target is DTensor
304- for name , tensor in torchtitan_state_dict .items ():
286+ for name , tensor in titan_state_dict .items ():
305287 if name in model_state_dict and isinstance (model_state_dict [name ], DTensor ):
306288 target_dtensor = model_state_dict [name ]
307289 device_mesh = target_dtensor .device_mesh
308- torchtitan_state_dict [name ] = DTensor .from_local (
290+ titan_state_dict [name ] = DTensor .from_local (
309291 tensor .to (device_mesh .device_type ),
310292 device_mesh = device_mesh ,
311293 placements = [Replicate ()],
@@ -314,10 +296,37 @@ def load_weights(self, weights_iter):
314296 # Load state dict
315297 set_model_state_dict (
316298 model = self .model ,
317- model_state_dict = torchtitan_state_dict ,
299+ model_state_dict = titan_state_dict ,
318300 options = StateDictOptions (strict = False ),
319301 )
320302
321- loaded_params = { f"model. { name } " for name in torchtitan_state_dict . keys ()}
303+ loaded_params = titan_state_dict . keys ()
322304
323305 return loaded_params
306+
307+ def load_weights (self , weights_iter ):
308+ """
309+ Load weights from HF checkpoint using the provided state dict adapter.
310+ vLLM engine would call this function to load model weights.
311+
312+ Args:
313+ weights_iter: Iterator of (name, tensor) pairs from HF checkpoint
314+
315+ Returns:
316+ Set of loaded parameter names
317+ """
318+
319+ # Since our model weights are already loaded during initialization,
320+ # we need to return the names of all parameters that have been loaded
321+ # so vLLM's safety check passes.
322+ loaded_param_names = set ()
323+ for name , _ in self .model .named_parameters ():
324+ loaded_param_names .add (name )
325+
326+ logger .info (
327+ f"Weights already loaded during model initialization. \
328+ Returning { len (loaded_param_names )} loaded parameter names to satisfy vLLM safety check."
329+ )
330+
331+ # Return the names of all loaded parameters so vLLM knows they were handled
332+ return loaded_param_names
0 commit comments