Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,20 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
12 changes: 12 additions & 0 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,21 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
Expand Down Expand Up @@ -159,6 +170,7 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'context'],
]

vae_logical_axis_rules: [
['activation_batch', 'redundant'],
['activation_length', 'vae_spatial'],
Expand Down
12 changes: 11 additions & 1 deletion src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,20 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/base_wan_animate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,19 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
# Number of devices to shard VAE spatial activations across. -1 uses all devices.
vae_spatial: -1

Expand Down
12 changes: 11 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,20 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
12 changes: 11 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,20 @@ revision: ''
weights_dtype: 'bfloat16'
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
activations_dtype: 'bfloat16'
vae_weights_dtype: 'float32'
vae_dtype: 'float32'
scheduler_dtype: 'float32'

# Replicates vae across devices instead of using the model's sharding annotations for sharding.
replicate_vae: False
vae_spatial: -1 # default to total_device * 2 // (dp)

# Chunk size for VAE decode scan. Increase to improve decode time at the cost of memory.
vae_decode_chunk: 1

# Chunk size for VAE encode scan. (num_input_frames - 1) must be divisible by this value.
# Increase to improve encode time at the cost of memory.
vae_encode_chunk: 4
vae_spatial: -1

# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
# Options are "DEFAULT", "HIGH", "HIGHEST"
Expand Down
8 changes: 7 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
f" Inference: {generation_time:>7.1f}s",
]
if trace:
vae_decode_total = trace.get("vae_decode", 0.0)
vae_decode_tpu = trace.get("vae_decode_tpu", 0.0)
vae_decode_post = vae_decode_total - vae_decode_tpu
summary.extend([
f" {'─' * 40}",
f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s",
f" - VAE Encode: {trace.get('vae_encode', 0.0):>7.1f}s",
f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s",
f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s",
f" VAE Decode: {vae_decode_total:>7.1f}s",
f" - TPU Compute: {vae_decode_tpu:>7.1f}s",
f" - Host Formatting: {vae_decode_post:>7.1f}s",
])
summary.append(f"{'=' * 50}")
max_logging.log("\n".join(summary))
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _tpu_flash_attention(
) -> jax.Array:
"""TPU Flash Attention"""

num_context_shards = mesh.shape["context"]
num_context_shards = mesh.shape["context"] if "context" in mesh.shape else 1
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
Expand Down Expand Up @@ -491,7 +491,9 @@ def ring_scan_body(carry, _):
raise ValueError("ring attention requires context > 1")
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
data_dim = mesh.shape["data"] if "data" in mesh.shape else 1
fsdp_dim = mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1
devices_in_batch_sharding = data_dim * fsdp_dim
# This warning might show up when doing model eval for example, when calculating model flops
# and that is expected.
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
Expand Down
Loading
Loading