Skip to content

Commit

Permalink
Add Perturbed-Attention Guidance toggle. See https://ku-cvlab.github.…
Browse files Browse the repository at this point in the history
  • Loading branch information
FeepingCreature committed Feb 5, 2025
1 parent de6ee85 commit bc48608
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CheckpointInput:
rescale_cfg: float = 0.7
self_attention_guidance: bool = False
dynamic_caching: bool = False
perturbed_attention_guidance: bool = False


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,9 @@ def apply_ip_adapter_face(
def apply_self_attention_guidance(self, model: Output):
return self.add("SelfAttentionGuidance", 1, model=model, scale=0.5, blur_sigma=2.0)

def apply_perturbed_attention_guidance(self, model: Output):
return self.add("PerturbedAttentionGuidance", 1, model=model)

def inpaint_preprocessor(self, image: Output, mask: Output):
return self.add("InpaintPreprocessor", 1, image=image, mask=mask)

Expand Down
8 changes: 8 additions & 0 deletions ai_diffusion/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class StyleSettings:
_("Pay more attention to difficult parts of the image. Can improve fine details."),
)

perturbed_attention_guidance = Setting(
"Enable PAG / Perturbed-Attention Guidance",
False,
'Deliberately introduce errors in "difficult" parts to steer away from. Can improve coherence.',
)

preferred_resolution = Setting(
_("Preferred Resolution"), 0, _("Image resolution the checkpoint was trained on")
)
Expand Down Expand Up @@ -121,6 +127,7 @@ class Style:
v_prediction_zsnr: bool = StyleSettings.v_prediction_zsnr.default
rescale_cfg: float = StyleSettings.rescale_cfg.default
self_attention_guidance: bool = StyleSettings.self_attention_guidance.default
perturbed_attention_guidance: bool = StyleSettings.perturbed_attention_guidance.default
preferred_resolution: int = StyleSettings.preferred_resolution.default
sampler: str = StyleSettings.sampler.default
sampler_steps: int = StyleSettings.sampler_steps.default
Expand Down Expand Up @@ -199,6 +206,7 @@ def get_models(self, available_checkpoints: Iterable[str]):
rescale_cfg=self.rescale_cfg,
loras=[LoraInput.from_dict(l) for l in self.loras if l.get("enabled", True)],
self_attention_guidance=self.self_attention_guidance,
perturbed_attention_guidance=self.perturbed_attention_guidance,
)
return result

Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/ui/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,13 +662,19 @@ def add(name: str, widget: SettingWidget):
SwitchSetting(StyleSettings.self_attention_guidance, parent=self),
)

self._pag = add(
"perturbed_attention_guidance",
SwitchSetting(StyleSettings.perturbed_attention_guidance, parent=self),
)

self._checkpoint_advanced_widgets = [
self._arch_select,
self._vae,
self._clip_skip,
self._resolution_spin,
self._zsnr,
self._sag,
self._pag,
]
for widget in self._checkpoint_advanced_widgets:
widget.indent = 1
Expand Down Expand Up @@ -867,6 +873,7 @@ def _enable_checkpoint_advanced(self):
self._clip_skip.enabled = arch.supports_clip_skip and self.current_style.clip_skip > 0
self._zsnr.enabled = arch.supports_attention_guidance
self._sag.enabled = arch.supports_attention_guidance
self._pag.enabled = arch.supports_attention_guidance

def _read_style(self, style: Style):
with self._write_guard:
Expand Down
3 changes: 3 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

if checkpoint.perturbed_attention_guidance:
model = w.apply_perturbed_attention_guidance(model)

return model, clip, vae


Expand Down

0 comments on commit bc48608

Please sign in to comment.