Skip to content

Commit 5dd8633

Browse files
committed
leave only parameters display
1 parent a9b0c0f commit 5dd8633

File tree

2 files changed

+13
-238
lines changed

2 files changed

+13
-238
lines changed

deepmd/pt/train/training.py

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -743,63 +743,27 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
743743
self.profiling = training_params.get("profiling", False)
744744
self.profiling_file = training_params.get("profiling_file", "timeline.json")
745745

746-
# Log model summary info (descriptor type and parameter count)
746+
# Log model parameter count
747747
if self.rank == 0:
748-
self._log_model_summary()
749-
750-
def _log_model_summary(self) -> None:
751-
"""Log model summary information including descriptor type and parameter count."""
752-
753-
def get_descriptor_type(model: torch.nn.Module) -> str:
754-
"""Get the descriptor type name from model."""
755-
# Standard models have get_descriptor method
756-
if hasattr(model, "get_descriptor"):
757-
descriptor = model.get_descriptor()
758-
if descriptor is not None and hasattr(descriptor, "serialize"):
759-
serialized = descriptor.serialize()
760-
if isinstance(serialized, dict) and "type" in serialized:
761-
return str(serialized["type"]).upper()
762-
# ZBL and other models: use serialize() API
763-
if hasattr(model, "serialize"):
764-
serialized = model.serialize()
765-
if isinstance(serialized, dict):
766-
model_type = str(serialized.get("type", "")).lower()
767-
if model_type == "zbl":
768-
# ZBL model: get descriptor type from the DP sub-model
769-
models_data = serialized.get("models", [])
770-
if models_data and isinstance(models_data[0], dict):
771-
descriptor_data = models_data[0].get("descriptor", {})
772-
if isinstance(descriptor_data, dict):
773-
desc_type = descriptor_data.get("type", "UNKNOWN")
774-
return f"{str(desc_type).upper()} (with ZBL)"
775-
return "UNKNOWN (with ZBL)"
776-
return "UNKNOWN"
777-
778-
def count_parameters(model: torch.nn.Module) -> tuple[int, int]:
779-
"""Count the number of trainable and total parameters.
780-
781-
Returns
782-
-------
783-
tuple[int, int]
784-
(trainable_count, total_count)
785-
"""
786-
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
787-
total = sum(p.numel() for p in model.parameters())
788-
return trainable, total
748+
self._log_parameter_count()
789749

750+
def _log_parameter_count(self) -> None:
751+
"""Log model parameter count."""
790752
if not self.multi_task:
791-
desc_type = get_descriptor_type(self.model)
792-
trainable, total = count_parameters(self.model)
793-
log.info(f"Descriptor: {desc_type}")
753+
trainable = sum(
754+
p.numel() for p in self.model.parameters() if p.requires_grad
755+
)
756+
total = sum(p.numel() for p in self.model.parameters())
794757
log.info(
795758
f"Model Params: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
796759
)
797760
else:
798-
# For multi-task, log each model's info
799761
for model_key in self.model_keys:
800-
desc_type = get_descriptor_type(self.model[model_key])
801-
trainable, total = count_parameters(self.model[model_key])
802-
log.info(f"Descriptor [{model_key}]: {desc_type}")
762+
model = self.model[model_key]
763+
trainable = sum(
764+
p.numel() for p in model.parameters() if p.requires_grad
765+
)
766+
total = sum(p.numel() for p in model.parameters())
803767
log.info(
804768
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
805769
)

source/tests/pt/test_model_summary.py

Lines changed: 0 additions & 189 deletions
This file was deleted.

0 commit comments

Comments
 (0)