Skip to content

Commit b1c77f6

Browse files
yiyixuxugithub-actions[bot]sayakpaul
authored
[modular] add auto_docstring & more doc related refactors (#12958)
* up * up up * update outputs * style * add modular_auto_docstring! * more auto docstring * style * up up up * more more * up * address feedbacks * add TODO in the description for empty docstring * refactor based on dhruv's feedback: remove the class method * add template method * up * up up up * apply auto docstring * make style * rmove space in make docstring * Apply suggestions from code review * revert change in z * fix * Apply style fixes * include auto-docstring check in the modular ci. (#13004) * Run ruff format after auto docstring generation * up * upup * upup * style --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <[email protected]>
1 parent 956bdcc commit b1c77f6

File tree

14 files changed

+4104
-639
lines changed

14 files changed

+4104
-639
lines changed

.github/workflows/pr_modular_tests.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,27 @@ jobs:
7575
if: ${{ failure() }}
7676
run: |
7777
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
78+
check_auto_docs:
79+
runs-on: ubuntu-22.04
80+
steps:
81+
- uses: actions/checkout@v6
82+
- name: Set up Python
83+
uses: actions/setup-python@v6
84+
with:
85+
python-version: "3.10"
86+
- name: Install dependencies
87+
run: |
88+
pip install --upgrade pip
89+
pip install .[quality]
90+
- name: Check auto docs
91+
run: make modular-autodoctrings
92+
- name: Check if failure
93+
if: ${{ failure() }}
94+
run: |
95+
echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY
7896
7997
run_fast_tests:
80-
needs: [check_code_quality, check_repository_consistency]
98+
needs: [check_code_quality, check_repository_consistency, check_auto_docs]
8199
name: Fast PyTorch Modular Pipeline CPU tests
82100

83101
runs-on:

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ fix-copies:
7070
python utils/check_copies.py --fix_and_overwrite
7171
python utils/check_dummies.py --fix_and_overwrite
7272

73+
# Auto docstrings in modular blocks
74+
modular-autodoctrings:
75+
python utils/modular_auto_docstring.py
76+
7377
# Run tests for the library
7478

7579
test:

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 236 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass, field, fields
1919
from typing import Any, Dict, List, Literal, Optional, Type, Union
2020

21+
import PIL.Image
2122
import torch
2223

2324
from ..configuration_utils import ConfigMixin, FrozenDict
@@ -323,11 +324,192 @@ class ConfigSpec:
323324
description: Optional[str] = None
324325

325326

326-
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
327-
# however some fields are not relevant for intermediate_inputs
328-
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
329-
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
330-
# -> should we use different class for inputs and intermediate_inputs?
327+
# ======================================================
328+
# InputParam and OutputParam templates
329+
# ======================================================
330+
331+
INPUT_PARAM_TEMPLATES = {
332+
"prompt": {
333+
"type_hint": str,
334+
"required": True,
335+
"description": "The prompt or prompts to guide image generation.",
336+
},
337+
"negative_prompt": {
338+
"type_hint": str,
339+
"description": "The prompt or prompts not to guide the image generation.",
340+
},
341+
"max_sequence_length": {
342+
"type_hint": int,
343+
"default": 512,
344+
"description": "Maximum sequence length for prompt encoding.",
345+
},
346+
"height": {
347+
"type_hint": int,
348+
"description": "The height in pixels of the generated image.",
349+
},
350+
"width": {
351+
"type_hint": int,
352+
"description": "The width in pixels of the generated image.",
353+
},
354+
"num_inference_steps": {
355+
"type_hint": int,
356+
"default": 50,
357+
"description": "The number of denoising steps.",
358+
},
359+
"num_images_per_prompt": {
360+
"type_hint": int,
361+
"default": 1,
362+
"description": "The number of images to generate per prompt.",
363+
},
364+
"generator": {
365+
"type_hint": torch.Generator,
366+
"description": "Torch generator for deterministic generation.",
367+
},
368+
"sigmas": {
369+
"type_hint": List[float],
370+
"description": "Custom sigmas for the denoising process.",
371+
},
372+
"strength": {
373+
"type_hint": float,
374+
"default": 0.9,
375+
"description": "Strength for img2img/inpainting.",
376+
},
377+
"image": {
378+
"type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]],
379+
"required": True,
380+
"description": "Reference image(s) for denoising. Can be a single image or list of images.",
381+
},
382+
"latents": {
383+
"type_hint": torch.Tensor,
384+
"description": "Pre-generated noisy latents for image generation.",
385+
},
386+
"timesteps": {
387+
"type_hint": torch.Tensor,
388+
"description": "Timesteps for the denoising process.",
389+
},
390+
"output_type": {
391+
"type_hint": str,
392+
"default": "pil",
393+
"description": "Output format: 'pil', 'np', 'pt'.",
394+
},
395+
"attention_kwargs": {
396+
"type_hint": Dict[str, Any],
397+
"description": "Additional kwargs for attention processors.",
398+
},
399+
"denoiser_input_fields": {
400+
"name": None,
401+
"kwargs_type": "denoiser_input_fields",
402+
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
403+
},
404+
# inpainting
405+
"mask_image": {
406+
"type_hint": PIL.Image.Image,
407+
"required": True,
408+
"description": "Mask image for inpainting.",
409+
},
410+
"padding_mask_crop": {
411+
"type_hint": int,
412+
"description": "Padding for mask cropping in inpainting.",
413+
},
414+
# controlnet
415+
"control_image": {
416+
"type_hint": PIL.Image.Image,
417+
"required": True,
418+
"description": "Control image for ControlNet conditioning.",
419+
},
420+
"control_guidance_start": {
421+
"type_hint": float,
422+
"default": 0.0,
423+
"description": "When to start applying ControlNet.",
424+
},
425+
"control_guidance_end": {
426+
"type_hint": float,
427+
"default": 1.0,
428+
"description": "When to stop applying ControlNet.",
429+
},
430+
"controlnet_conditioning_scale": {
431+
"type_hint": float,
432+
"default": 1.0,
433+
"description": "Scale for ControlNet conditioning.",
434+
},
435+
"layers": {
436+
"type_hint": int,
437+
"default": 4,
438+
"description": "Number of layers to extract from the image",
439+
},
440+
# common intermediate inputs
441+
"prompt_embeds": {
442+
"type_hint": torch.Tensor,
443+
"required": True,
444+
"description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.",
445+
},
446+
"prompt_embeds_mask": {
447+
"type_hint": torch.Tensor,
448+
"required": True,
449+
"description": "mask for the text embeddings. Can be generated from text_encoder step.",
450+
},
451+
"negative_prompt_embeds": {
452+
"type_hint": torch.Tensor,
453+
"description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.",
454+
},
455+
"negative_prompt_embeds_mask": {
456+
"type_hint": torch.Tensor,
457+
"description": "mask for the negative text embeddings. Can be generated from text_encoder step.",
458+
},
459+
"image_latents": {
460+
"type_hint": torch.Tensor,
461+
"required": True,
462+
"description": "image latents used to guide the image generation. Can be generated from vae_encoder step.",
463+
},
464+
"batch_size": {
465+
"type_hint": int,
466+
"default": 1,
467+
"description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
468+
},
469+
"dtype": {
470+
"type_hint": torch.dtype,
471+
"default": torch.float32,
472+
"description": "The dtype of the model inputs, can be generated in input step.",
473+
},
474+
}
475+
476+
OUTPUT_PARAM_TEMPLATES = {
477+
"images": {
478+
"type_hint": List[PIL.Image.Image],
479+
"description": "Generated images.",
480+
},
481+
"latents": {
482+
"type_hint": torch.Tensor,
483+
"description": "Denoised latents.",
484+
},
485+
# intermediate outputs
486+
"prompt_embeds": {
487+
"type_hint": torch.Tensor,
488+
"kwargs_type": "denoiser_input_fields",
489+
"description": "The prompt embeddings.",
490+
},
491+
"prompt_embeds_mask": {
492+
"type_hint": torch.Tensor,
493+
"kwargs_type": "denoiser_input_fields",
494+
"description": "The encoder attention mask.",
495+
},
496+
"negative_prompt_embeds": {
497+
"type_hint": torch.Tensor,
498+
"kwargs_type": "denoiser_input_fields",
499+
"description": "The negative prompt embeddings.",
500+
},
501+
"negative_prompt_embeds_mask": {
502+
"type_hint": torch.Tensor,
503+
"kwargs_type": "denoiser_input_fields",
504+
"description": "The negative prompt embeddings mask.",
505+
},
506+
"image_latents": {
507+
"type_hint": torch.Tensor,
508+
"description": "The latent representation of the input image.",
509+
},
510+
}
511+
512+
331513
@dataclass
332514
class InputParam:
333515
"""Specification for an input parameter."""
@@ -337,11 +519,31 @@ class InputParam:
337519
default: Any = None
338520
required: bool = False
339521
description: str = ""
340-
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
522+
kwargs_type: str = None
341523

342524
def __repr__(self):
343525
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
344526

527+
@classmethod
528+
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
529+
"""Get template for name if exists, otherwise raise ValueError."""
530+
if template_name not in INPUT_PARAM_TEMPLATES:
531+
raise ValueError(f"InputParam template for {template_name} not found")
532+
533+
template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()
534+
535+
# Determine the actual param name:
536+
# 1. From overrides if provided
537+
# 2. From template if present
538+
# 3. Fall back to template_name
539+
name = overrides.pop("name", template_kwargs.pop("name", template_name))
540+
541+
if note and "description" in template_kwargs:
542+
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
543+
544+
template_kwargs.update(overrides)
545+
return cls(name=name, **template_kwargs)
546+
345547

346548
@dataclass
347549
class OutputParam:
@@ -350,13 +552,33 @@ class OutputParam:
350552
name: str
351553
type_hint: Any = None
352554
description: str = ""
353-
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
555+
kwargs_type: str = None
354556

355557
def __repr__(self):
356558
return (
357559
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
358560
)
359561

562+
@classmethod
563+
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
564+
"""Get template for name if exists, otherwise raise ValueError."""
565+
if template_name not in OUTPUT_PARAM_TEMPLATES:
566+
raise ValueError(f"OutputParam template for {template_name} not found")
567+
568+
template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()
569+
570+
# Determine the actual param name:
571+
# 1. From overrides if provided
572+
# 2. From template if present
573+
# 3. Fall back to template_name
574+
name = overrides.pop("name", template_kwargs.pop("name", template_name))
575+
576+
if note and "description" in template_kwargs:
577+
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
578+
579+
template_kwargs.update(overrides)
580+
return cls(name=name, **template_kwargs)
581+
360582

361583
def format_inputs_short(inputs):
362584
"""
@@ -509,10 +731,12 @@ def wrap_text(text, indent, max_length):
509731
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
510732
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
511733
param_str += f"\n{desc_indent}{wrapped_desc}"
734+
else:
735+
param_str += f"\n{desc_indent}TODO: Add description."
512736

513737
formatted_params.append(param_str)
514738

515-
return "\n\n".join(formatted_params)
739+
return "\n".join(formatted_params)
516740

517741

518742
def format_input_params(input_params, indent_level=4, max_line_length=115):
@@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
582806
loading_field_values = []
583807
for field_name in component.loading_fields():
584808
field_value = getattr(component, field_name)
585-
if field_value is not None:
809+
if field_value:
586810
loading_field_values.append(f"{field_name}={field_value}")
587811

588812
# Add loading field information if available
@@ -669,17 +893,17 @@ def make_doc_string(
669893
# Add description
670894
if description:
671895
desc_lines = description.strip().split("\n")
672-
aligned_desc = "\n".join(" " + line for line in desc_lines)
896+
aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines)
673897
output += aligned_desc + "\n\n"
674898

675899
# Add components section if provided
676900
if expected_components and len(expected_components) > 0:
677-
components_str = format_components(expected_components, indent_level=2)
901+
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
678902
output += components_str + "\n\n"
679903

680904
# Add configs section if provided
681905
if expected_configs and len(expected_configs) > 0:
682-
configs_str = format_configs(expected_configs, indent_level=2)
906+
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
683907
output += configs_str + "\n\n"
684908

685909
# Add inputs section

0 commit comments

Comments
 (0)