@@ -743,63 +743,27 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
743743 self .profiling = training_params .get ("profiling" , False )
744744 self .profiling_file = training_params .get ("profiling_file" , "timeline.json" )
745745
746- # Log model summary info (descriptor type and parameter count)
746+ # Log model parameter count
747747 if self .rank == 0 :
748- self ._log_model_summary ()
749-
750- def _log_model_summary (self ) -> None :
751- """Log model summary information including descriptor type and parameter count."""
752-
753- def get_descriptor_type (model : torch .nn .Module ) -> str :
754- """Get the descriptor type name from model."""
755- # Standard models have get_descriptor method
756- if hasattr (model , "get_descriptor" ):
757- descriptor = model .get_descriptor ()
758- if descriptor is not None and hasattr (descriptor , "serialize" ):
759- serialized = descriptor .serialize ()
760- if isinstance (serialized , dict ) and "type" in serialized :
761- return str (serialized ["type" ]).upper ()
762- # ZBL and other models: use serialize() API
763- if hasattr (model , "serialize" ):
764- serialized = model .serialize ()
765- if isinstance (serialized , dict ):
766- model_type = str (serialized .get ("type" , "" )).lower ()
767- if model_type == "zbl" :
768- # ZBL model: get descriptor type from the DP sub-model
769- models_data = serialized .get ("models" , [])
770- if models_data and isinstance (models_data [0 ], dict ):
771- descriptor_data = models_data [0 ].get ("descriptor" , {})
772- if isinstance (descriptor_data , dict ):
773- desc_type = descriptor_data .get ("type" , "UNKNOWN" )
774- return f"{ str (desc_type ).upper ()} (with ZBL)"
775- return "UNKNOWN (with ZBL)"
776- return "UNKNOWN"
777-
778- def count_parameters (model : torch .nn .Module ) -> tuple [int , int ]:
779- """Count the number of trainable and total parameters.
780-
781- Returns
782- -------
783- tuple[int, int]
784- (trainable_count, total_count)
785- """
786- trainable = sum (p .numel () for p in model .parameters () if p .requires_grad )
787- total = sum (p .numel () for p in model .parameters ())
788- return trainable , total
748+ self ._log_parameter_count ()
789749
750+ def _log_parameter_count (self ) -> None :
751+ """Log model parameter count."""
790752 if not self .multi_task :
791- desc_type = get_descriptor_type (self .model )
792- trainable , total = count_parameters (self .model )
793- log .info (f"Descriptor: { desc_type } " )
753+ trainable = sum (
754+ p .numel () for p in self .model .parameters () if p .requires_grad
755+ )
756+ total = sum (p .numel () for p in self .model .parameters ())
794757 log .info (
795758 f"Model Params: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
796759 )
797760 else :
798- # For multi-task, log each model's info
799761 for model_key in self .model_keys :
800- desc_type = get_descriptor_type (self .model [model_key ])
801- trainable , total = count_parameters (self .model [model_key ])
802- log .info (f"Descriptor [{ model_key } ]: { desc_type } " )
762+ model = self .model [model_key ]
763+ trainable = sum (
764+ p .numel () for p in model .parameters () if p .requires_grad
765+ )
766+ total = sum (p .numel () for p in model .parameters ())
803767 log .info (
804768 f"Model Params [{ model_key } ]: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
805769 )
0 commit comments