-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): add parameter numbers output #5147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds two private helper methods to Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~4 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
✏️ Tip: You can disable this entire section by setting Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 753-770: The get_descriptor_type function can raise AttributeError
because it assumes model.get_descriptor(), descriptor.serialize(),
model.atomic_model, model.atomic_model.models[0], and dp_model.descriptor are
non-None; update get_descriptor_type to guard every access: verify
model.get_descriptor exists and returns a non-None descriptor before calling
descriptor.serialize(), ensure serialize() returns a dict before indexing
"type", check model.atomic_model is not None and model.atomic_model.models is a
non-empty sequence with models[0] not None, and verify dp_model.descriptor is
non-None before calling its serialize(); keep returning "UNKNOWN" if any
intermediate value is None or serialize() isn't a dict.
- Around line 776-787: Log messages for single- and multi-task branches use
inconsistent spacing; update the log.info calls in the block using
get_descriptor_type, count_parameters, self.multi_task and self.model_keys so
they use a consistent format (e.g., "Descriptor: {desc}" and "Model Params:
{num:.3f} M" with a single space after the colon or a fixed-width label) for
both single-task and the per-model multi-task loop, ensuring the same
spacing/label style is applied to Descriptor and Model Params messages for each
model_key.
In @source/tests/pt/test_model_summary.py:
- Around line 106-109: The static test helper count_parameters duplicates
production logic; instead extract the parameter-counting logic into a shared
utility (e.g., a function in training.py or a new test_utils module) and have
the test import and call that implementation rather than defining its own static
method; update the test to call the imported function (referencing
count_parameters or the new utility function name) so the test exercises the
real production logic just like get_descriptor_type was refactored.
- Around line 15-33: The test duplicates production logic: extract the
duplicated methods get_descriptor_type and count_parameters out of the test and
into training.py (or a new utils module) as module-level functions, update any
internal callers to use the new module-level functions (e.g., replace
class/static usages with importable functions), then change
source/tests/pt/test_model_summary.py to import get_descriptor_type and
count_parameters from training.py (or the new utils module) and call those
functions directly, removing the duplicated static methods from the test so the
tests exercise the real implementation and avoid drift.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.pysource/tests/pt/test_model_summary.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (2)
deepmd/pt/model/model/dp_model.py (1)
get_descriptor(52-54)deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)
serialize(302-313)
🔇 Additional comments (3)
deepmd/pt/train/training.py (1)
746-748: LGTM!The placement and rank gating are appropriate for logging model summary once after initialization.
source/tests/pt/test_model_summary.py (2)
35-100: Good test coverage for edge cases.The test methods comprehensively cover standard models, ZBL models, empty lists, missing keys, non-dict serialization, and unknown structures. The test scenarios are well-designed.
111-140: Good test coverage for parameter counting scenarios.The tests appropriately cover all parameters trainable, mixed trainability, and all frozen. The use of
torch.device("cpu")context is consistent and appropriate for testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds functionality to log model descriptor type and parameter count information during training initialization. The changes provide better visibility into the model architecture being trained.
Changes:
- Added a new
_log_model_summary()method to theTrainerclass that logs descriptor type and trainable parameter count - Implemented helper functions
get_descriptor_type()andcount_parameters()to extract model information - Added comprehensive unit tests for the helper functions in a new test file
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| deepmd/pt/train/training.py | Added _log_model_summary() method with nested helper functions to log model descriptor type and parameter count during trainer initialization |
| source/tests/pt/test_model_summary.py | Added comprehensive unit tests for get_descriptor_type() and count_parameters() functions covering standard models, ZBL models, edge cases, and error conditions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5147 +/- ##
==========================================
- Coverage 81.94% 81.93% -0.01%
==========================================
Files 713 714 +1
Lines 73009 73315 +306
Branches 3617 3617
==========================================
+ Hits 59826 60070 +244
- Misses 12021 12082 +61
- Partials 1162 1163 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/train/training.py:
- Around line 746-749: Wrap the call to self._log_model_summary() in a non-fatal
try/except so any serialization or unexpected model-shape errors don't abort
training; keep the existing rank check (self.rank == 0), call
self._log_model_summary() inside the try, and on exception catch Exception as e
and log a warning/error via the trainer logger (e.g., self.logger.warning or
self.logger.error) including the exception text but do not re-raise.
- Around line 750-776: The descriptor detection in _log_model_summary should be
made exception-safe and more defensive: wrap calls to model.get_descriptor() and
model.serialize() in try/except so exceptions never propagate, and treat
non-dict/None returns safely; when inspecting ZBL-like serialized data, perform
a case-insensitive compare for "zbl" (e.g. lower()), ensure models_data is a
non-empty list and that models_data[0] is a dict before accessing
.get("descriptor"), and verify descriptor_data is a dict before reading its
"type"; on any unexpected shape or error return a non-fatal fallback like
"UNKNOWN" (or "UNKNOWN (with ZBL)" when ZBL is detected but descriptor missing).
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
778-794: Clarify/adjust parameter counting semantics for multi-task + shared params.
count_parameters()is fine for “trainable params in this model module”, but the multi-task logs can be misleading if branches share weights (shared params counted once per branch in the output). Consider clarifying the label (e.g., “Model Params (per-branch)”) or additionally logging a “unique across all branches” total whenshared_linksis used.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.pysource/tests/pt/test_model_summary.py
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_model_summary.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.
Applied to files:
deepmd/pt/train/training.py
📚 Learning: 2024-07-22T21:19:07.962Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-07-22T21:19:07.962Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.
Applied to files:
deepmd/pt/train/training.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
753-776: Add defensive checks for nested dictionary access in ZBL model handling.The ZBL model path (lines 767-775) assumes a specific nested structure exists without validation. If
models[0]doesn't contain adescriptorkey or ifdescriptorisn't a dict, this could fail silently or raise an AttributeError.🛡️ Proposed defensive handling
if isinstance(serialized, dict): model_type = str(serialized.get("type", "")).lower() if model_type == "zbl": # ZBL model: get descriptor type from the DP sub-model models_data = serialized.get("models", []) - if models_data and isinstance(models_data[0], dict): + if ( + models_data + and len(models_data) > 0 + and isinstance(models_data[0], dict) + ): descriptor_data = models_data[0].get("descriptor", {}) if isinstance(descriptor_data, dict): desc_type = descriptor_data.get("type", "UNKNOWN") - return f"{str(desc_type).upper()} (with ZBL)" + if desc_type != "UNKNOWN": + return f"{str(desc_type).upper()} (with ZBL)" return "UNKNOWN (with ZBL)"
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.pysource/tests/pt/test_model_summary.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.
Applied to files:
deepmd/pt/train/training.py
📚 Learning: 2024-07-22T21:19:07.962Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-07-22T21:19:07.962Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.
Applied to files:
deepmd/pt/train/training.py
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Applied to files:
source/tests/pt/test_model_summary.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (4)
deepmd/pt/model/model/dp_model.py (1)
get_descriptor(52-54)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
serialize(168-180)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
serialize(336-347)deepmd/pt/model/atomic_model/linear_atomic_model.py (2)
serialize(375-386)serialize(570-582)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (10)
deepmd/pt/train/training.py (3)
746-748: LGTM! Clean placement and correct rank guard.The invocation of
_log_model_summary()is well-placed after profiling settings and correctly guarded byself.rank == 0to ensure logging only happens once in distributed training scenarios.
778-780: LGTM! Correct parameter counting implementation.The parameter counting logic correctly filters for trainable parameters using
p.requires_gradand sums their element counts.
782-793: LGTM! Logging logic correctly handles both single-task and multi-task cases.The branching logic appropriately handles single-task models and multi-task models with per-model-key logging. The parameter count formatting (
.3fwith millions) is reasonable for typical model sizes.source/tests/pt/test_model_summary.py (7)
16-51: LGTM! Well-structured test helper methods.The helper methods provide clean fixtures for creating mock trainers and models. The explicit deletion of
get_descriptorin_create_mock_zbl_model(line 40) correctly simulates ZBL models that don't have this method.
54-69: LGTM! Standard model test validates expected log output.The test correctly verifies that descriptor type is uppercased ("SE_E2_A") and parameter count appears in logs.
72-86: LGTM! ZBL model test validates special descriptor format.The test correctly validates that ZBL models log descriptor with "(with ZBL)" suffix, matching the implementation in the training file.
89-110: LGTM! Multi-task test validates per-model logging.The test correctly verifies that each model key and its corresponding descriptor type appear in the logs for multi-task setups.
113-132: LGTM! Unknown structure test validates graceful fallback.The test ensures that models without valid descriptor information log "UNKNOWN" without raising exceptions.
135-153: LGTM! None descriptor test validates edge case handling.The test confirms that
get_descriptor()returningNoneis handled gracefully, avoiding AttributeError when trying to access methods on None.
159-185: LGTM! Parameter count test validates accurate counting.The test uses a real
torch.nn.Linearmodel with 55 parameters (10×5 + 5 bias) and verifies the count is correctly formatted as "0.000 M" in the logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
753-776: Consider adding exception handling for robustness.The
serialize()calls could potentially raise exceptions for edge-case models. While this logging is non-critical, a failure here would prevent training from starting.💡 Optional: wrap descriptor detection in try-except
def get_descriptor_type(model: torch.nn.Module) -> str: """Get the descriptor type name from model.""" - # Standard models have get_descriptor method - if hasattr(model, "get_descriptor"): - descriptor = model.get_descriptor() - if descriptor is not None and hasattr(descriptor, "serialize"): - serialized = descriptor.serialize() - if isinstance(serialized, dict) and "type" in serialized: - return str(serialized["type"]).upper() - # ZBL and other models: use serialize() API - if hasattr(model, "serialize"): - serialized = model.serialize() - if isinstance(serialized, dict): - model_type = str(serialized.get("type", "")).lower() - if model_type == "zbl": - # ZBL model: get descriptor type from the DP sub-model - models_data = serialized.get("models", []) - if models_data and isinstance(models_data[0], dict): - descriptor_data = models_data[0].get("descriptor", {}) - if isinstance(descriptor_data, dict): - desc_type = descriptor_data.get("type", "UNKNOWN") - return f"{str(desc_type).upper()} (with ZBL)" - return "UNKNOWN (with ZBL)" - return "UNKNOWN" + try: + # Standard models have get_descriptor method + if hasattr(model, "get_descriptor"): + descriptor = model.get_descriptor() + if descriptor is not None and hasattr(descriptor, "serialize"): + serialized = descriptor.serialize() + if isinstance(serialized, dict) and "type" in serialized: + return str(serialized["type"]).upper() + # ZBL and other models: use serialize() API + if hasattr(model, "serialize"): + serialized = model.serialize() + if isinstance(serialized, dict): + model_type = str(serialized.get("type", "")).lower() + if model_type == "zbl": + # ZBL model: get descriptor type from the DP sub-model + models_data = serialized.get("models", []) + if models_data and isinstance(models_data[0], dict): + descriptor_data = models_data[0].get("descriptor", {}) + if isinstance(descriptor_data, dict): + desc_type = descriptor_data.get("type", "UNKNOWN") + return f"{str(desc_type).upper()} (with ZBL)" + return "UNKNOWN (with ZBL)" + except Exception: + pass + return "UNKNOWN"
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/train/training.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-07-22T21:18:12.787Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/dpmodel/model/model.py:106-106
Timestamp: 2024-07-22T21:18:12.787Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling for unsupported model types by raising a `RuntimeError` and providing a "Did you mean" message if there are close matches.
Applied to files:
deepmd/pt/train/training.py
📚 Learning: 2024-10-08T15:32:11.479Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4002
File: deepmd/pt/model/model/__init__.py:209-209
Timestamp: 2024-10-08T15:32:11.479Z
Learning: The `get_class_by_type` method in `deepmd/utils/plugin.py` includes error handling that raises a `RuntimeError` if the class type is unknown, along with a "Did you mean" suggestion.
Applied to files:
deepmd/pt/train/training.py
🧬 Code graph analysis (1)
deepmd/pt/train/training.py (8)
deepmd/pt/model/model/dp_model.py (1)
get_descriptor(52-54)deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)
serialize(193-206)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
serialize(168-180)deepmd/pt/model/descriptor/dpa1.py (1)
serialize(482-533)deepmd/pt/model/descriptor/dpa2.py (1)
serialize(544-623)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
serialize(336-347)deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)
serialize(302-313)deepmd/pt/model/descriptor/se_a.py (1)
serialize(378-409)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (2)
deepmd/pt/train/training.py (2)
746-748: LGTM!Good placement at the end of initialization after the model is fully set up. The rank 0 guard correctly prevents duplicate logging in distributed training scenarios.
778-805: LGTM!The
count_parametershelper uses an efficient generator-based approach. The logging output is well-formatted with clear labeling for both single-task and multi-task scenarios. The parameter counts in millions with 3 decimal places provide good readability.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.