Skip to content

Commit a97c980

Browse files
[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
1 parent 635406e commit a97c980

File tree

12 files changed

+2040
-102
lines changed

12 files changed

+2040
-102
lines changed

comfy/sd.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import comfy.ldm.hunyuan_video.vae
2121
import comfy.ldm.mmaudio.vae.autoencoder
2222
import comfy.pixel_space_convert
23+
import comfy.weight_adapter
2324
import yaml
2425
import math
2526
import os
@@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
101102
return (new_modelpatcher, new_clip)
102103

103104

105+
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
106+
"""
107+
Load LoRA in bypass mode without modifying base model weights.
108+
109+
Instead of patching weights, this injects the LoRA computation into the
110+
forward pass: output = base_forward(x) + lora_path(x)
111+
112+
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
113+
114+
This is useful for training and when model weights are offloaded.
115+
"""
116+
key_map = {}
117+
if model is not None:
118+
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
119+
if clip is not None:
120+
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
121+
122+
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
123+
124+
lora = comfy.lora_convert.convert_lora(lora)
125+
loaded = comfy.lora.load_lora(lora, key_map)
126+
127+
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
128+
129+
# Separate adapters (for bypass) from other patches (for regular patching)
130+
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
131+
regular_patches = {} # diff, set, bias patches -> regular weight patching
132+
133+
for key, patch_data in loaded.items():
134+
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
135+
bypass_patches[key] = patch_data
136+
else:
137+
regular_patches[key] = patch_data
138+
139+
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
140+
141+
k = set()
142+
k1 = set()
143+
144+
if model is not None:
145+
new_modelpatcher = model.clone()
146+
147+
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
148+
if regular_patches:
149+
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
150+
k.update(patched_keys)
151+
152+
# Apply adapter patches via bypass injection
153+
manager = comfy.weight_adapter.BypassInjectionManager()
154+
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
155+
156+
for key, adapter in bypass_patches.items():
157+
if key in model_sd_keys:
158+
manager.add_adapter(key, adapter, strength=strength_model)
159+
k.add(key)
160+
else:
161+
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
162+
163+
injections = manager.create_injections(new_modelpatcher.model)
164+
165+
if manager.get_hook_count() > 0:
166+
new_modelpatcher.set_injections("bypass_lora", injections)
167+
else:
168+
new_modelpatcher = None
169+
170+
if clip is not None:
171+
new_clip = clip.clone()
172+
173+
# Apply regular patches to clip
174+
if regular_patches:
175+
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
176+
k1.update(patched_keys)
177+
178+
# Apply adapter patches via bypass injection
179+
clip_manager = comfy.weight_adapter.BypassInjectionManager()
180+
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
181+
182+
for key, adapter in bypass_patches.items():
183+
if key in clip_sd_keys:
184+
clip_manager.add_adapter(key, adapter, strength=strength_clip)
185+
k1.add(key)
186+
187+
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
188+
if clip_manager.get_hook_count() > 0:
189+
new_clip.patcher.set_injections("bypass_lora", clip_injections)
190+
else:
191+
new_clip = None
192+
193+
for x in loaded:
194+
if (x not in k) and (x not in k1):
195+
patch_data = loaded[x]
196+
patch_type = type(patch_data).__name__
197+
if isinstance(patch_data, tuple):
198+
patch_type = f"tuple({patch_data[0]})"
199+
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
200+
201+
return (new_modelpatcher, new_clip)
202+
203+
104204
class CLIP:
105205
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
106206
if no_init:

comfy/weight_adapter/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from .glora import GLoRAAdapter
66
from .oft import OFTAdapter
77
from .boft import BOFTAdapter
8+
from .bypass import (
9+
BypassInjectionManager,
10+
BypassForwardHook,
11+
create_bypass_injections_from_patches,
12+
)
813

914

1015
adapters: list[type[WeightAdapterBase]] = [
@@ -31,4 +36,7 @@
3136
"WeightAdapterTrainBase",
3237
"adapters",
3338
"adapter_maps",
39+
"BypassInjectionManager",
40+
"BypassForwardHook",
41+
"create_bypass_injections_from_patches",
3442
] + [a.__name__ for a in adapters]

0 commit comments

Comments
 (0)