Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions nemo_deploy/llm/inference/inference_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
get_default_load_sharded_strategy,
)
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.inference.contexts.static_context import StaticInferenceContext
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.inference.apis import MegatronLLM
from megatron.core.inference.config import InferenceConfig
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import MLATransformerConfig
Expand Down Expand Up @@ -390,27 +384,24 @@ def setup_model_and_tokenizer_for_inference(


class MCoreEngineWithCleanup:
"""Wrapper around MCoreEngine that ensures proper cleanup of distributed resources.
"""Wrapper around MegatronLLM that ensures proper cleanup of distributed resources.

This class delegates all operations to the underlying MCoreEngine while ensuring that
distributed resources are properly cleaned up when the engine is destroyed.
This class delegates all operations to the underlying MegatronLLM engine while ensuring
that distributed resources are properly cleaned up when the engine is destroyed.
"""

def __init__(
self,
mcore_engine: MCoreEngine,
model_inference_wrapper: GPTInferenceWrapper,
llm: MegatronLLM,
tokenizer: Union[MCoreTokenizerWrappper, MegatronTokenizer],
):
"""Initialize the MCoreEngineWithCleanup.

Args:
mcore_engine (MCoreEngine): The underlying MCoreEngine instance
model_inference_wrapper (GPTInferenceWrapper): The model inference wrapper
llm (MegatronLLM): The underlying MegatronLLM instance
tokenizer (Union[MCoreTokenizerWrappper, MegatronTokenizer]): The tokenizer instance
"""
self.mcore_engine = mcore_engine
self.model_inference_wrapper = model_inference_wrapper
self.mcore_engine = llm
self.tokenizer = tokenizer

def __del__(self):
Expand Down Expand Up @@ -446,16 +437,16 @@ def create_mcore_engine(
buffer_size_gb: float = 10.0,
legacy_model_format: bool = False,
**model_config_kwargs,
) -> Tuple[MCoreEngineWithCleanup, GPTInferenceWrapper, Union[MCoreTokenizerWrappper, MegatronTokenizer]]:
"""Set up the model, tokenizer and MCoreEngine for inference.
) -> Tuple[MCoreEngineWithCleanup, Union[MCoreTokenizerWrappper, MegatronTokenizer]]:
"""Set up the model, tokenizer and MegatronLLM engine for inference.

Args:
path (Path): Path to the checkpoint file
params_dtype (torch.dtype): Data type for model parameters (default: torch.bfloat16)
inference_batch_times_seqlen_threshold (int): Threshold for batch size times sequence length
inference_max_seq_length (int): Maximum sequence length for inference
max_batch_size (int): Maximum batch size for inference
random_seed (Optional[int]): Random seed for reproducibility
random_seed (Optional[int]): Random seed for reproducibility (set globally during init)
tensor_model_parallel_size (Optional[int]): Size of tensor model parallelism
pipeline_model_parallel_size (Optional[int]): Size of pipeline model parallelism
context_parallel_size (Optional[int]): Size of context parallelism
Expand All @@ -466,11 +457,10 @@ def create_mcore_engine(
model_type (str): Type of model to load (default: "gpt")
model_format (str): Format of model to load (default: "nemo")
micro_batch_size (Optional[int]): Micro batch size for model execution
legacy_model_format (bool): Whether to use the legacy StaticInferenceEngine path in MCoreEngine (default: False)
legacy_model_format (bool): Deprecated; no longer used (DynamicInferenceEngine is always used)
Returns:
Tuple[MCoreEngineWithCleanup, GPTInferenceWrapper, Union[MCoreTokenizerWrappper, MegatronTokenizer]]: Tuple containing:
Tuple[MCoreEngineWithCleanup, Union[MCoreTokenizerWrappper, MegatronTokenizer]]: Tuple containing:
- MCoreEngineWithCleanup: Engine for text generation with proper cleanup
- GPTInferenceWrapper: Inference-wrapped model
- Union[MCoreTokenizerWrappper, MegatronTokenizer]: Tokenizer instance
"""
# Default to 1 for any parallelism dimension that's None
Expand Down Expand Up @@ -512,34 +502,32 @@ def create_mcore_engine(
else:
raise ValueError(f"Model format {model_format} not supported.")

# MLA models require block_size_tokens=64 for the dynamic engine, which is not
# configurable in the current Megatron-LM version. Fall back to the legacy static
# engine so MLA inference works correctly without touching Megatron-LM.
model.eval()

# MLA models require block_size_tokens=64 for correct KV cache operation with the
# dynamic inference engine. Set the attention backend to flash if not already set.
block_size_tokens = 256
model_config = getattr(model, "config", None)
if isinstance(model_config, MLATransformerConfig):
legacy_model_format = True
# The legacy static engine requires an explicit attention backend.
# MLA models use flash attention (attention_mask is handled internally).
block_size_tokens = 64
if not model_config.attention_backend:
model_config.attention_backend = AttnBackend.flash

inference_context = StaticInferenceContext(
max_batch_size=max_batch_size,
inference_config = InferenceConfig(
max_sequence_length=inference_max_seq_length,
buffer_size_gb=int(buffer_size_gb),
max_requests=max_batch_size,
block_size_tokens=block_size_tokens,
materialize_only_last_token_logits=True,
)
model_inference_wrapper = GPTInferenceWrapper(model, inference_context)
text_generation_controller = TextGenerationController(
inference_wrapped_model=model_inference_wrapper, tokenizer=tokenizer
)
mcore_engine = MCoreEngine(
text_generation_controller=text_generation_controller,
max_batch_size=max_batch_size,
random_seed=random_seed,
buffer_size_gb=buffer_size_gb,
legacy=legacy_model_format,

llm = MegatronLLM(
model=model,
tokenizer=tokenizer,
inference_config=inference_config,
)

# Wrap the engine to ensure cleanup
wrapped_engine = MCoreEngineWithCleanup(mcore_engine, model_inference_wrapper, tokenizer)
wrapped_engine = MCoreEngineWithCleanup(llm, tokenizer)

return wrapped_engine, model_inference_wrapper, tokenizer
return wrapped_engine, tokenizer
34 changes: 16 additions & 18 deletions nemo_deploy/llm/megatronllm_deployable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch
import torch.distributed
from jinja2 import Template
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.inference_request import DynamicInferenceRequest
from megatron.core.inference.sampling_params import SamplingParams

from nemo_deploy import ITritonDeployable
from nemo_deploy.llm.inference.inference_base import create_mcore_engine
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
if model_type not in ["gpt", "mamba"]:
raise ValueError(f"Model type {model_type} not supported for Megatron models.")

self.mcore_engine, self.inference_wrapped_model, self.mcore_tokenizer = create_mcore_engine(
self.mcore_engine, self.mcore_tokenizer = create_mcore_engine(
num_devices=num_devices,
num_nodes=num_nodes,
path=Path(megatron_checkpoint_filepath),
Expand Down Expand Up @@ -144,18 +144,18 @@ def __init__(
def generate(
self,
prompts: List[str],
inference_params: Optional[CommonInferenceParams] = None,
) -> List[InferenceRequest]:
inference_params: Optional[SamplingParams] = None,
) -> List[DynamicInferenceRequest]:
"""Generates text based on the provided input prompts.

Args:
prompts (List[str]): A list of input strings.
inference_params (Optional[CommonInferenceParams]): Parameters for controlling the inference process.
inference_params (Optional[SamplingParams]): Parameters for controlling the inference process.

Returns:
List[InferenceRequest]: A list containing the generated results.
List[DynamicInferenceRequest]: A list containing the generated results.
"""
inference_params = inference_params or CommonInferenceParams()
inference_params = inference_params or SamplingParams()

# Store the original number of prompts
orig_num_prompts = len(prompts)
Expand All @@ -173,17 +173,15 @@ def generate(

results = self.mcore_engine.generate(
prompts=padded_prompts,
add_BOS=False,
common_inference_params=inference_params,
sampling_params=inference_params,
)

# Only return results for the original prompts
return list(results)[:orig_num_prompts]
else:
results = self.mcore_engine.generate(
prompts=prompts,
add_BOS=False,
common_inference_params=inference_params,
sampling_params=inference_params,
)
return list(results)

Expand All @@ -198,7 +196,7 @@ def generate_other_ranks(self):
data=[None], src=0
)

inference_params = CommonInferenceParams(
inference_params = SamplingParams(
temperature=temperature,
top_k=int(top_k),
top_p=float(top_p),
Expand All @@ -208,7 +206,7 @@ def generate_other_ranks(self):
)

if log_probs:
dynamic_engine = getattr(self.mcore_engine, "dynamic_engine", None)
dynamic_engine = getattr(self.mcore_engine, "engine", None)
if dynamic_engine is not None:
dynamic_engine.materialize_only_last_token_logits = False
dynamic_engine.context.config.materialize_only_last_token_logits = False
Expand Down Expand Up @@ -419,15 +417,15 @@ def _infer_fn(
)

# cast top_k,top_p to native int, float since typecheck assert statements added in MCore0.13 error otherwise
# return_prompt_top_n_logprobs returns top_logprobs for prompt tokens too when top_logprobs>0.
inference_params = CommonInferenceParams(
# skip_prompt_log_probs=False (default) includes prompt tokens in top-N logprobs when top_logprobs>0.
inference_params = SamplingParams(
temperature=temperature,
top_k=int(top_k),
top_p=float(top_p),
num_tokens_to_generate=num_tokens_to_generate,
return_log_probs=log_probs,
top_n_logprobs=top_logprobs,
return_prompt_top_n_logprobs=bool(top_logprobs),
skip_prompt_log_probs=not bool(top_logprobs),
stop_words=stop_words,
)

Expand All @@ -436,7 +434,7 @@ def _infer_fn(
# (prompt log probs are required for logprob eval benchmarks).
# Toggle it on both the engine and the context config (controls the
# model forward pass and log prob calculations).
dynamic_engine = getattr(self.mcore_engine, "dynamic_engine", None)
dynamic_engine = getattr(self.mcore_engine, "engine", None)
needs_all_logits = log_probs or bool(top_logprobs)
if dynamic_engine is not None and needs_all_logits:
dynamic_engine.materialize_only_last_token_logits = False
Expand Down
6 changes: 3 additions & 3 deletions tests/functional_tests/utils/run_nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@

in_framework_supported = True
try:
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams

from nemo_deploy.llm import NemoQueryLLMPyTorch
from nemo_deploy.llm.megatronllm_deployable import MegatronLLMDeployable
except Exception as e:
LOGGER.warning(
"Cannot import MegatronLLMDeployable class, or NemoQueryLLMPyTorch, or CommonInferenceParams, "
"Cannot import MegatronLLMDeployable class, or NemoQueryLLMPyTorch, or SamplingParams, "
f"in-framework inference will not be available. Reason: {type(e).__name__}: {e}"
)
in_framework_supported = False
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_accuracy_with_lambada(model, nq, lora_uids, test_data_path, use_vllm: bo
if in_framework_supported and isinstance(model, MegatronLLMDeployable):
model_output = model.generate(
prompts=[prompt],
inference_params=CommonInferenceParams(
inference_params=SamplingParams(
temperature=0.1,
top_k=1,
top_p=0.0,
Expand Down
50 changes: 15 additions & 35 deletions tests/unit_tests/deploy/test_etp_sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,14 @@ class TestCreateMcoreEngineETPSequenceParallel(unittest.TestCase):
"""Tests that create_mcore_engine handles ETP/SP defaults and passes them down."""

@patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngine")
@patch("nemo_deploy.llm.inference.inference_base.MegatronLLM")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup")
@patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper")
@patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext")
@patch("nemo_deploy.llm.inference.inference_base.TextGenerationController")
def test_etp_defaults_to_1_when_none(
self, mock_tgc, mock_ctx, mock_wrapper, mock_cleanup, mock_engine_cls, mock_setup
):
def test_etp_defaults_to_1_when_none(self, mock_cleanup, mock_llm_cls, mock_setup):
"""expert_tensor_parallel_size=None is normalised to 1 before forwarding."""
from nemo_deploy.llm.inference.inference_base import create_mcore_engine

mock_setup.return_value = ([MagicMock()], MagicMock())
mock_engine_cls.return_value = MagicMock()
mock_llm_cls.return_value = MagicMock()
mock_cleanup.return_value = MagicMock()

create_mcore_engine(path=Path("/fake"), model_format="nemo", expert_tensor_parallel_size=None)
Expand All @@ -410,19 +405,14 @@ def test_etp_defaults_to_1_when_none(
assert kwargs["expert_tensor_parallel_size"] == 1

@patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngine")
@patch("nemo_deploy.llm.inference.inference_base.MegatronLLM")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup")
@patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper")
@patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext")
@patch("nemo_deploy.llm.inference.inference_base.TextGenerationController")
def test_sp_defaults_to_1_when_none(
self, mock_tgc, mock_ctx, mock_wrapper, mock_cleanup, mock_engine_cls, mock_setup
):
def test_sp_defaults_to_1_when_none(self, mock_cleanup, mock_llm_cls, mock_setup):
"""sequence_parallel=None is normalised to 1 before forwarding."""
from nemo_deploy.llm.inference.inference_base import create_mcore_engine

mock_setup.return_value = ([MagicMock()], MagicMock())
mock_engine_cls.return_value = MagicMock()
mock_llm_cls.return_value = MagicMock()
mock_cleanup.return_value = MagicMock()

create_mcore_engine(path=Path("/fake"), model_format="nemo", sequence_parallel=None)
Expand All @@ -431,19 +421,14 @@ def test_sp_defaults_to_1_when_none(
assert kwargs["sequence_parallel"] == 1

@patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngine")
@patch("nemo_deploy.llm.inference.inference_base.MegatronLLM")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup")
@patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper")
@patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext")
@patch("nemo_deploy.llm.inference.inference_base.TextGenerationController")
def test_explicit_etp_passed_through(
self, mock_tgc, mock_ctx, mock_wrapper, mock_cleanup, mock_engine_cls, mock_setup
):
def test_explicit_etp_passed_through(self, mock_cleanup, mock_llm_cls, mock_setup):
"""An explicit expert_tensor_parallel_size value is forwarded unchanged."""
from nemo_deploy.llm.inference.inference_base import create_mcore_engine

mock_setup.return_value = ([MagicMock()], MagicMock())
mock_engine_cls.return_value = MagicMock()
mock_llm_cls.return_value = MagicMock()
mock_cleanup.return_value = MagicMock()

create_mcore_engine(path=Path("/fake"), model_format="nemo", expert_tensor_parallel_size=4)
Expand All @@ -452,19 +437,14 @@ def test_explicit_etp_passed_through(
assert kwargs["expert_tensor_parallel_size"] == 4

@patch("nemo_deploy.llm.inference.inference_base.setup_model_and_tokenizer_for_inference")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngine")
@patch("nemo_deploy.llm.inference.inference_base.MegatronLLM")
@patch("nemo_deploy.llm.inference.inference_base.MCoreEngineWithCleanup")
@patch("nemo_deploy.llm.inference.inference_base.GPTInferenceWrapper")
@patch("nemo_deploy.llm.inference.inference_base.StaticInferenceContext")
@patch("nemo_deploy.llm.inference.inference_base.TextGenerationController")
def test_explicit_sp_passed_through(
self, mock_tgc, mock_ctx, mock_wrapper, mock_cleanup, mock_engine_cls, mock_setup
):
def test_explicit_sp_passed_through(self, mock_cleanup, mock_llm_cls, mock_setup):
"""An explicit sequence_parallel=True value is forwarded unchanged."""
from nemo_deploy.llm.inference.inference_base import create_mcore_engine

mock_setup.return_value = ([MagicMock()], MagicMock())
mock_engine_cls.return_value = MagicMock()
mock_llm_cls.return_value = MagicMock()
mock_cleanup.return_value = MagicMock()

create_mcore_engine(path=Path("/fake"), model_format="nemo", sequence_parallel=True)
Expand All @@ -487,7 +467,7 @@ def test_expert_tensor_parallel_size_forwarded(self, mock_create):
"""expert_tensor_parallel_size is forwarded to create_mcore_engine."""
from nemo_deploy.llm.megatronllm_deployable import MegatronLLMDeployable

mock_create.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_create.return_value = (MagicMock(), MagicMock())

MegatronLLMDeployable(
megatron_checkpoint_filepath="model.ckpt",
Expand All @@ -503,7 +483,7 @@ def test_sequence_parallel_forwarded(self, mock_create):
"""sequence_parallel is forwarded to create_mcore_engine."""
from nemo_deploy.llm.megatronllm_deployable import MegatronLLMDeployable

mock_create.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_create.return_value = (MagicMock(), MagicMock())

MegatronLLMDeployable(
megatron_checkpoint_filepath="model.ckpt",
Expand All @@ -519,7 +499,7 @@ def test_defaults_etp_1_and_sp_false(self, mock_create):
"""Defaults: expert_tensor_parallel_size=1, sequence_parallel=False."""
from nemo_deploy.llm.megatronllm_deployable import MegatronLLMDeployable

mock_create.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_create.return_value = (MagicMock(), MagicMock())

MegatronLLMDeployable(megatron_checkpoint_filepath="model.ckpt")

Expand Down
Loading
Loading