Skip to content

Performance Issue: logits_processor and stopping_criteria Account for 63.5% of Generation Time #51

Description

@AmitMY

Summary

Benchmarking reveals that logits_processor and stopping_criteria in bytes_decoder.generate() account for 63.5% of the total generation time (warm cache). Optimizing these components could provide a 2.74x speedup.

Location

welt/model.py:503-509 in _generate_word_bytes():

return self.bytes_decoder.generate(
    inputs_embeds=inputs_embeds,
    generation_config=bytes_generation_config,
    tokenizer=tokenizer,
    logits_processor=[self.logits_processor],  # ← Costs 31.6% of runtime
    stopping_criteria=stopping_criteria,        # ← Costs 27.9% of runtime
)

Benchmark Results

Warm Cache (After torch.compile):

Configuration Time (s) Speedup vs Baseline % of Baseline
Baseline (both parameters) 3.1276 100%
Without logits_processor 2.1392 31.6% faster 68.4%
Without stopping_criteria 2.2548 27.9% faster 72.1%
Without both 1.1408 63.5% faster 36.5%

Cold Cache (First Run with torch.compile):

Configuration Time (s) Speedup vs Baseline
Baseline (both parameters) 20.1456
Without logits_processor 17.2984 14.1% faster
Without stopping_criteria 18.4905 8.2% faster
Without both 16.9498 15.9% faster

Note: Cold cache results show lower relative impact because compilation overhead dominates (6.4x slower than warm cache).

Analysis

Components:

  1. logits_processor: UTF8ValidationLogitsProcessor (compiled at line 419)

    • Ensures valid UTF-8 byte sequences during generation
    • Costs ~0.99s per benchmark (31.6% of runtime)
  2. stopping_criteria: WordStoppingCriteria

    • Stops generation at word boundaries
    • Costs ~0.87s per benchmark (27.9% of runtime)

Why This Matters:

These two components together take nearly 2x longer than the actual model forward passes, word encoding, and tokenization combined (1.99s vs 1.14s).

Reproduction

# Run benchmark with warmup
python -m welt_training.sample

# The script now includes:
# 1. Warmup run to compile everything
# 2. Timed benchmark run with warm cache

Proposed Solutions

Option 1: Optimize Existing Implementations

  • Profile UTF8ValidationLogitsProcessor to identify bottlenecks
  • Profile WordStoppingCriteria for optimization opportunities
  • Consider vectorization or JIT compilation improvements
  • Investigate if torch.compile is effectively optimizing these components

Option 2: Alternative Implementations

  • Implement validation logic directly in CUDA/Triton for GPU acceleration
  • Move stopping criteria checks to a more efficient location in the generation loop
  • Consider caching or batching validation checks

Option 3: Make Optional

  • Add flags to disable these checks for inference when validation isn't critical
  • Document the trade-offs (performance vs correctness guarantees)

Questions

  1. Are these components already torch.compiled effectively? (They are compiled at line 419/576)
  2. Could validation be moved to post-processing to avoid per-token overhead?
  3. Is there redundancy in the checks that could be eliminated?
  4. What's the actual implementation complexity of these components?

Additional Context

  • Model: sign/WeLT-string-repetition
  • Hardware: NVIDIA GB10 (CUDA capability 12.1)
  • PyTorch optimizations enabled: cudnn benchmark, TF32, Flash Attention
  • Generation config: max_generated_words=32
  • Batch size: 3 samples

Expected Outcome

Ideally, we should be able to:

  1. Keep the correctness guarantees of UTF-8 validation and word stopping
  2. Reduce their combined overhead from ~2s to <0.5s (75% reduction)
  3. Achieve close to the 1.14s generation time while maintaining safety

This would provide a 2.74x overall speedup without compromising functionality.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions