diff --git a/vllm/config.py b/vllm/config.py index 9b3f4f920630..bb86cc1b2b7c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -124,6 +124,9 @@ class ModelConfig: def __init__( self, model: str, + with_ladder : bool, + ladder_model_path : str, + sub_layers_ids: Optional[int], tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -147,8 +150,15 @@ def __init__( served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, - override_neuron_config: Optional[Dict[str, Any]] = None) -> None: + override_neuron_config: Optional[Dict[str, Any]] = None, + aix_model_config: Union[Dict, None] = None) -> None: self.model = model + # -------------------------------------------------------------------------------- + # Yocto : add two argument to support ladder net + self.with_ladder = with_ladder + self.ladder_model_path = ladder_model_path + self.sub_layers_ids = sub_layers_ids + # -------------------------------------------------------------------------------- self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code @@ -174,7 +184,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta) + code_revision, rope_scaling, rope_theta, aix_model_config) self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0b866db6432..4c6ef12c66b8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses import json -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Type, Union) @@ -26,6 +26,8 @@ ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] +def parse_list_of_ints(string): + return json.loads(string) if string else [] def nullable_str(val: str): if not val or val == "None": @@ -58,6 +60,9 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: class EngineArgs: """Arguments for vLLM engine.""" model: str = 'facebook/opt-125m' + with_ladder : bool = False + ladder_model_path : str = "" + sub_layers_ids : Optional[List[int]] = field(default_factory=list) served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None skip_tokenizer_init: bool = False @@ -165,6 +170,27 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=EngineArgs.model, help='Name or path of the huggingface model to use.') + # -------------------------------------------------------------------------------- + # Yocto : add two argument to support ladder net + parser.add_argument( + '--with-ladder', + action='store_true', + help="Flag to indicate if the model has undergone domain-specific training." + ) + parser.add_argument( + '--ladder-model-path', + type=str, + default='', + help="Path to the ladder model if domain-specific training is utilized." + ) + parser.add_argument( + '--sub-layers-ids', + type=parse_list_of_ints, + default="", + help="When domain-adaptive training is activated, \ + it is necessary to specify in which layers to insert auxiliary layers." + ) + # -------------------------------------------------------------------------------- parser.add_argument( '--tokenizer', type=nullable_str, @@ -752,7 +778,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: }, default=None, help="override or set neuron device configuration.") - return parser @classmethod @@ -763,7 +788,7 @@ def from_cli_args(cls, args: argparse.Namespace): engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args - def create_engine_config(self) -> EngineConfig: + def create_engine_config(self, aix_model_config: Union[Dict, None] = None) -> EngineConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -787,10 +812,12 @@ def create_engine_config(self) -> EngineConfig: assert self.cpu_offload_gb >= 0, ( "CPU offload space must be non-negative" f", but got {self.cpu_offload_gb}") - device_config = DeviceConfig(device=self.device) model_config = ModelConfig( model=self.model, + with_ladder=self.with_ladder, + sub_layers_ids=self.sub_layers_ids, + ladder_model_path=self.ladder_model_path, tokenizer=self.tokenizer, tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, @@ -813,7 +840,8 @@ def create_engine_config(self) -> EngineConfig: served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, - override_neuron_config=self.override_neuron_config) + override_neuron_config=self.override_neuron_config, + aix_model_config=aix_model_config) cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 50dcb6937eb6..c416d454b97a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -301,7 +301,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) - self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -314,7 +313,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: prompt_adapter_config=prompt_adapter_config, observability_config=self.observability_config, ) - if not self.model_config.embedding_mode: self._initialize_kv_caches() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b997507ea738..d9df8420ee3a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -118,7 +118,7 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), + weight = Parameter(torch.zeros(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -299,7 +299,6 @@ def __init__(self, if output_sizes is None: output_sizes = [output_size] - self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734a..fa1ddea3edc9 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -92,6 +92,8 @@ class SamplerOutput( # On-device tensor containing the logprobs of each token. logprobs: Optional["torch.Tensor"] = None + probs: Optional["torch.Tensor"] = None + # Holds either (1) the pythonized sampler result (single-step scheduling) # or (2) what will be arguments for later deferred pythonization of the # sampler result (muliti-step scheduling) @@ -305,13 +307,15 @@ def forward( prompt_logprobs, sample_logprobs = get_logprobs( logprobs, sampling_metadata, maybe_deferred_sample_results) - return _build_sampler_output( + __out: SamplerOutput = _build_sampler_output( maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, on_device_tensors=on_device_tensors, skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + __out.probs = probs + return __out @property def _should_modify_greedy_probs_inplace(self) -> bool: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 553fa848489b..0f9ef2a59ea8 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -147,15 +147,25 @@ def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, quant_config: Optional[QuantizationConfig], *, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + scheduler_config: Optional[SchedulerConfig], + with_ladder: Optional[bool], + sub_layers_ids: Optional[List[int]]) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, scheduler_config) - return model_class(config=hf_config, - cache_config=cache_config, - quant_config=quant_config, - **extra_kwargs) + if with_ladder: + return model_class(config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + with_ladder=with_ladder, + sub_layers_ids=sub_layers_ids, + **extra_kwargs) + else: + return model_class(config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + **extra_kwargs) def _initialize_model( @@ -166,6 +176,8 @@ def _initialize_model( scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) + with_ladder = model_config.with_ladder + sub_layers_ids = model_config.sub_layers_ids return build_model( model_class, @@ -175,6 +187,8 @@ def _initialize_model( lora_config=lora_config, multimodal_config=model_config.multimodal_config, scheduler_config=scheduler_config, + with_ladder=with_ladder, + sub_layers_ids=sub_layers_ids ) @@ -299,12 +313,17 @@ def _prepare_weights(self, model_name_or_path: str, return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], + self, model_name_or_path: str, + with_ladder : bool, ladder_model_path : str, + revision: Optional[str], fall_back_to_pt: bool ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision, fall_back_to_pt) + if with_ladder: + hf_weights_files.append(os.path.join(ladder_model_path, 'CodeLST.pt')) + if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False @@ -341,13 +360,26 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights( - self._get_weights_iterator(model_config.model, - model_config.revision, - fall_back_to_pt=getattr( - model, - "fall_back_to_pt_during_load", - True)), ) + if model_config.with_ladder: + model.load_weights_with_ladder( + self._get_weights_iterator(model_config.model, + model_config.with_ladder, + model_config.ladder_model_path, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + else: + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.with_ladder, + model_config.ladder_model_path, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -408,6 +440,133 @@ def _get_weights_iterator( self) -> Generator[Tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = LoadFormat.PT + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, + revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iteratorV2( + self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + return weights_iterator + def _load_model_serialized_cpu( self, @@ -458,7 +617,17 @@ def _load_model_serialized( tensorizer_config.hf_config = model_config.hf_config tensorizer_config.dtype = model_config.dtype - model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + model = load_with_tensorizer(tensorizer_config, model_config, **extra_kwargs) + + if model_config.with_ladder: + model.load_weights_with_ladder( + self._get_weights_iteratorV2(model_config.ladder_model_path, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + return model.eval() def load_model(self, *, model_config: ModelConfig, @@ -474,7 +643,6 @@ def load_model(self, *, model_config: ModelConfig, self.tensorizer_config.tensorizer_uri = \ self.tensorizer_config.tensorizer_uri \ % get_tensor_model_parallel_rank() - if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, cache_config) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index b009ad8c882d..3ce2ff35a8d0 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -101,8 +101,9 @@ def verify_with_model_config(self, model_config: "ModelConfig") -> None: def load_with_tensorizer(tensorizer_config: TensorizerConfig, + model_config: Optional[ModelConfig], **extra_kwargs) -> nn.Module: - tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + tensorizer = TensorizerAgent(tensorizer_config, model_config, **extra_kwargs) return tensorizer.deserialize() @@ -262,6 +263,7 @@ class TensorizerAgent: """ def __init__(self, tensorizer_config: TensorizerConfig, + model_config: Optional[ModelConfig], quant_config: QuantizationConfig, **extra_kwargs): if tensorizer_error_msg is not None: raise ImportError( @@ -270,6 +272,7 @@ def __init__(self, tensorizer_config: TensorizerConfig, "Error message: {}".format(tensorizer_error_msg)) self.tensorizer_config = tensorizer_config + self.model_config = model_config self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) self.extra_kwargs = extra_kwargs @@ -288,6 +291,8 @@ def _init_model(self): return self.tensorizer_config.model_class( config=model_args, quant_config=self.quant_config, + with_ladder=self.model_config.with_ladder, + sub_layers_ids =self.model_config.sub_layers_ids, **self.extra_kwargs) def _resize_lora_embeddings(self): @@ -359,6 +364,7 @@ def deserialize(self): def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" ) -> Generator[Tuple[str, torch.Tensor], None, None]: logger.warning( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0666457756b0..98c367d283b8 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -407,6 +407,7 @@ def pt_weights_iterator( """Iterate over the weights in the model bin/pt files.""" enable_tqdm = not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0 + for bin_file in tqdm( hf_weights_files, desc="Loading pt checkpoint shards", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e55c01316087..c88c9106b1ae 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -34,6 +34,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -49,11 +50,101 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip +from vllm.utils import is_hip, AixQkvWeightHelper from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +# ========================================================================================== +# Yocto : support ladder net + +# class RMSNorm(torch.nn.Module): +# def __init__(self, dim: int, eps: float = 1e-6): +# super().__init__() +# self.eps = eps +# self.weight = torch.nn.Parameter(torch.ones(dim)) + +# def _norm(self, x): +# return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + +# def forward(self, x): +# output = self._norm(x.float()).type_as(x) +# return output * self.weight + +def dropout_add(x, residual, prob, training): + # type: (torch.Tensor, torch.Tensor, float, bool) -> torch.Tensor + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + +# @torch.jit.script +def dropout_add_fused_inference(x: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return dropout_add(x, residual, prob, False) + +class ParallelGatedLinearUnit(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super(ParallelGatedLinearUnit, self).__init__() + + feed_forward_dim = int(int(out_dim/(3*4)) * 4) + + self.dense_1 = ColumnParallelLinear( + in_dim, + feed_forward_dim, + gather_output=False, + skip_bias_add=True, + bias=False, + ) + self.dense_2 = ColumnParallelLinear( + in_dim, + feed_forward_dim, + gather_output=False, + skip_bias_add=True, + bias=False, + ) + self.dense_3 = RowParallelLinear( + feed_forward_dim, + out_dim, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + ) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_state): + part_one, _ = self.dense_1(hidden_state) + part_one = self.activation(part_one) + part_two, _ = self.dense_2(hidden_state) + + hidden_state, _ = self.dense_3(torch.multiply(part_one, part_two)) + return hidden_state + +class LlamaLadderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + ) -> None: + super().__init__() + self.input_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.feed_forward = ParallelGatedLinearUnit( + in_dim=4096, out_dim=4096 + ) + self.merge_gates = nn.Parameter(torch.ones(1)) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.input_norm(hidden_states).to(hidden_states.dtype) + + hidden_states = self.feed_forward(hidden_states) + hidden_states = dropout_add_fused_inference(hidden_states, residual=residual, prob=0) + + return hidden_states + +# ========================================================================================== + class LlamaMLP(nn.Module): @@ -177,9 +268,16 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # debug_list.append({"q":q.clone().detach()}) + # debug_list.append({"k":k.clone().detach()}) + # debug_list.append({"v":v.clone().detach()}) q, k = self.rotary_emb(positions, q, k) + # debug_list.append({"rotary_emb_q":q.clone().detach()}) + # debug_list.append({"rotary_emb_k":k.clone().detach()}) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + # debug_list.append({"attn_output":attn_output.clone().detach()}) output, _ = self.o_proj(attn_output) + # debug_list.append({"o_proj_output":output.clone().detach()}) return output @@ -242,23 +340,40 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention + # debug_list = [] + # debug_list.append({"layer_norm_input" : hidden_states.clone().detach()}) + # debug_list.append({"residual is None" : residual is None}) + # debug_list.append({"layer_norm_weight" : self.input_layernorm.weight}) + # debug_list.append({"layer_norm_eps" : self.input_layernorm.variance_epsilon}) + if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + # debug_list.append({"layer_norm_output" : hidden_states.clone().detach()}) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) + # for item in attn_debug_list: + # debug_list.append(item) + # debug_list.append({"attention_output" : hidden_states.clone().detach()}) + # Fully Connected + # debug_list.append({"post_attention_norm_input" : hidden_states.clone().detach()}) + # debug_list.append({"post_attention_norm_input_res" : residual.clone().detach()}) hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + # debug_list.append({"post_attention_norm_output" : hidden_states.clone().detach()}) hidden_states = self.mlp(hidden_states) + # debug_list.append({"mlp_output" : hidden_states.clone().detach()}) + # return hidden_states, residual, debug_list return hidden_states, residual @@ -271,6 +386,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, prefix: str = "", + with_ladder: Optional[bool] = False, + sub_layers_ids: Optional[List[int]] = [], ) -> None: super().__init__() self.config = config @@ -300,6 +417,20 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() +# ========================================================================================== + # Yocto : support ladder net + # TODO(Yocto) : remove this hack + self.with_ladder = with_ladder + if self.with_ladder: + self.sub_layers_ids = sub_layers_ids + self.input_scaler = ParallelGatedLinearUnit(in_dim=4096, out_dim=4096) + self.output_layer = ParallelGatedLinearUnit(in_dim=4096, out_dim=4096) + self.ladder_layers = torch.nn.ModuleList([LlamaLadderLayer(self.config) for _ in range(len(self.sub_layers_ids))]) + self.finnal_norm = RMSNorm(4096) + self.sigmoid = torch.nn.Sigmoid() + self.activation_func = torch.nn.SiLU() +# ========================================================================================== + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -313,6 +444,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: + # debug_tensor = {} + stem_hidden_state = None + # debug_tensor["input_ids"] = input_ids.clone().detach().int() if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -323,9 +457,16 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + ladder_hidden_states = None + # debug_tensor.append({"embd_output" : hidden_states.clone().detach()}) + # debug_tensor.append({"residual" : residual}) for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + # hidden_list.append({"layer_id" : i}) + # hidden_list.append({"layer" : layer}) + # hidden_list.append({"hidden_states" : hidden_states}) + # hidden_list.append({"residual" : residual}) hidden_states, residual = layer( positions, hidden_states, @@ -333,15 +474,41 @@ def forward( attn_metadata, residual, ) + if self.with_ladder: + if i in self.sub_layers_ids: + # debug_tensor['hidden_states'] = hidden_states.clone().detach() + # debug_tensor['residual'] = residual.clone().detach() + stem_hidden_state = hidden_states[-1].clone().contiguous() + residual[-1].clone().contiguous() + if i == 0: + # debug_tensor['layer_norm_input'] = stem_hidden_state.clone().detach() + ladder_hidden_states = self.input_scaler(stem_hidden_state) + # debug_tensor['layer_norm_output'] = ladder_hidden_states.clone().detach() + else: + # debug_tensor[f'layer_{i}_input'] = ladder_hidden_states.clone().detach() + ladder_layer = self.ladder_layers[self.sub_layers_ids.index(i) - 1] + gate_value = self.sigmoid(ladder_layer.merge_gates) + # debug_tensor[f'layer_{i}_gate_value'] = gate_value.clone().detach() + # debug_tensor[f'layer_{i}_stem_hidden_state'] = stem_hidden_state.clone().detach() + merged = torch.add(stem_hidden_state * gate_value, ladder_hidden_states * (1 - gate_value)) + # debug_tensor[f'layer_{i}_merged'] = merged.clone().detach() + ladder_hidden_states = ladder_layer(merged) + # debug_tensor[f'layer_{i}_output'] = ladder_hidden_states.clone().detach() if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + # debug_tensor["ladder_hidden_states"] = ladder_hidden_states.clone().detach() + if self.with_ladder: + ladder_hidden_states = self.activation_func(ladder_hidden_states) + ladder_hidden_states = self.output_layer(ladder_hidden_states) + ladder_hidden_states = self.finnal_norm(ladder_hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + return hidden_states, ladder_hidden_states class LlamaForCausalLM(nn.Module, SupportsLoRA): @@ -382,6 +549,8 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + with_ladder: Optional[bool] = False, + sub_layers_ids:Optional[List[int]] = [], ) -> None: super().__init__() @@ -392,7 +561,9 @@ def __init__( cache_config, quant_config, lora_config=lora_config, - prefix="model") + prefix="model", + with_ladder=with_ladder, + sub_layers_ids=sub_layers_ids) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -426,9 +597,9 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, kv_caches, + model_output, ladder_hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - return model_output + return model_output, ladder_hidden_states def compute_logits( self, @@ -439,6 +610,15 @@ def compute_logits( sampling_metadata) return logits +# ========================================================================================== + def compute_ladder_logits( + self, + ladder_hidden_states: torch.Tensor + ) -> Optional[torch.Tensor]: + ladder_logits = torch.nn.functional.linear(ladder_hidden_states, self.lm_head.weight) + return ladder_logits +# ========================================================================================== + def sample( self, logits: torch.Tensor, @@ -461,6 +641,45 @@ def make_empty_intermediate_tensors( device=device), }) + def load_weights_with_ladder(self, weights: Iterable[Tuple[str, torch.Tensor]]): + qkv_weight_helper = AixQkvWeightHelper(self) + params_dict = dict(self.named_parameters()) + weight_dtype = params_dict[next(iter(params_dict.keys()))] + for name, loaded_weight in weights: + if name == 'merge_gates': + for layer_id in range(loaded_weight.shape[0]): + source_name = 'model.ladder_layers.' + str(layer_id) + '.merge_gates' + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight[layer_id].unsqueeze(0).shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight[layer_id].unsqueeze(0).cuda() + elif name.startswith("input_scaler"): + source_name = 'model.' + name + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight.shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight.to(weight_dtype).cuda() + elif name.startswith('layers.') and 'feed_forward.dense' in name: + source_name = 'model.ladder_' + name + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight.shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight.cuda() + elif name.startswith('layers.') and 'input_norm.weight' in name: + source_name = 'model.ladder_' + name + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight.shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight.cuda() + elif name.startswith('output_layer'): + source_name = 'model.' + name + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight.shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight.cuda() + elif name == 'finnal_norm.weight': + source_name = 'model.' + name + assert source_name in params_dict.keys(), f"{name} is not in params_dict" + assert params_dict[source_name].shape == loaded_weight.shape, f"{name} shape error" + params_dict[source_name].data = loaded_weight.cuda() + else: + raise ValueError(f"key {name} error.") + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -524,7 +743,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - + # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state diff --git a/vllm/sequence.py b/vllm/sequence.py index a5ebf152ce77..a0d36601a964 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -40,6 +40,7 @@ class Logprob: logprob: float rank: Optional[int] = None decoded_token: Optional[str] = None + ori_logprob: Optional[float] = None # {token_id -> logprob} per each sequence group. None if the corresponding diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index dfe83ddb731d..32069b71571b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,4 +1,5 @@ import contextlib +import os from pathlib import Path from typing import Any, Dict, Optional, Type, Union @@ -53,6 +54,7 @@ def get_config( code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None, + aix_model_config: Union[Dict, None] = None, **kwargs, ) -> PretrainedConfig: @@ -63,12 +65,15 @@ def get_config( model = Path(model).parent try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs) + if not os.path.exists(os.path.join(model, "config.json")) and isinstance(aix_model_config, dict): + config = PretrainedConfig.from_dict(aix_model_config) + else: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): diff --git a/vllm/utils.py b/vllm/utils.py index 657a3ecef696..1cffced62336 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1224,3 +1224,99 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def supports_dynamo() -> bool: base_torch_version = Version(Version(torch.__version__).base_version) return base_torch_version >= Version("2.4.0") + +def split(v: Union[np.ndarray, torch.Tensor], + tp_size: int, + tp_rank: int, + dim=0): + if tp_size == 1: + return v + assert len(v.shape) > 1 or dim == 0 + if isinstance(v, np.ndarray): + return np.ascontiguousarray( + np.split(v, tp_size, axis=dim)[tp_rank].copy()) + else: + assert v.shape[dim] % tp_size == 0, \ + 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' + split_size = v.shape[dim] // tp_size + return v.split(split_size, dim=dim)[tp_rank].clone().detach() + +# Yocto +def dup_kv_weight(v, num_head, tp_size): + assert tp_size % num_head == 0 + reps = tp_size // num_head + head_size = v.shape[0] // num_head + v = v.reshape(num_head, head_size, + -1)[:, None, :, :].expand(num_head, reps, head_size, + v.shape[1]) + return v.reshape(num_head * reps * head_size, -1).clone().detach() + +# Yocto +class AixQkvWeightHelper: + """ A helper utility for loading QKV weights from single files. """ + + def __init__(self, model): + self.hidden_size = 4096 + self.num_heads = 32 + self.num_kv_heads = 8 + self.tp_size = 1 + self.tp_rank = 0 + self.is_mha = self.num_heads == self.num_kv_heads + + @staticmethod + def is_qkv_weight(name): + if "attention.query_key_value.weight" in name: + return True + return False + + def permute_qkv_weights(self, name, param): + """ + name: layers.0.attention.query_key_value.weight + param's shape: [6144, 4096] + param's permutation: + - [qqqqkkvv qqqqkkvv] for np(heads per local)=2, kvnp(kv heads per local)=1, hn(hidden_size per head)=4, hidden_size=8 + - [group_size, num_of_query_local_heads_in_each_group + 1 + 1, head_dim, hidden_size] => [(num_of_heads + num_of_kv_heads * 2) * head_dim, hidden_size] + + """ + + # TODO: now, only work for aix3-7B + assert self.num_kv_heads == 8 + assert self.num_heads == 32 + assert self.hidden_size == 4096 + assert self.is_mha == False + + assert param.shape[0] == (self.num_kv_heads *2) * self.hidden_size//self.num_heads + self.hidden_size + assert param.shape[1] == self.hidden_size + + + if not self.is_mha: + head_size = self.hidden_size // self.num_heads + param = param.reshape(self.num_kv_heads, self.num_heads // self.num_kv_heads + 2, head_size, self.hidden_size) + + # w_q.shape (8, 4, 128, 4096) + # w_kv.shape (8, 2, 128, 4096) + w_q, w_kv = torch.split(param, self.num_heads // self.num_kv_heads, dim=1) + w_k = torch.clone(w_kv[:, 0:1, :, :]) + w_v = torch.clone(w_kv[:, 1:2, :, :]) + + if self.num_kv_heads < self.tp_size: + # duplicate the KV heads up to tensor_parallel + w_k = dup_kv_weight(w_k, self.num_kv_heads, self.tp_size) + v = dup_kv_weight(v, self.num_kv_heads, self.tp_size) + + w_q = torch.reshape(w_q, [self.hidden_size, self.hidden_size]) + w_k = torch.reshape(w_k, [self.num_kv_heads * head_size, self.hidden_size]) + w_v = torch.reshape(w_v, [self.num_kv_heads * head_size, self.hidden_size]) + assert w_k.shape[0] % (self.tp_size * head_size) == 0, f"w_k.shape: {w_k.shape}, self.tp_size: {self.tp_size}, head_size: {head_size}" + assert w_v.shape[0] % (self.tp_size * head_size) == 0, f"w_v.shape: {w_v.shape}, self.tp_size: {self.tp_size}, head_size: {head_size}" + wq = split(w_q, self.tp_size, self.tp_rank) + wk = split(w_k, self.tp_size, self.tp_rank) + wv = split(w_v, self.tp_size, self.tp_rank) + fused_qkv = torch.cat((wq, wk, wv), dim=0) + else: + qkv = param + qkv = qkv.reshape(3, self.hidden_size, self.hidden_size) + fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1) + fused_qkv = fused_qkv.reshape(3 * (self.hidden_size // self.tp_size), + self.hidden_size) + return fused_qkv \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 74f7d4e0860d..6b0791b1cf62 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -913,6 +913,8 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) + if self.model_config.with_ladder: + logger.info("Starting to load ladder model %s...", self.model_config.ladder_model_path) with CudaMemoryProfiler() as m: self.model = get_model(model_config=self.model_config, device_config=self.device_config, @@ -1447,7 +1449,7 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( + results = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, @@ -1456,7 +1458,14 @@ def execute_model( **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) + + if isinstance(results, tuple): + hidden_or_intermediate_states, ladder_hidden_states = results + else: + hidden_or_intermediate_states = results + ladder_hidden_states = None + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() @@ -1482,6 +1491,13 @@ def execute_model( logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) + + if ladder_hidden_states != None: + ladder_logits = self.model.compute_ladder_logits(ladder_hidden_states.to(hidden_or_intermediate_states.dtype)) + logits = ladder_logits.unsqueeze(0) if \ + torch.max(torch.softmax(ladder_logits.float(), dim=-1)) >= \ + torch.max(torch.softmax(logits.float(), dim=-1)) \ + else logits if not self.is_driver_worker: return []