Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Jan 11, 2026

Summary by CodeRabbit

  • Chores
    • Enhanced runtime logging during training initialization to report total and trainable parameter counts. Includes per-model breakdowns for multi-task configurations so model-size details are visible in logs for easier diagnostics and monitoring.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 11, 2026 08:54
@dosubot dosubot bot added the new feature label Jan 11, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 11, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds two private helper methods to Trainer: _count_parameters (static) and _log_parameter_count, and calls _log_parameter_count during initialization to log total and trainable parameter counts, including per-model counts when model_keys (multi-task) are present.

Changes

Cohort / File(s) Summary
Parameter logging in Trainer
deepmd/pt/train/training.py
Added @staticmethod def _count_parameters(model: torch.nn.Module) -> tuple[int, int] and def _log_parameter_count(self) -> None; invoked _log_parameter_count during Trainer init to log total and trainable parameter counts. Supports multi-task by iterating model_keys and logging per-model counts.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~4 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat(pt): add parameter numbers output' accurately summarizes the main change: adding runtime logging of model parameter counts during trainer initialization.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings


📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9ab02e4 and 28f0d7a.

📒 Files selected for processing (1)
  • deepmd/pt/train/training.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/train/training.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 753-770: The get_descriptor_type function can raise AttributeError
because it assumes model.get_descriptor(), descriptor.serialize(),
model.atomic_model, model.atomic_model.models[0], and dp_model.descriptor are
non-None; update get_descriptor_type to guard every access: verify
model.get_descriptor exists and returns a non-None descriptor before calling
descriptor.serialize(), ensure serialize() returns a dict before indexing
"type", check model.atomic_model is not None and model.atomic_model.models is a
non-empty sequence with models[0] not None, and verify dp_model.descriptor is
non-None before calling its serialize(); keep returning "UNKNOWN" if any
intermediate value is None or serialize() isn't a dict.
- Around line 776-787: Log messages for single- and multi-task branches use
inconsistent spacing; update the log.info calls in the block using
get_descriptor_type, count_parameters, self.multi_task and self.model_keys so
they use a consistent format (e.g., "Descriptor: {desc}" and "Model Params:
{num:.3f} M" with a single space after the colon or a fixed-width label) for
both single-task and the per-model multi-task loop, ensuring the same
spacing/label style is applied to Descriptor and Model Params messages for each
model_key.

In @source/tests/pt/test_model_summary.py:
- Around line 106-109: The static test helper count_parameters duplicates
production logic; instead extract the parameter-counting logic into a shared
utility (e.g., a function in training.py or a new test_utils module) and have
the test import and call that implementation rather than defining its own static
method; update the test to call the imported function (referencing
count_parameters or the new utility function name) so the test exercises the
real production logic just like get_descriptor_type was refactored.
- Around line 15-33: The test duplicates production logic: extract the
duplicated methods get_descriptor_type and count_parameters out of the test and
into training.py (or a new utils module) as module-level functions, update any
internal callers to use the new module-level functions (e.g., replace
class/static usages with importable functions), then change
source/tests/pt/test_model_summary.py to import get_descriptor_type and
count_parameters from training.py (or the new utils module) and call those
functions directly, removing the duplicated static methods from the test so the
tests exercise the real implementation and avoid drift.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 82a5f32 and c608729.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py
  • source/tests/pt/test_model_summary.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/pt/model/model/dp_model.py (1)
  • get_descriptor (52-54)
deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)
  • serialize (302-313)
🔇 Additional comments (3)
deepmd/pt/train/training.py (1)

746-748: LGTM!

The placement and rank gating are appropriate for logging model summary once after initialization.

source/tests/pt/test_model_summary.py (2)

35-100: Good test coverage for edge cases.

The test methods comprehensively cover standard models, ZBL models, empty lists, missing keys, non-dict serialization, and unknown structures. The test scenarios are well-designed.


111-140: Good test coverage for parameter counting scenarios.

The tests appropriately cover all parameters trainable, mixed trainability, and all frozen. The use of torch.device("cpu") context is consistent and appropriate for testing.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds functionality to log model descriptor type and parameter count information during training initialization. The changes provide better visibility into the model architecture being trained.

Changes:

  • Added a new _log_model_summary() method to the Trainer class that logs descriptor type and trainable parameter count
  • Implemented helper functions get_descriptor_type() and count_parameters() to extract model information
  • Added comprehensive unit tests for the helper functions in a new test file

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.

File Description
deepmd/pt/train/training.py Added _log_model_summary() method with nested helper functions to log model descriptor type and parameter count during trainer initialization
source/tests/pt/test_model_summary.py Added comprehensive unit tests for get_descriptor_type() and count_parameters() functions covering standard models, ZBL models, edge cases, and error conditions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Jan 11, 2026

Codecov Report

❌ Patch coverage is 75.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.93%. Comparing base (e5baf69) to head (28f0d7a).
⚠️ Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/train/training.py 75.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5147      +/-   ##
==========================================
- Coverage   81.94%   81.93%   -0.01%     
==========================================
  Files         713      714       +1     
  Lines       73009    73315     +306     
  Branches     3617     3617              
==========================================
+ Hits        59826    60070     +244     
- Misses      12021    12082      +61     
- Partials     1162     1163       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 746-749: Wrap the call to self._log_model_summary() in a non-fatal
try/except so any serialization or unexpected model-shape errors don't abort
training; keep the existing rank check (self.rank == 0), call
self._log_model_summary() inside the try, and on exception catch Exception as e
and log a warning/error via the trainer logger (e.g., self.logger.warning or
self.logger.error) including the exception text but do not re-raise.
- Around line 750-776: The descriptor detection in _log_model_summary should be
made exception-safe and more defensive: wrap calls to model.get_descriptor() and
model.serialize() in try/except so exceptions never propagate, and treat
non-dict/None returns safely; when inspecting ZBL-like serialized data, perform
a case-insensitive compare for "zbl" (e.g. lower()), ensure models_data is a
non-empty list and that models_data[0] is a dict before accessing
.get("descriptor"), and verify descriptor_data is a dict before reading its
"type"; on any unexpected shape or error return a non-fatal fallback like
"UNKNOWN" (or "UNKNOWN (with ZBL)" when ZBL is detected but descriptor missing).
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

778-794: Clarify/adjust parameter counting semantics for multi-task + shared params.

count_parameters() is fine for “trainable params in this model module”, but the multi-task logs can be misleading if branches share weights (shared params counted once per branch in the output). Consider clarifying the label (e.g., “Model Params (per-branch)”) or additionally logging a “unique across all branches” total when shared_links is used.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c608729 and d6aeacc.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py
  • source/tests/pt/test_model_summary.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/pt/test_model_summary.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.

Applied to files:

  • deepmd/pt/train/training.py
📚 Learning: 2024-07-22T21:19:07.962Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-07-22T21:19:07.962Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.

Applied to files:

  • deepmd/pt/train/training.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

753-776: Add defensive checks for nested dictionary access in ZBL model handling.

The ZBL model path (lines 767-775) assumes a specific nested structure exists without validation. If models[0] doesn't contain a descriptor key or if descriptor isn't a dict, this could fail silently or raise an AttributeError.

🛡️ Proposed defensive handling
                 if isinstance(serialized, dict):
                     model_type = str(serialized.get("type", "")).lower()
                     if model_type == "zbl":
                         # ZBL model: get descriptor type from the DP sub-model
                         models_data = serialized.get("models", [])
-                        if models_data and isinstance(models_data[0], dict):
+                        if (
+                            models_data
+                            and len(models_data) > 0
+                            and isinstance(models_data[0], dict)
+                        ):
                             descriptor_data = models_data[0].get("descriptor", {})
                             if isinstance(descriptor_data, dict):
                                 desc_type = descriptor_data.get("type", "UNKNOWN")
-                                return f"{str(desc_type).upper()} (with ZBL)"
+                                if desc_type != "UNKNOWN":
+                                    return f"{str(desc_type).upper()} (with ZBL)"
                         return "UNKNOWN (with ZBL)"
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6aeacc and 50204da.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py
  • source/tests/pt/test_model_summary.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.

Applied to files:

  • deepmd/pt/train/training.py
📚 Learning: 2024-07-22T21:19:07.962Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-07-22T21:19:07.962Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.

Applied to files:

  • deepmd/pt/train/training.py
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.

Applied to files:

  • source/tests/pt/test_model_summary.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (4)
deepmd/pt/model/model/dp_model.py (1)
  • get_descriptor (52-54)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • serialize (168-180)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • serialize (336-347)
deepmd/pt/model/atomic_model/linear_atomic_model.py (2)
  • serialize (375-386)
  • serialize (570-582)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (10)
deepmd/pt/train/training.py (3)

746-748: LGTM! Clean placement and correct rank guard.

The invocation of _log_model_summary() is well-placed after profiling settings and correctly guarded by self.rank == 0 to ensure logging only happens once in distributed training scenarios.


778-780: LGTM! Correct parameter counting implementation.

The parameter counting logic correctly filters for trainable parameters using p.requires_grad and sums their element counts.


782-793: LGTM! Logging logic correctly handles both single-task and multi-task cases.

The branching logic appropriately handles single-task models and multi-task models with per-model-key logging. The parameter count formatting (.3f with millions) is reasonable for typical model sizes.

source/tests/pt/test_model_summary.py (7)

16-51: LGTM! Well-structured test helper methods.

The helper methods provide clean fixtures for creating mock trainers and models. The explicit deletion of get_descriptor in _create_mock_zbl_model (line 40) correctly simulates ZBL models that don't have this method.


54-69: LGTM! Standard model test validates expected log output.

The test correctly verifies that descriptor type is uppercased ("SE_E2_A") and parameter count appears in logs.


72-86: LGTM! ZBL model test validates special descriptor format.

The test correctly validates that ZBL models log descriptor with "(with ZBL)" suffix, matching the implementation in the training file.


89-110: LGTM! Multi-task test validates per-model logging.

The test correctly verifies that each model key and its corresponding descriptor type appear in the logs for multi-task setups.


113-132: LGTM! Unknown structure test validates graceful fallback.

The test ensures that models without valid descriptor information log "UNKNOWN" without raising exceptions.


135-153: LGTM! None descriptor test validates edge case handling.

The test confirms that get_descriptor() returning None is handled gracefully, avoiding AttributeError when trying to access methods on None.


159-185: LGTM! Parameter count test validates accurate counting.

The test uses a real torch.nn.Linear model with 55 parameters (10×5 + 5 bias) and verifies the count is correctly formatted as "0.000 M" in the logs.

@OutisLi OutisLi requested review from iProzd and njzjz January 11, 2026 11:40
@OutisLi OutisLi requested a review from iProzd January 13, 2026 07:12
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

753-776: Consider adding exception handling for robustness.

The serialize() calls could potentially raise exceptions for edge-case models. While this logging is non-critical, a failure here would prevent training from starting.

💡 Optional: wrap descriptor detection in try-except
 def get_descriptor_type(model: torch.nn.Module) -> str:
     """Get the descriptor type name from model."""
-    # Standard models have get_descriptor method
-    if hasattr(model, "get_descriptor"):
-        descriptor = model.get_descriptor()
-        if descriptor is not None and hasattr(descriptor, "serialize"):
-            serialized = descriptor.serialize()
-            if isinstance(serialized, dict) and "type" in serialized:
-                return str(serialized["type"]).upper()
-    # ZBL and other models: use serialize() API
-    if hasattr(model, "serialize"):
-        serialized = model.serialize()
-        if isinstance(serialized, dict):
-            model_type = str(serialized.get("type", "")).lower()
-            if model_type == "zbl":
-                # ZBL model: get descriptor type from the DP sub-model
-                models_data = serialized.get("models", [])
-                if models_data and isinstance(models_data[0], dict):
-                    descriptor_data = models_data[0].get("descriptor", {})
-                    if isinstance(descriptor_data, dict):
-                        desc_type = descriptor_data.get("type", "UNKNOWN")
-                        return f"{str(desc_type).upper()} (with ZBL)"
-                return "UNKNOWN (with ZBL)"
-    return "UNKNOWN"
+    try:
+        # Standard models have get_descriptor method
+        if hasattr(model, "get_descriptor"):
+            descriptor = model.get_descriptor()
+            if descriptor is not None and hasattr(descriptor, "serialize"):
+                serialized = descriptor.serialize()
+                if isinstance(serialized, dict) and "type" in serialized:
+                    return str(serialized["type"]).upper()
+        # ZBL and other models: use serialize() API
+        if hasattr(model, "serialize"):
+            serialized = model.serialize()
+            if isinstance(serialized, dict):
+                model_type = str(serialized.get("type", "")).lower()
+                if model_type == "zbl":
+                    # ZBL model: get descriptor type from the DP sub-model
+                    models_data = serialized.get("models", [])
+                    if models_data and isinstance(models_data[0], dict):
+                        descriptor_data = models_data[0].get("descriptor", {})
+                        if isinstance(descriptor_data, dict):
+                            desc_type = descriptor_data.get("type", "UNKNOWN")
+                            return f"{str(desc_type).upper()} (with ZBL)"
+                    return "UNKNOWN (with ZBL)"
+    except Exception:
+        pass
+    return "UNKNOWN"
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50204da and a9b0c0f.

📒 Files selected for processing (1)
  • deepmd/pt/train/training.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.

Applied to files:

  • deepmd/pt/train/training.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-10-08T15:32:11.479Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.

Applied to files:

  • deepmd/pt/train/training.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (8)
deepmd/pt/model/model/dp_model.py (1)
  • get_descriptor (52-54)
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)
  • serialize (193-206)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • serialize (168-180)
deepmd/pt/model/descriptor/dpa1.py (1)
  • serialize (482-533)
deepmd/pt/model/descriptor/dpa2.py (1)
  • serialize (544-623)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • serialize (336-347)
deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)
  • serialize (302-313)
deepmd/pt/model/descriptor/se_a.py (1)
  • serialize (378-409)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
  • GitHub Check: Test Python (8, 3.13)
  • GitHub Check: Test Python (12, 3.13)
  • GitHub Check: Test Python (2, 3.10)
  • GitHub Check: Test Python (10, 3.10)
  • GitHub Check: Test Python (11, 3.13)
  • GitHub Check: Test Python (12, 3.10)
  • GitHub Check: Test Python (11, 3.10)
  • GitHub Check: Test Python (9, 3.13)
  • GitHub Check: Test Python (7, 3.10)
  • GitHub Check: Test Python (9, 3.10)
  • GitHub Check: Test Python (6, 3.10)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (7, 3.13)
  • GitHub Check: Test Python (1, 3.10)
  • GitHub Check: Test Python (10, 3.13)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (5, 3.10)
  • GitHub Check: Test Python (3, 3.13)
  • GitHub Check: Test Python (6, 3.13)
  • GitHub Check: Test Python (8, 3.10)
  • GitHub Check: Test Python (3, 3.10)
  • GitHub Check: Test Python (5, 3.13)
  • GitHub Check: Test Python (4, 3.10)
  • GitHub Check: Test Python (4, 3.13)
  • GitHub Check: Test Python (1, 3.13)
  • GitHub Check: Test Python (2, 3.13)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (true, false, false, true)
  • GitHub Check: Test C++ (false, true, true, false)
  • GitHub Check: Test C++ (false, false, false, true)
  • GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (2)
deepmd/pt/train/training.py (2)

746-748: LGTM!

Good placement at the end of initialization after the model is fully set up. The rank 0 guard correctly prevents duplicate logging in distributed training scenarios.


778-805: LGTM!

The count_parameters helper uses an efficient generator-based approach. The logging output is well-formatted with clear labeling for both single-task and multi-task scenarios. The parameter counts in millions with 3 decimal places provide good readability.

@OutisLi OutisLi changed the title feat(pt): add descriptor name and parameter numbers output feat(pt): add parameter numbers output Jan 14, 2026
@OutisLi OutisLi requested a review from njzjz January 14, 2026 05:07
@OutisLi OutisLi requested review from iProzd and njzjz January 16, 2026 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants