Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
-m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/api/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig

[[autodoc]] apply_faster_cache

### FirstBlockCacheConfig

[[autodoc]] FirstBlockCacheConfig

[[autodoc]] apply_first_block_cache
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
Expand Down Expand Up @@ -793,11 +795,13 @@
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_torch_available


if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
Expand Down
141 changes: 67 additions & 74 deletions src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Any, Callable, Type

from ..models.attention import BasicTransformerBlock
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from typing import Any, Callable, Dict, Type


@dataclass
Expand All @@ -38,40 +24,90 @@ class AttentionProcessorMetadata:

@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None

_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None

def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]


class AttentionProcessorRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False

@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._register()
cls._registry[model_class] = metadata

@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]

@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_attention_processors_metadata()


class TransformerBlockRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False

@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._register()
metadata._cls = model_class
cls._registry[model_class] = metadata

@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]

@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_transformer_blocks_metadata()


def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor

# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
Expand All @@ -90,11 +126,24 @@ def _register_attention_processors_metadata():


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock

# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
Expand All @@ -104,7 +153,6 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
Expand All @@ -114,7 +162,6 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
Expand All @@ -124,15 +171,13 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
Expand All @@ -142,31 +187,27 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
Expand All @@ -176,7 +217,6 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
Expand All @@ -186,7 +226,6 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
Expand All @@ -196,7 +235,6 @@ def _register_transformer_blocks_metadata():
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
Expand All @@ -223,49 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *

_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states


def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states


def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states


def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states


_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on


_register_attention_processors_metadata()
_register_transformer_blocks_metadata()
Loading
Loading