Skip to content
Draft
8 changes: 4 additions & 4 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def add(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha)
if with_text_lora:

weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
scale=alpha,
target_replace_module=["CLIPAttention"],
)

Expand All @@ -93,12 +93,12 @@ def add(
path_1,
).to("cpu")

weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), scale=alpha)
if with_text_lora:
weight_apply_lora(
loaded_pipeline.text_encoder,
torch.load(_text_lora_path(path_2)),
alpha=alpha,
scale=alpha,
target_replace_module=["CLIPAttention"],
)

Expand Down
95 changes: 76 additions & 19 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,39 @@


class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
def __init__(self, in_features, out_features, bias=False, r=4, scale=1.0, init=None, nonlin: nn.Module = None):
super().__init__()

if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)


if scale <= 0:
raise ValueError(
f"LoRA scale {scale} must be greater than 0"
)

self.r = r
self.scale = scale
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.nonlin = nonlin if nonlin else None
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0

nn.init.normal_(self.lora_down.weight, std=1 / r)
if init=="kaiming":
pass
# Kaiming with a=math.sqrt(5) is default for nn.Linear
else:
nn.init.normal_(self.lora_down.weight, std=1 / r)

nn.init.zeros_(self.lora_up.weight)

def forward(self, input):
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
if self.nonlin:
return self.linear(input) + self.lora_up(self.nonlin(self.lora_down(input))) * self.scale
else:
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale


UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
Expand Down Expand Up @@ -116,6 +131,9 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
init=None,
nonlin=None,
loras=None, # path to lora .pt
):
"""
Expand All @@ -137,7 +155,10 @@ def inject_trainable_lora(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
r=r,
scale=scale,
init=init,
nonlin=nonlin,
)
_tmp.linear.weight = weight
if bias is not None:
Expand Down Expand Up @@ -333,9 +354,13 @@ def load_safeloras(path, device="cpu"):


def weight_apply_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0
):

model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _m, _n, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
):
Expand All @@ -344,13 +369,22 @@ def weight_apply_lora(
up_weight = loras.pop(0).detach().to(weight.device)
down_weight = loras.pop(0).detach().to(weight.device)

# W <- W + U * D
weight = weight + alpha * (up_weight @ down_weight).type(weight.dtype)
if nonlin is None:
# W <- W + U * D
weight = weight + scale * (up_weight @ down_weight).type(weight.dtype)
else:
# W <- W + U * nonlin(D)
weight = weight + scale * (up_weight @ nonlin(down_weight)).type(weight.dtype)

_child_module.weight = nn.Parameter(weight)


def monkeypatch_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
Expand All @@ -362,6 +396,8 @@ def monkeypatch_lora(
_child_module.out_features,
_child_module.bias is not None,
r=r,
scale=scale,
nonlin=nonlin,
)
_tmp.linear.weight = weight

Expand All @@ -385,7 +421,12 @@ def monkeypatch_lora(


def monkeypatch_replace_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[LoraInjectedLinear]
Expand All @@ -397,7 +438,9 @@ def monkeypatch_replace_lora(
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
scale=scale,
nonlin=nonlin,
)
_tmp.linear.weight = weight

if bias is not None:
Expand All @@ -424,6 +467,8 @@ def monkeypatch_or_replace_lora(
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: Union[int, List[int]] = 4,
scale: Union[float, List[float]] = 1.0,
nonlin: Union[float, List[nn.Module]] = None,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
Expand All @@ -441,6 +486,8 @@ def monkeypatch_or_replace_lora(
_source.out_features,
_source.bias is not None,
r=r.pop(0) if isinstance(r, list) else r,
scale=scale.pop(0) if isinstance(scale, list) else scale,
nonlin=nonlin.pop(0) if isinstance(nonlin, list) else nonlin,
)
_tmp.linear.weight = weight

Expand Down Expand Up @@ -496,7 +543,7 @@ def monkeypatch_add_lora(
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
alpha: float = 1.0,
scale: float = 1.0,
beta: float = 1.0,
):
for _module, name, _child_module in _find_modules(
Expand All @@ -519,12 +566,16 @@ def monkeypatch_add_lora(
_module._modules[name].to(weight.device)


def tune_lora_scale(model, alpha: float = 1.0):
def tune_lora_scale(model, alpha: float = 1.0, scale: float = None):
if alpha:
# Keep original named parameter alpha (which is really scale),
scale = alpha

for _module in model.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
_module.scale = alpha

_module.scale = scale


def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
Expand Down Expand Up @@ -576,6 +627,8 @@ def patch_pipe(
unet_path,
token: str,
r: int = 4,
scale: float = 1.0,
nonlin: nn.Module = None,
patch_unet=True,
patch_text=False,
patch_ti=False,
Expand All @@ -596,6 +649,8 @@ def patch_pipe(
pipe.unet,
torch.load(unet_path),
r=r,
scale=scale,
nonlin=nonlin,
target_replace_module=unet_target_replace_module,
)

Expand All @@ -606,6 +661,8 @@ def patch_pipe(
torch.load(text_path),
target_replace_module=text_target_replace_module,
r=r,
scale=scale,
nonlin=nonlin,
)
if patch_ti:
print("LoRA : Patching token input")
Expand Down