A minimal, customizable implementation of CASteer for concept steering in Stable Diffusion models.
Figure: Castle with metallic steering vector, prompt: “Epic fantasy castle on top of a mountain, clouds swirling around, dramatic lighting.”
Resources:
This repository provides a minimal implementation of concept steering using activation-based guidance vectors. It works by:
- Collecting activations from cross-attention layers when processing prompt pairs
- This implementation takes a different approach than the official implementation, using steering hooks to collect activations rather than re-implementing the cross-attention layer.
- Applying these vectors during inference to guide image generation
Figure: Architecture of integrating vector steering into the diffusion pipeline (SD 1.5 pipeline from Demir's blog).
- tutorial.ipynb: Example of CASteer
- tutorial_composition.ipynb: Example of composing vectors using CASteer
- tutorial_efficient.ipynb: Example of applying CASteer based on diffusion steps
Notebooks should readily run on Google Colab. Can confirm it works on the L4 GPU.
- Hook-based Implementation: Uses forward hooks for both activation collection and steering application.
- Experimentally Customizable: Easy modification of how activations are collected and steering vectors are applied (editing the hook).
- Composable Implementation: Example of how to perform composable vector steering of multiple attributes.
- Steering only after N diffusion steps: Optionally steer only the last 50% of diffusion steps (see
tutorial_efficient.ipynb). Or edit one line to apply steering after any Nth step here.
The implementation requires only a few steps:
# Add hooks to collect activations
steer_hooks = steering.add_steer_hooks(pipe)
# Build steering vectors from prompt pairs
final_vectors = steering.build_final_steering_vectors(pipe, steer_hooks, prompts)
# Add calculated vectors to hooks for inference
steering.add_final_steer_vectors(steer_hooks, final_vectors)
# Generate images with steering
steering.run_grid_experiment(pipe, steer_hooks, test_prompts, steer_scale_list=[0.0, 5.0, 10.0])