Skip to content

Commit 25938ed

Browse files
committed
address comments
1 parent 1983fe4 commit 25938ed

File tree

3 files changed

+64
-33
lines changed

3 files changed

+64
-33
lines changed

docs/_scripts/gen_community.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,43 @@
1313

1414

1515
def _rst_escape(text: str) -> str:
16-
"""Escape text for RST."""
1716
return text.replace("*", r"\*").replace("_", r"\_").replace("`", r"\`")
1817

1918

19+
def _parse_dataset_path(dataset_path: str) -> tuple:
20+
"""Parse a dataset path that may include a remote prefix.
21+
22+
Args:
23+
dataset_path: Dataset ID, optionally with remote prefix (e.g., "hf://user/dataset-id")
24+
25+
Returns:
26+
Tuple of (remote_path, dataset_id) where remote_path is None for default remote
27+
"""
28+
if "://" in dataset_path:
29+
remote_type, rest = dataset_path.split("://", maxsplit=1)
30+
remote_name, dataset_id = rest.split("/", maxsplit=1)
31+
remote_path = f"{remote_type}://{remote_name}"
32+
return remote_path, dataset_id
33+
return None, dataset_path
34+
35+
2036
def _generate_community_dataset_page(dataset_entry):
2137
"""Generate a dataset page for a community dataset."""
22-
dataset_id = dataset_entry.get("dataset_id", "")
23-
display_name = dataset_entry.get("display_name", dataset_id)
38+
dataset_path = dataset_entry.get("dataset_id", "")
39+
display_name = dataset_entry.get("display_name", dataset_path)
2440

25-
if not dataset_id:
41+
if not dataset_path:
2642
warnings.warn(f"Skipping dataset entry without dataset_id: {dataset_entry}")
2743
return
2844

2945
try:
30-
# Get metadata from remote datasets first
31-
remote_datasets = minari.list_remote_datasets(latest_version=True)
46+
# Parse the dataset path to extract remote_path if present
47+
remote_path, dataset_id = _parse_dataset_path(dataset_path)
48+
49+
# Get metadata from the appropriate remote
50+
remote_datasets = minari.list_remote_datasets(
51+
remote_path=remote_path, latest_version=True
52+
)
3253
if dataset_id not in remote_datasets:
3354
warnings.warn(f"Dataset {dataset_id} not found in remote datasets")
3455
return
@@ -59,7 +80,8 @@ def _generate_community_dataset_page(dataset_entry):
5980
content += "This environment can be recovered from the Minari dataset as follows:\n\n"
6081
content += "```python\n"
6182
content += "import minari\n"
62-
content += f"dataset = minari.load_dataset('{dataset_id}')\n"
83+
load_path = dataset_path if remote_path else dataset_id
84+
content += f"dataset = minari.load_dataset('{load_path}', download=True)\n"
6385
content += "env = dataset.recover_environment()\n"
6486
content += "```\n\n"
6587

@@ -74,16 +96,34 @@ def _generate_community_dataset_page(dataset_entry):
7496
print(f"Generated community dataset page for {dataset_id}")
7597

7698
except Exception as e:
77-
warnings.warn(f"Failed to generate page for {dataset_id}: {e}")
99+
warnings.warn(f"Failed to generate page for {dataset_path}: {e}")
100+
101+
102+
def _get_dataset_metadata(dataset_path: str) -> dict:
103+
"""Get metadata for a dataset, handling custom remotes.
104+
105+
Args:
106+
dataset_path: Dataset ID, optionally with remote prefix
107+
108+
Returns:
109+
Metadata dict or empty dict if not found
110+
"""
111+
remote_path, dataset_id = _parse_dataset_path(dataset_path)
112+
try:
113+
remote_datasets = minari.list_remote_datasets(
114+
remote_path=remote_path, latest_version=True
115+
)
116+
return remote_datasets.get(dataset_id, {})
117+
except Exception:
118+
return {}
78119

79120

80121
def generate_community_page(
81122
yaml_path=DATASET_FOLDER.joinpath("community", "community.yaml"),
82123
out_rst=DATASET_FOLDER.joinpath("community", "index.rst"),
83124
):
84125
if not os.path.exists(yaml_path):
85-
print(f"YAML file not found: {yaml_path}")
86-
return
126+
raise FileNotFoundError(f"YAML file not found: {yaml_path}")
87127

88128
with open(yaml_path) as f:
89129
community_data = yaml.safe_load(f) or []
@@ -95,27 +135,22 @@ def generate_community_page(
95135
# Generate the index page
96136
content = "Community Datasets\n"
97137
content += "==================\n\n"
98-
content += "Below is a list of datasets contributed by the community. "
99-
content += "To add yours, open a PR editing ``docs/datasets/community/community.yaml``.\n\n"
138+
content += "Below is a list of datasets contributed by the community.\n\n"
100139

101140
content += ".. raw:: html\n\n"
102141
content += ' <div class="sphx-glr-thumbnails">\n\n'
103142

104-
# Get metadata for all datasets to extract descriptions
105-
remote_datasets = minari.list_remote_datasets(latest_version=True)
106-
107143
for dataset_entry in community_data:
108-
dataset_id = dataset_entry.get("dataset_id", "")
109-
display_name = _rst_escape(dataset_entry.get("display_name", dataset_id))
144+
dataset_path = dataset_entry.get("dataset_id", "")
145+
display_name = _rst_escape(dataset_entry.get("display_name", dataset_path))
110146

111-
# Get description from metadata
112-
description = ""
113-
local_url = f"/datasets/{dataset_id}/" # Set default URL
147+
# Parse the path to get the actual dataset_id for URL
148+
_, dataset_id = _parse_dataset_path(dataset_path)
114149

115-
if dataset_id in remote_datasets:
116-
description = _rst_escape(
117-
remote_datasets[dataset_id].get("description", "")
118-
)
150+
# Get description from metadata (handles custom remotes)
151+
metadata = _get_dataset_metadata(dataset_path)
152+
description = _rst_escape(metadata.get("description", ""))
153+
local_url = f"/datasets/{dataset_id}/"
119154

120155
# Thumb card - matching tutorial style
121156
content += ".. raw:: html\n\n"

docs/_scripts/gen_dataset_md.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,8 @@ def _generate_namespace_page(namespace: str, namespace_content):
287287
file_content += "```\n"
288288

289289
namespace_file = namespace_path.joinpath("index.md")
290-
file = open(namespace_file, "w", encoding="utf-8")
291-
file.write(file_content)
292-
file.close()
290+
with open(namespace_file, "w", encoding="utf-8") as f:
291+
f.write(file_content)
293292

294293

295294
if __name__ == "__main__":

docs/datasets/community/community.yaml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33
# To add your dataset to this page, open a PR adding an entry below.
44
#
55
# Fields:
6-
# - dataset_id: (required) HuggingFace dataset ID in format "username/dataset-name"
6+
# - dataset_id: (required) Dataset ID. For example, for a custom HuggingFace dataset: "hf://username/dataset-name-v0"
77
# - display_name: (required) Human-readable name for the dataset
88
#
99
# Note: Description, author, and tags are automatically extracted from the dataset metadata.
1010

11-
- dataset_id: D4RL/minigrid/fourrooms-v0
12-
display_name: D4RL Minigrid Four Rooms
13-
14-
- dataset_id: mujoco/halfcheetah/expert-v0
15-
display_name: MuJoCo HalfCheetah Expert
11+
- dataset_id: hf://lukasz-sawala-goat/Farama-Antmaze-v5
12+
display_name: AntMaze-v5

0 commit comments

Comments
 (0)