diff --git a/mindformers/core/callback/__init__.py b/mindformers/core/callback/__init__.py index 306f2547814068b36c6b35aac5cb7ecc9c4177ee..922fa9d192627579f4b64b6d174c8e017f0d64f5 100644 --- a/mindformers/core/callback/__init__.py +++ b/mindformers/core/callback/__init__.py @@ -21,10 +21,12 @@ from .callback import ( ProfileMonitor, SummaryMonitor, StressDetectCallBack, - TrainingStateMonitor + TrainingStateMonitor, + MaxLogitsMonitor ) __all__ = [ 'CheckpointMonitor', 'EvalCallBack', 'MFLossMonitor', - 'ProfileMonitor', 'SummaryMonitor', 'TrainingStateMonitor' + 'ProfileMonitor', 'SummaryMonitor', 'TrainingStateMonitor', + 'MaxLogitsMonitor' ] diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index 2b1bf9fd8afd0b4a15fd780fbc706f66cd338275..a7ff25944d08fefdd66de491ec3a94e000c038a5 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -58,6 +58,7 @@ from mindspore.communication.comm_func import all_gather_into_tensor, barrier from mindspore.profiler import ProfilerLevel, schedule from mindspore.utils import stress_detect +from mindformers.wrapper.wrapper import get_real_models from mindformers.core.context.build_context import is_legacy_model from mindformers.tools import get_output_root_path from mindformers.tools.logger import logger @@ -651,7 +652,9 @@ class TrainingStateMonitor(Callback): dataset_size (int, optional): Required in sink mode. Training dataset size. Default: ``None``. initial_epoch (int, optional): The beginning epoch. Default: ``0``. initial_step (int, optional): The beginning step. Default: ``0``. + micro_batch_num (int, optional): MicroBatch size for Pipeline Parallel. Default: ``0``. global_batch_size (int, optional): The total batch size. Default: ``0``. + tensor_model_parallel_size (int, optional): Tensor model parallel size. Default: ``0``. check_for_nan_in_loss_and_grad (bool, optional): Whether to check loss and norm of grad is Nan. Default: ``False``. use_skip_data_by_global_norm (bool, optional): Whether to use the skip data function @@ -669,7 +672,9 @@ class TrainingStateMonitor(Callback): dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, + micro_batch_num: int = 0, global_batch_size: int = 0, + tensor_model_parallel_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, @@ -688,7 +693,9 @@ class TrainingStateMonitor(Callback): self.origin_epochs = origin_epochs self.initial_epoch = initial_epoch self.initial_step = initial_step + self.micro_batch_num = micro_batch_num self.global_batch_size = global_batch_size + self.tensor_model_parallel_size = tensor_model_parallel_size self.global_norm_spike_count = 0 self.use_skip_data_by_global_norm = use_skip_data_by_global_norm self.embedding_size = embedding_size @@ -765,6 +772,8 @@ class TrainingStateMonitor(Callback): self._dump_data_in_step(cb_params.cur_step_num) if self.optimizer_state_format: self._dump_optimizer_state(cb_params) + if self.max_attention_logit_format: + self._dump_max_attention_logit(cb_params) if self.weight_state_format: network = cb_params.network if isinstance(network, ms.nn.TrainOneStepCell): @@ -884,6 +893,7 @@ class TrainingStateMonitor(Callback): self.device_local_norm_format = config.get('device_local_norm_format', None) self.device_local_loss_format = \ config.get('device_local_loss_format', None) if is_last_pipeline_stage() else None + self.max_attention_logit_format = config.get('max_attention_logit_format', None) self.optimizer_state_format = config.get('optimizer_state_format', None) self.weight_state_format = config.get('weight_state_format', None) self.throughput_baseline = config.get('throughput_baseline', None) @@ -907,7 +917,7 @@ class TrainingStateMonitor(Callback): if not isinstance(self.print_struct, bool): raise TypeError("The value of 'print_struct' should be bool.") attrs = ['local_norm_format', 'local_loss_format', 'device_local_norm_format', 'device_local_loss_format', - 'optimizer_state_format', 'weight_state_format'] + 'optimizer_state_format', 'weight_state_format', 'max_attention_logit_format'] for attr in attrs: self._check_attr_formats(attr) if self.global_norm_record_path and os.path.exists(self.global_norm_record_path): @@ -1016,6 +1026,41 @@ class TrainingStateMonitor(Callback): if 'tensorboard' in self.local_loss_format: self._output(f'local_{loss_tag}_loss', np.mean(loss_list), self.dump_step, ['tensorboard']) + def _dump_max_attention_logit(self, cb_params): + """write the max attention logit to log/tensorboard""" + network = cb_params.train_network + network = get_real_models(network) + params = network.get_max_attention_logit() + + if not params: + return + step = cb_params.cur_step_num + vals = [] + for param_name, param in params.items(): + v = param.asnumpy().squeeze() + v = v / max(1, self.micro_batch_num) + + tag = f"max_attention_logit/{param_name}" + if 'log' in self.max_attention_logit_format: + self._output(tag, v.tolist(), step, ['log']) + if 'tensorboard' in self.max_attention_logit_format: + tp_id = get_rank() // self.tensor_model_parallel_size + head_start = tp_id * len(v) + data = {f"head_{head_start+i}": max_attention_logit for i, max_attention_logit in enumerate(v)} + self._output(tag, data, step, ['tensorboard']) + + vals.append(v) + + if vals: + mean_v = float(np.mean(vals)) + max_v = float(np.max(vals)) + if 'tensorboard' in self.max_attention_logit_format: + self._output('max_attention_logit/mean', mean_v, step, ['tensorboard']) + self._output('max_attention_logit/max', max_v, step, ['tensorboard']) + if 'log' in self.max_attention_logit_format: + self._output('max_attention_logit/mean', mean_v, step, ['log']) + self._output('max_attention_logit/max', max_v, step, ['log']) + def _dump_optimizer_state(self, cb_params): """write the optimizer state to tensorboard""" optimizer = cb_params.optimizer @@ -1055,19 +1100,33 @@ class TrainingStateMonitor(Callback): return file_list = os.listdir(self.dump_path) for f in file_list: + if f.startswith(".nfs"): + continue os.remove(os.path.join(self.dump_path, f)) def _to_tensorboard(self, tag, data, global_step): """Write data to tensorboard if possible""" if self.tensor_writer is not None: - self.tensor_writer.add_scalar(tag, data, global_step=global_step) + if isinstance(data, dict): + self.tensor_writer.add_scalars(tag, data, global_step=global_step) + else: + self.tensor_writer.add_scalar(tag, data, global_step=global_step) def _to_log(self, tag, data, global_step): """Write data to log file""" cur_epoch_num = (global_step + self.initial_step - 1) // self.steps_per_epoch + 1 cur_step_num = (global_step + self.initial_step - 1) % self.steps_per_epoch + 1 - logger.info("Epoch:[%3d/%3d], step:[%5d/%5d] %s: %.4f", - cur_epoch_num, self.origin_epochs, cur_step_num, self.steps_per_epoch, tag, data) + if isinstance(data, list): + logger.info( + "Epoch:[%3d/%3d], step:[%5d/%5d] %s: %s", + cur_epoch_num, self.origin_epochs, cur_step_num, self.steps_per_epoch, tag, + [round(x, 4) if isinstance(x, (float, int)) else x for x in data] + ) + else: + logger.info( + "Epoch:[%3d/%3d], step:[%5d/%5d] %s: %.4f", + cur_epoch_num, self.origin_epochs, cur_step_num, self.steps_per_epoch, tag, data + ) def _output(self, tag, data, global_step, formats): """Write data in specified formats""" @@ -1276,7 +1335,7 @@ class CheckpointMonitor(ModelCheckpoint): self.global_batch_size = global_batch_size # this list records parameters which will be ignored when saving ckpt. self.filter_list = ['accu_grads', 'fi_parameter', 'zeros_k_pe', 'zeros_k_nope', 'zeros_value_states', '_cache', - '_device_local_norm', '_device_local_loss', 'expert_load'] + '_device_local_norm', '_device_local_loss', 'expert_load', 'max_logits_val'] self.save_info_list = defaultdict( lambda: { @@ -2406,6 +2465,47 @@ class StressDetectCallBack(Callback): logger.warning(f"Stress detection failed with error code: {ret}") +@MindFormerRegister.register(MindFormerModuleType.CALLBACK) +class MaxLogitsMonitor(Callback): + """ + Callback to reset max attention logits during training. + + This callback resets the maximum attention logit values at the end of each training step. + """ + + def __init__(self,): + pass + + def _reset_max_attention_logit(self, network): + """Reset max attention logit in the network. + + Args: + network: The network to reset max attention logit. + + Raises: + RuntimeError: If the network does not have reset_max_attention_logit method. + """ + while hasattr(network, "network"): + network = network.network + if hasattr(network, "reset_max_attention_logit"): + network.reset_max_attention_logit() + else: + raise RuntimeError(f"network {type(network).__name__} should have reset_max_attention_logit") + + def on_train_step_end(self, run_context): + """update expert bias at the end of step.""" + cb_params = run_context.original_args() + self.cur_step = cb_params.cur_step_num + # pylint: disable=W0212 + network = cb_params.train_network + while hasattr(network, 'network'): + network = network.network + parallel_mode = get_auto_parallel_context("parallel_mode") + if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and ms.get_context('mode') == 0: + network = network._backbone + self._reset_max_attention_logit(network) + + @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class TopkBiasBalanceCallback(Callback): """ diff --git a/mindformers/core/optim/__init__.py b/mindformers/core/optim/__init__.py index 6947d1857a578a68db7d1547ddbaa53ac59a8ad5..0ae30e129b4bf41a282027bcccf880e04df5aa31 100644 --- a/mindformers/core/optim/__init__.py +++ b/mindformers/core/optim/__init__.py @@ -21,8 +21,9 @@ from .adamw import AdamW as BasicAdamW from .fused_adamw import FusedAdamW from .pma_adamw import PmaAdamW as BasicPmaAdamW from .fused_pma_adamw import FusedPmaAdamW +from .muon import Muon -__all__ = ['AdamW', 'PmaAdamW'] +__all__ = ['AdamW', 'PmaAdamW', 'Muon'] @MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) diff --git a/mindformers/core/optim/build_optim.py b/mindformers/core/optim/build_optim.py index b8e0887ff2774d472b2d007deae5fb939d2f14d6..aa7530d809df8491509a8ca302d2dfefafc23670 100644 --- a/mindformers/core/optim/build_optim.py +++ b/mindformers/core/optim/build_optim.py @@ -33,6 +33,7 @@ def get_tft_wrapped_cls(class_name, config): optim_cls = optim_cls.get_actual_adamw_cls(use_fused) if check_tft_valid(): + # pylint: disable=C0415 from mindspore.train.callback import TrainFaultTolerance optim_cls = TrainFaultTolerance.get_optimizer_wrapper(optim_cls) else: @@ -89,7 +90,7 @@ def build_optim( if default_args is not None: config.update(default_args) - + config = config.copy() optim_cls, config = get_tft_wrapped_cls(config.pop('type'), config) else: optim_cls, config = get_tft_wrapped_cls(class_name, kwargs) diff --git a/mindformers/core/optim/muon.py b/mindformers/core/optim/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..46754524b1011f66abac56d4a2e72d3e1ab4cf94 --- /dev/null +++ b/mindformers/core/optim/muon.py @@ -0,0 +1,539 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Muon API""" + +from __future__ import absolute_import + +import hashlib + +import numpy as np +from mindspore.common import dtype as mstype +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.common.api import jit +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.communication.management import create_group, get_rank +from mindspore.ops.auto_generate import Chunk +from mindspore import get_auto_parallel_context + +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.core.context import is_legacy_model +from mindformers.tools.logger import logger + +_muon_opt = C.MultitypeFuncGraph("muon_opt") + + +def _perform_allgather_op(ns_inputs_item, op, tp, tp_dim, op_group, tp_group, param_name): + """Perform AllGather operations based on op and tp settings.""" + if "mlp.experts.weight" not in param_name: + # all gather op_shard + if op > 1: + ns_inputs_item = P.AllGather(group=op_group)(ns_inputs_item) + + # all gather tp_shard + if tp > 1: + if tp_dim == 0: + ns_inputs_item = P.AllGather(group=tp_group)(ns_inputs_item) + elif tp_dim == 1: + ns_inputs_item = P.AllGather(group=tp_group)(ns_inputs_item.T) + ns_inputs_item = ns_inputs_item.T + return ns_inputs_item + + +def zeropower_via_newtonschulz5_2d(x, dim_a, dim_b): + """Apply Newton-Schulz iteration for 2D tensors.""" + a, b, c = (3.4445, -4.7750, 2.0315) + + if dim_a > dim_b: + x = x.T + # Ensure spectral norm is at most 1 + x = x / (x.norm() + 1e-7) + # Perform the NS iterations + for _ in range(5): + a_mat = x @ x.T + b_mat = b * a_mat + c * a_mat @ a_mat + x = a * x + b_mat @ x + if dim_a > dim_b: + x = x.T + return x + + +def zeropower_via_newtonschulz5_3d(x, dim_a, dim_b): + """Apply Newton-Schulz iteration for 3D tensors.""" + a, b, c = (3.4445, -4.7750, 2.0315) + + if dim_a > dim_b: + x = P.Transpose()(x, (0, 2, 1)) + # Ensure spectral norm is at most 1 + x = x / P.ExpandDims()(P.ExpandDims()((x.norm(dim=(1, 2)) + 1e-7), 1), 1) + # Perform the NS iterations + for _ in range(5): + a_mat = P.BatchMatMul(transpose_b=True)(x, x) + b_mat = b * a_mat + c * P.BatchMatMul()(a_mat, a_mat) + x = a * x + P.BatchMatMul()(b_mat, x) + if dim_a > dim_b: + x = P.Transpose()(x, (0, 2, 1)) + return x + + +def _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group): + """Slice tensor to tp_shard and op_shard.""" + # slice X to tp_shard and slice X to op_shard + if tp > 1: + if tp_dim >= 0: + chunk_id = rank_id % tp + x = Chunk()(x, tp, tp_dim)[chunk_id] + + if op > 1: + if tp_dim == -1: + chunk_id = rank_id % op + else: + chunk_id = rank_id // tp % op + x = Chunk()(x, op)[chunk_id] + return x + + +def _apply_muon_update( + gradient, muon_m, momentum, use_nesterov, param, lr, weight_decay, + matched_adamw_rms, muon_split_fn, muon_merge_fn, param_name, + op, tp, tp_dim, rank_id, op_group, tp_group): + """Apply Muon optimizer update.""" + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + + m_fp32 = op_cast(muon_m, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + next_m = m_fp32 * momentum + gradient_fp32 + + if use_nesterov: + gradient_fp32 = gradient_fp32 + next_m * momentum + else: + gradient_fp32 = next_m + + ns_inputs = op_cast(gradient_fp32, mstype.bfloat16) + ns_inputs_list = muon_split_fn(param_name, ns_inputs) + x_list = [] + + dim_a, dim_b = None, None + for ns_inputs_item in ns_inputs_list: + dim_a, dim_b = op_shape(ns_inputs_item)[-2:] + + if len(op_shape(ns_inputs_item)) == 2: + ns_inputs_item = _perform_allgather_op( + ns_inputs_item, op, tp, tp_dim, op_group, tp_group, param_name) + x = zeropower_via_newtonschulz5_2d(ns_inputs_item, dim_a, dim_b) + x = _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group) + else: + x = zeropower_via_newtonschulz5_3d(ns_inputs_item, dim_a, dim_b) + + x_list.append(x) + + x_ret = muon_merge_fn(param_name, x_list) + param_fp32 = op_cast(param, mstype.float32) + param_fp32 = param_fp32 * (1 - lr * weight_decay) + + adjusted_ratio = op_sqrt(op_cast(max(dim_a, dim_b), mstype.float32)) * matched_adamw_rms + adjusted_lr = lr * adjusted_ratio + update_with_lr = adjusted_lr * x_ret + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(muon_m, op_cast(next_m, F.dtype(muon_m)))) + return op_cast(next_param, F.dtype(param)) + + +def _apply_adamw_update(param, exp_avg, exp_avg_sq, gradient, beta1, beta2, step, eps, lr, weight_decay): + """Apply AdamW optimizer update.""" + op_mul = P.Mul() + op_pow = P.Pow() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + addcmul = P.Addcmul() + + param_fp32 = op_cast(param, mstype.float32) + next_param = op_mul(param_fp32, 1 - lr * weight_decay) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_param = F.depend( + next_param, + F.assign( + exp_avg, + op_mul(exp_avg, beta1) + + op_mul(gradient_fp32, op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1), + ), + ) + next_param = F.depend( + next_param, + F.assign( + exp_avg_sq, + addcmul( + op_mul(exp_avg_sq, beta2), + gradient_fp32, + gradient_fp32, + op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, + ), + ), + ) + + bias_correction1 = 1 - op_pow(op_cast(beta1, mstype.float32), step) + bias_correction2 = 1 - op_pow(op_cast(beta2, mstype.float32), step) + step_size = lr / bias_correction1 + denom = op_sqrt(exp_avg_sq / bias_correction2) + eps + return_param = next_param - op_mul(exp_avg / denom, step_size) + F.assign(param, op_cast(return_param, F.dtype(param))) + return op_cast(return_param, F.dtype(param)) + + +@_muon_opt.register( + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Number", "Number", "Bool", + "Bool", "Bool", "String", "String", "String", "Function", "Function") +def _update_run_op( + momentum, matched_adamw_rms, beta1, beta2, step, eps, lr, weight_decay, rank_id, + param, exp_avg, exp_avg_sq, gradient, muon_m, tp, op, tp_dim, use_muon, + use_nesterov, optim_filter, op_group, tp_group, param_name, muon_split_fn, muon_merge_fn): + """ + Update parameters. + + Args: + beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. + lr (Tensor): Learning rate. + weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0. + param (Tensor): Parameters. + m (Tensor): m value of parameters. + v (Tensor): v value of parameters. + gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. + + Returns: + Tensor, the new value of v after updating. + """ + op_cast = P.Cast() + if "max_logits_val" in param_name: + return op_cast(gradient, F.dtype(param)) + + if not optim_filter: + return gradient + + if use_muon: + return _apply_muon_update( + gradient, muon_m, momentum, use_nesterov, param, lr, weight_decay, + matched_adamw_rms, muon_split_fn, muon_merge_fn, param_name, + op, tp, tp_dim, rank_id, op_group, tp_group) + + return _apply_adamw_update(param, exp_avg, exp_avg_sq, gradient, beta1, beta2, step, eps, lr, weight_decay) + + +@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) +class Muon(Optimizer): + """ + Muon optimizer implementation. + + Args: + params: model parameters to optimize. + learning_rate (float): Learning rate. Default: ``2e-2``. + weight_decay (float): Weight decay factor. Default: ``0.1``. + matched_adamw_rms (float): RMS matching parameter for AdamW. Default: ``0.2``. + momentum (float): Momentum factor. Default: ``0.95``. + nesterov (bool): Whether to use Nesterov momentum. Default: ``True``. + ns_steps (int): Number of Newton-Schulz steps. Default: ``5``. + adamw_betas (tuple): Beta parameters for AdamW. Default: ``(0.95, 0.95)``. + adamw_eps (float): Epsilon for AdamW. Default: ``1e-8``. + micro_batch_num (int): Number of micro batches. Default: ``1``. + qk_clip_threshold (float): QK clip threshold. Default: ``4``. + model: The model model. Default: ``None``. + """ + + def __init__( + self, + params, + learning_rate=2e-2, + weight_decay=0.1, + matched_adamw_rms=0.2, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + micro_batch_num=1, + qk_clip_threshold=4, + model=None, + ): + super().__init__(learning_rate, params, weight_decay) + + self._verify_model(model) + + # Initialize basic parameters + self._initialize_basic_params(adamw_betas, adamw_eps, momentum, matched_adamw_rms, nesterov) + + # Initialize model configuration + self._initialize_network_config(model) + + # Initialize parameter layers + self._initialize_param_layers(model) + + # Initialize QK-clip parameters + self.ones = Tensor([1.0], mstype.float32) + self.rank_id = get_rank() + self.rank_ids = tuple(self.rank_id for _ in self._parameters) + self.logit_threshold = Tensor([qk_clip_threshold * micro_batch_num], dtype=mstype.float32) + + # Initialize Muon momentum + self._initialize_muon_moments(model) + + # Initialize tensor parallel dimensions + self._initialize_tp_dims(model) + + # Initialize AdamW moments + self._initialize_adamw_moments(model) + + # Initialize parallel configuration + self._initialize_parallel_config(model) + + # Initialize communication groups + self._initialize_communication_groups() + + # Initialize optimizer parallel groups + self._initialize_op_groups(model) + + # Store model for QK-clip + self.model = model + self.ns_steps = ns_steps + + def _verify_model(self, model): + """Verify if the model is compatible with Muon optimizer.""" + if model is None: + raise ValueError("Model must be provided for Muon optimizer.") + + if is_legacy_model(): + raise ValueError("Muon does not support Legacy Model.") + + config = model.get_gpt_transformer_config() + + if not config.multi_latent_attention: + raise ValueError("Current Muon implementation only supports models with Multi-Latent Attention enabled.") + + def _initialize_basic_params(self, adamw_betas, adamw_eps, momentum, matched_adamw_rms, nesterov): + """Initialize basic optimizer parameters.""" + self.beta1 = Tensor(np.array([adamw_betas[0]]).astype(np.float32)) + self.beta2 = Tensor(np.array([adamw_betas[1]]).astype(np.float32)) + self.eps = Tensor(np.array([adamw_eps]).astype(np.float32)) + self.muon_momentum = Tensor(np.array([momentum]).astype(np.float32)) + self.matched_adamw_rms = Tensor(np.array([matched_adamw_rms]).astype(np.float32)) + self.use_nesterov = tuple(nesterov for _ in self._parameters) + self.param_name_tuple = tuple(p.name for p in self._parameters) + + def _initialize_network_config(self, model): + """Initialize Model configuration and split/merge functions.""" + + self.muon_split_fn, self.muon_merge_fn = model.make_model_muon_fns() + self.muon_split_fns = tuple(self.muon_split_fn for _ in self._parameters) + self.muon_merge_fns = tuple(self.muon_merge_fn for _ in self._parameters) + + def _initialize_param_layers(self, model): + """Initialize parameter layer indices.""" + self.param_layer = model.get_param_layer_indices(self._parameters) + + def _initialize_muon_moments(self, model): + """Initialize Muon momentum parameters.""" + muon_filter = model.get_muon_filter() + + self.muon_m = [] + self.param_idx_in_opt = {} + for idx, param in enumerate(self._parameters): + self.param_idx_in_opt[param.name] = idx + + for param in self._parameters: + if muon_filter(param): + x1 = param.clone("zeros") + x1.name = "muon_m" + "." + x1.name + self.muon_m.append(x1) + logger.info(f"Muon apply: {param}") + else: + self.muon_m.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="muon_m." + param.name)) + self.muon_m = ParameterTuple(self.muon_m) + self.use_muon = tuple(muon_filter(param) for param in self._parameters) + + def _initialize_tp_dims(self, model): + """Initialize tensor parallel dimensions.""" + self.tp_dims = model.get_tp_dims(self._parameters) + + def _initialize_adamw_moments(self, model): + """Initialize AdamW momentum parameters.""" + muon_filter = model.get_muon_filter() + + self.moments1 = [] + self.moments2 = [] + for param in self._parameters: + if not muon_filter(param): + x1 = param.clone("zeros") + x1.name = "adam_m" + "." + x1.name + self.moments1.append(x1) + x2 = param.clone("zeros") + x2.name = "adam_v" + "." + x2.name + self.moments2.append(x2) + logger.info(f"Adam apply: {param}") + else: + self.moments1.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="adam_m." + param.name)) + self.moments2.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="adam_v." + param.name)) + self.moments1 = ParameterTuple(self.moments1) + self.moments2 = ParameterTuple(self.moments2) + + def _initialize_parallel_config(self, model): + """Initialize parallel configuration.""" + self.tp = model.get_gpt_transformer_config().tensor_model_parallel_size + self.tps = tuple(self.tp for _ in self._parameters) + logger.info(f"Muon tp group size is: {self.tp}") + + if not get_auto_parallel_context('enable_parallel_optimizer'): + self.op = 1 + else: + self.op = get_auto_parallel_context('optimizer_weight_shard_size') + if self.op == -1: + raise ValueError( + "Must set parallel.parallel_optimizer_config.optimizer_weight_shard_size when using Muon") + logger.info(f"Muon op group size is: {self.op}") + + def _initialize_communication_groups(self): + """Initialize communication groups for parallel training.""" + self.tp_group = self._get_tp_group_name(self.rank_id, self.tp) + self.op_group, self.op_in_tp_group = self._get_op_group_name(self.rank_id, self.tp, self.op, self.tp_group) + self.tp_groups = tuple(self.tp_group for _ in self._parameters) + + def _initialize_op_groups(self, model): + """Initialize optimizer parallel groups for parameters.""" + self.ops, self.op_groups = model.get_op_groups_info(self._parameters, self.op) + + def _create_communication_group(self, rank_list): + """ + Create a communication group with a hashed name. + + Args: + rank_list: List of ranks in the communication group + + Returns: + str: The created group name + """ + rank_list_str = "-".join([str(i) for i in rank_list]) + hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] + group_name = str(hashed) + create_group(group_name, rank_list) + return group_name + + def _get_op_group_name(self, rank_id, tp, op, tp_group): + """ + Generates a unique group name for optimizer parallel communication group. + + Returns: + tuple: The optimizer group name and optimizer-in-tensor-parallel group name + """ + dp_range = tp + op_range = tp * op + rank_start = rank_id % dp_range + rank_id // op_range * op_range + rank_end = rank_start + op_range + rank_list = list(range(rank_start, rank_end, dp_range)) + logger.info(f"Muon op group list is: {rank_list}") + op_group_name = self._create_communication_group(rank_list) + + if tp == op: + logger.info( + f"op_in_tp group will reuse tp group" \ + f", since tensor_parallel_size({tp}) == optimizer_parallel_size({op})." + ) + op_in_tp_group_name = tp_group + else: + logger.info(f"Muon op_in_tp group list is: {rank_list}") + op_in_tp_group_name = self._get_tp_group_name(rank_id, op) + + return op_group_name, op_in_tp_group_name + + def _get_tp_group_name(self, rank_id, tp): + """ + Generates a unique group name for tensor parallel communication group. + + Returns: + str: The tensor parallel group name + """ + rank_start = rank_id // tp * tp + rank_end = rank_id // tp * tp + tp + rank_list = list(range(rank_start, rank_end)) + logger.info(f"Muon tp group list is: {rank_list}") + tp_group_name = self._create_communication_group(rank_list) + return tp_group_name + + @jit(backend="ms_backend") + def construct(self, gradients): + """Construct method for optimizer. + + Args: + gradients: Gradients for optimization. + + Returns: + Updated gradients after optimization. + """ + gradients = self.flatten_gradients(gradients) + weight_decay = self.get_weight_decay() + lr = self.get_lr() + self.assignadd(self.global_step, self.global_step_increase_tensor) + optim_result = self.hyper_map( + F.partial( + _muon_opt, + self.muon_momentum, + self.matched_adamw_rms, + self.beta1, + self.beta2, + self.global_step, + self.eps, + lr, + ), + weight_decay, + self.rank_ids, + self._parameters, + self.moments1, + self.moments2, + gradients, + self.muon_m, + self.tps, + self.ops, + self.tp_dims, + self.use_muon, + self.use_nesterov, + self.optim_filter, + self.op_groups, + self.tp_groups, + self.param_name_tuple, + self.muon_split_fns, + self.muon_merge_fns, + ) + + updates = self.model.apply_qk_clip_scaling( + self._parameters, + self.param_name_tuple, + self.param_layer, + self.logit_threshold, + self.muon_split_fn, + self.muon_merge_fn, + ) + + # Apply the weight updates + for param_idx, weights in updates: + optim_result = F.depend(optim_result, F.assign(self._parameters[param_idx], weights)) + + return optim_result diff --git a/mindformers/core/optim/muon_utils.py b/mindformers/core/optim/muon_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e23dfbf7faf7ee9c7714bdbab5ec3c88d4a28ede --- /dev/null +++ b/mindformers/core/optim/muon_utils.py @@ -0,0 +1,322 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Muon utils""" + +import math +from fnmatch import fnmatch +from mindspore import ops as P +from mindspore.ops.operations import Morph +from mindspore import nn + + +class BlockSplitReshape(nn.Cell): + """ + Reshape tensor by splitting its last dimension into blocks. + + This operation takes a tensor and splits its last dimension into equal-sized blocks, + adding a new dimension for the block index. + + Args: + block: Block size for splitting the last dimension. + """ + + def __init__( + self, + block + ): + super().__init__() + self.block = block + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + *prefix, dim = args[0] + t = prefix + [dim // self.block, self.block] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class TensorReshapeTo3D(nn.Cell): + """ + Reshape tensor to 3D with specified middle and last dimensions. + + This operation reshapes a tensor to a 3-dimensional tensor where the first dimension + is automatically calculated from the total size, and the last two dimensions are fixed. + + Args: + dim1: The second dimension (middle dimension) of the output 3D tensor. + dim2: The third dimension (last dimension) of the output 3D tensor. + """ + + def __init__( + self, + dim1, + dim2, + ): + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + t = [total // (self.dim1 * self.dim2), self.dim1, self.dim2] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class PrefixDimensionReshape(nn.Cell): + """ + Reshape tensor with fixed prefix dimensions and calculated last dimension. + + This operation reshapes a tensor by specifying the leading (prefix) dimensions, + while the last dimension is automatically calculated from the total size. + + Args: + *prefix: Variable number of prefix dimensions for the output tensor shape. + """ + + def __init__( + self, + *prefix + ): + self.prefix = list(prefix) + super().__init__() + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + prefix_total = math.prod(self.prefix) + t = self.prefix + [total // prefix_total] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class TensorReshapeTo2D(nn.Cell): + """ + Reshape tensor to 2D with specified last dimension. + + This operation flattens a tensor to a 2-dimensional tensor where the last dimension + is fixed and the first dimension is automatically calculated from the total size. + + Args: + dim: The second dimension (last dimension) of the output 2D tensor. + """ + + def __init__( + self, + dim + ): + self.dim = dim + super().__init__() + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + t = [total // self.dim, self.dim] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +def muon_split(tensor, part_a: int, part_b: int, num_blocks: int): + """ + Split a 2D tensor into two periodic parts along its first dimension. + The split pattern repeats every (part_a + part_b) elements. + + Args: + tensor: Input tensor of shape (M, N). + part_a: Number of elements in the first part of each block. + part_b: Number of elements in the second part of each block. + num_blocks: Total number of (part_a + part_b) blocks. + + Returns: + A tuple of two tensors (first_part, second_part), + where: + - first_part contains all part_a segments of each block. + - second_part contains all part_b segments of each block. + """ + tensor = tensor.T + *prefix, _ = tensor.shape + block = part_a + part_b + t = BlockSplitReshape(block)(tensor, (*prefix, -1, block)) + + first_part = PrefixDimensionReshape(*prefix)(t[..., :part_a], (*prefix, -1)).T + second_part = PrefixDimensionReshape(*prefix)(t[..., part_a:], (*prefix, -1)).T + return first_part, second_part + + +def muon_merge(tensor_a, tensor_b, part_a: int, part_b: int, num_blocks: int): + """ + Merge two tensors back into the original periodic layout + that was split by muon_split(). + + Args: + tensor_a: Tensor containing the first part of each block. + tensor_b: Tensor containing the second part of each block. + part_a: Number of elements in the first part of each block. + part_b: Number of elements in the second part of each block. + num_blocks: Total number of (part_a + part_b) blocks. + + Returns: + A single tensor of the same shape as before muon_split(). + """ + tensor_a = tensor_a.T + tensor_b = tensor_b.T + *prefix, _ = tensor_a.shape + + a = BlockSplitReshape(part_a)(tensor_a, (*prefix, -1, part_a)) + b = BlockSplitReshape(part_b)(tensor_b, (*prefix, -1, part_b)) + t = P.Concat(axis=-1)([a, b]) + out = PrefixDimensionReshape(*prefix)(t, (*prefix, -1)).T + return out + + +def _eval_tuple(spec, name, tensor): + return spec(name, tensor) if callable(spec) else spec + + +def make_muon_fns(schema): + """ + Generate two generic functions: + - split_one(param_name, tensor) -> List[tensor] + - merge_one(param_name, parts_list) -> tensor + + Dimensions in schema should be either numbers or callback functions: + - periodic: rule["parts"] = (a, b, num_blocks) or lambda(name, tensor)->(a,b,blocks) + - reshape_* : rule["reshape"] = (x, y, z) or lambda(name, tensor)->(x,y,z) + """ + + def split_fn(param_name, tensor): + """ + Input a 2D tensor, split it according to schema rules, and return several segments (List[tensor]). + """ + + for rule in schema: + if not any(fnmatch(param_name, pat) for pat in rule["patterns"]): + continue + + kind = rule["kind"] + + if kind == "periodic": + part_a, part_b, num_blocks = _eval_tuple(rule["parts"], param_name, tensor) + first_part, second_part = muon_split(tensor, part_a, part_b, num_blocks) + return [first_part, second_part] + + if kind == "reshape_concat": + # e.g. experts.weight1: first reshape to [E, H, 2I], then split into two halves along the last dimension + _, hidden_size, total_intermediate = _eval_tuple(rule["reshape"], param_name, tensor) + half_intermediate = total_intermediate // 2 + t3 = TensorReshapeTo3D(hidden_size, total_intermediate)(tensor, (-1, hidden_size, total_intermediate)) + return [t3[..., :half_intermediate], t3[..., half_intermediate:]] + + if kind == "reshape_only": + # e.g. experts.weight2: just reshape to [E, I, H], no split + _, intermediate_size, hidden_size = _eval_tuple(rule["reshape"], param_name, tensor) + return [TensorReshapeTo3D(intermediate_size, hidden_size)(tensor, (-1, intermediate_size, hidden_size))] + + if kind == "alt_pair_periodic": + # Alternating rows 1,1 (blocks = M//2) + num_blocks = tensor.shape[0] // 2 + a, b = muon_split(tensor, 1, 1, num_blocks) + return [a, b] + + # Default: no processing, return as whole block + return [tensor] + + def merge_fn(param_name, parts_list): + """ + Merge the output of split_one (List[tensor]) back to 2D according to the same rules. + """ + concat = P.Concat(axis=-1) + + for rule in schema: + if not any(fnmatch(param_name, pat) for pat in rule["patterns"]): + continue + + kind = rule["kind"] + + if kind == "periodic": + part_a, part_b, num_blocks = _eval_tuple(rule["parts"], param_name, parts_list[0]) + # Convention: periodic always has two segments + return muon_merge(parts_list[0], parts_list[1], part_a, part_b, num_blocks) + + if kind == "reshape_concat": + _, hidden_size, total_intermediate = _eval_tuple(rule["reshape"], param_name, parts_list[0]) + cat = concat([parts_list[0], parts_list[1]]) # [..., I] + [..., I] -> [..., 2I] + return TensorReshapeTo2D(total_intermediate)(cat, (-1, total_intermediate)) + + if kind == "reshape_only": + _, _, hidden_size = _eval_tuple(rule["reshape"], param_name, parts_list[0]) + # Only one segment, directly restore to 2D + return TensorReshapeTo2D(hidden_size)(parts_list[0], (-1, hidden_size)) + + if kind == "alt_pair_periodic": + num_blocks = parts_list[0].shape[0] # 1 row per block + return muon_merge(parts_list[0], parts_list[1], 1, 1, num_blocks) + + # Default: directly take the first segment + return parts_list[0] + + return split_fn, merge_fn diff --git a/mindformers/models/deepseek3/modeling_deepseek_v3_train.py b/mindformers/models/deepseek3/modeling_deepseek_v3_train.py index 7e48bbae01e8dfa9466a13fb254db4cf12483504..ae7197a4b808100c72a6d28d42637c39b4df101f 100644 --- a/mindformers/models/deepseek3/modeling_deepseek_v3_train.py +++ b/mindformers/models/deepseek3/modeling_deepseek_v3_train.py @@ -90,3 +90,10 @@ class TrainingDeepseekV3ForCausalLM(TrainModelMixin, DeepseekV3PreTrainedModel): update topk bias and reset expert_load of router in MoELayers. """ return self.model.update_topk_bias(gradient_accumulation_steps) + + def reset_max_attention_logit(self,): + """ + Will be called by mindformer.core.callback.TopkBiasBalanceCallback to + update topk bias and reset expert_load of router in MoELayers. + """ + return self.model.reset_max_attention_logit() diff --git a/mindformers/parallel_core/mf_model_config.py b/mindformers/parallel_core/mf_model_config.py index 3457186019f9a57736a62c97a5011b48290f4f44..069b177d67dd74e3cc38b431475b1d0372e31bb1 100644 --- a/mindformers/parallel_core/mf_model_config.py +++ b/mindformers/parallel_core/mf_model_config.py @@ -268,6 +268,9 @@ class MFModelConfig: mask_func_type: str = "attn_mask_fill" """Mask function type to use for the attention layer.""" + monitor_max_attention_logit: bool = False + """Whether to monitor the maximum attention logit value during training.""" + #################################################### # MoE Configuration Items For MindSpore Transformers #################################################### diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 55fcd02e64177c4b97fcf74bc9a66f9a5b09fa7a..6ee0be8dd03cc95100c03fa1d0a546e297c66253 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -15,10 +15,12 @@ """mindformers GPT model""" __all__ = ['GPTModel'] +import hashlib from typing import Literal, Optional, Union import numpy as np import mindspore as ms +from mindspore.communication import create_group, get_group_size, get_rank from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops import auto_generate as aclnn_ops @@ -26,6 +28,7 @@ from mindspore.ops.operations import Morph from mindspore import Tensor, dtype, nn from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindspore import ops from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock @@ -51,6 +54,61 @@ from mindformers.parallel_core.inference.parallel_state import initialize_model_ from mindformers.tools.logger import logger from mindformers.models.utils import get_current_rank_stage, get_model_parameters from mindformers.version_control import get_lazy_inline as lazy_inline +from mindformers.core.optim.muon_utils import make_muon_fns +from mindformers.checkpoint.sharded_tensor import ShardedTensor + + +def compute_repeat_num_and_model_parallel_size(sharded_info: ShardedTensor, world_size: int, pp: int, op: int): + """Compute real op size.""" + axis_fragmentations = sharded_info.axis_fragmentations + flag = False + weight_sharded_size = 1 + for axis in axis_fragmentations: + if axis == 1: + continue + if flag: + raise ValueError("Only one axis can be fragmented in Muon optimizer.") + flag = True + weight_sharded_size *= axis + repeat_num = world_size // pp // weight_sharded_size + real_op_size = min(op, repeat_num) + if sharded_info.local_shape[0] % real_op_size != 0: + real_op_size = 1 + return real_op_size, weight_sharded_size + + +def create_communication_group(rank_list): + """ + Create a communication group with a hashed name. + + Args: + rank_list: List of ranks in the communication group + + Returns: + str: The created group name + """ + rank_list_str = "-".join([str(i) for i in rank_list]) + hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] + group_name = str(hashed) + create_group(group_name, rank_list) + return group_name + + +OP_GROUP_NAME = {} + + +def get_op_group_name(rank_id: int, real_op_size: int, model_parallel_size: int): + """Get op group name.""" + if (rank_id, real_op_size, model_parallel_size) in OP_GROUP_NAME: + return OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] + dp_range = model_parallel_size + op_range = model_parallel_size * real_op_size + rank_start = rank_id % dp_range + rank_id // op_range * op_range + rank_end = rank_start + op_range + rank_list = list(range(rank_start, rank_end, dp_range)) + op_group_name = create_communication_group(rank_list) + OP_GROUP_NAME[(rank_id, real_op_size, model_parallel_size)] = (op_group_name, rank_list) + return op_group_name, rank_list def func_infer_dtype(*args): @@ -634,6 +692,47 @@ class GPTModel(nn.Cell): if hasattr(self.decoder.layers[i].mlp, "router"): self.assign_aux(self.decoder.layers[i].mlp.router.fi_accu, self.zeros_tensor) + def _iter_core_attentions(self): + """Iterate over all core_attention modules with their param names. + + Yields: + Tuple[str, module]: A tuple of (param_name, core_attention_module). + """ + num_layers = self.config.num_layers + mtp_num_layers = self.config.mtp_num_layers + + for i in range(num_layers): + core_attn = self.decoder.layers[i].self_attention.core_attention + yield f"decoder.layers.{i}.self_attention.core_attention", core_attn + + for i in range(mtp_num_layers): + core_attn = self.mtp.layers[i].transformer_layer.self_attention.core_attention + yield f"mtp.layers.{i}.transformer_layer.self_attention.core_attention", core_attn + + def get_max_attention_logit(self): + """Get max attention logit values for all layers. + + Returns: + dict: A dictionary mapping parameter names to their max logit values. + Only includes layers with valid (sum > 0) max_logits_val. + """ + max_logits = {} + for param_name, core_attn in self._iter_core_attentions(): + if not hasattr(core_attn, "max_logits_val"): + continue + param = core_attn.max_logits_val.value() + if param.sum() <= 0: + continue + max_logits[f"{param_name}.max_logits_val"] = param + return max_logits + + def reset_max_attention_logit(self): + """Reset max attention logit to zeros for all layers.""" + for _, core_attn in self._iter_core_attentions(): + if hasattr(core_attn, "max_logits_val"): + param = core_attn.max_logits_val + F.assign(param, F.zeros_like(param)) + def shard(self, config: TransformerConfig): """parallel shard.""" dp = config.data_parallel_size @@ -683,6 +782,14 @@ class GPTModel(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Get all sharded state dict.""" + sharded_state_dict = {} + for _, sub_cell in self.cells_and_names(): + if sub_cell != self and hasattr(sub_cell, "sharded_state_dict"): + sharded_state_dict.update(sub_cell.sharded_state_dict()) + return sharded_state_dict + def get_model_parameters(self): """Get current rank trainable parameters in gpt model .""" params = set() @@ -690,10 +797,309 @@ class GPTModel(nn.Cell): if ms.get_auto_parallel_context('pipeline_stages') > 1: if current_pipeline_stage == self.output_layer.pipeline_stage: params.update(get_model_parameters(self.output_layer)) - if current_pipeline_stage == self.mtp.pipeline_stage: - params.update(get_model_parameters(self.mtp)) + if hasattr(self, "mtp"): + if current_pipeline_stage == self.mtp.pipeline_stage: + params.update(get_model_parameters(self.mtp)) params.update(self.decoder.get_model_parameters()) params.update(get_model_parameters(self.embedding)) else: params.update(get_model_parameters(self)) return params + + def make_model_muon_fns(self,): + """Read values from TransformersConfig and generate schema.""" + + num_moe_experts = self.config.num_moe_experts + hidden_size = self.config.hidden_size + moe_ffn_hidden_size = self.config.moe_ffn_hidden_size + qk_head_dim = self.config.qk_head_dim + qk_pos_emb_head_dim = self.config.qk_pos_emb_head_dim + num_attention_heads = self.config.num_attention_heads + kv_lora_rank = self.config.kv_lora_rank + value_head_dim = self.config.v_head_dim + + schema = [ + # experts.weight1: reshape → split into two [num_moe_experts, hidden_size, moe_ffn_hidden_size] + { + "patterns": ["*mlp.experts.weight1*"], + "kind": "reshape_concat", + "reshape": (num_moe_experts, hidden_size, 2 * moe_ffn_hidden_size), + }, + # experts.weight2: reshape → [num_moe_experts, moe_ffn_hidden_size, hidden_size] + { + "patterns": ["*mlp.experts.weight2*"], + "kind": "reshape_only", + "reshape": (num_moe_experts, moe_ffn_hidden_size, hidden_size), + }, + # q_proj / q_up_proj: periodic split across heads + { + "patterns": [ + "*self_attention.linear_q_proj.weight*", + "*self_attention.linear_q_up_proj.weight*", + ], + "kind": "periodic", + "parts": (qk_head_dim, qk_pos_emb_head_dim, num_attention_heads), + }, + # kv_down_proj: one block + { + "patterns": ["*self_attention.linear_kv_down_proj.weight*"], + "kind": "periodic", + "parts": (kv_lora_rank, qk_pos_emb_head_dim, 1), + }, + # kv_up_proj: periodic split across heads + { + "patterns": ["*self_attention.linear_kv_up_proj.weight*"], + "kind": "periodic", + "parts": (qk_head_dim, value_head_dim, num_attention_heads), + }, + # fc1 and shared_fc1: alternating 1,1 split along rows + { + "patterns": [ + "*mlp.shared_experts.linear_fc1.weight*", + "*mlp.linear_fc1.weight*", + ], + "kind": "alt_pair_periodic", + }, + ] + + return make_muon_fns(schema) + + def get_muon_filter(self): + """Return a filter function to determine if a parameter should use Muon optimization. + + Returns: + A function that takes a parameter and returns True if it should use Muon. + """ + def muon_filter(param): + return ( + (len(param.shape) == 2 or len(param.shape) == 3) + and "word_embeddings" not in param.name + and "output_layer" not in param.name + ) + return muon_filter + + def get_tp_dims(self, params): + """Return tensor parallel dimensions for each parameter. + + Args: + params: List of parameters from the optimizer. + + Returns: + Tuple of TP dimensions for each parameter. + """ + no_tp_list = [ + "linear_q_down_proj", + "linear_kv_down_proj", + "shared_experts", + "mlp.router", + "hnorm.weight", "enorm.weight", "eh_proj.weight", + ] + + tp_dim_1_list = [ + "self_attention.linear_proj.weight", + "mlp.linear_fc2.weight" + ] + + def name_filter(param_name, full_name_list): + for full_name in full_name_list: + if full_name in param_name: + return True + return False + + tp_dims = [] + for param in params: + if name_filter(param.name, tp_dim_1_list): + tp_dims.append(1) + elif name_filter(param.name, no_tp_list): + tp_dims.append(-1) + else: + tp_dims.append(0) + return tuple(tp_dims) + + def get_op_groups_info(self, params, op): + """Return optimizer parallel group information for each parameter. + + Args: + params: List of parameters from the optimizer. + op: Optimizer parallel size. + tp_group: Tensor parallel group name. + + Returns: + Tuple of (ops, op_groups) where: + - ops: tuple of op values for each parameter + - op_groups: tuple of group names for each parameter + """ + no_op_list = [ + "self_attention.linear_q_proj.weight", + "self_attention.linear_q_up_proj.weight", + "self_attention.linear_q_down_proj.weight", + "self_attention.linear_kv_up_proj.weight", + "self_attention.linear_kv_down_proj.weight", + "eh_proj", + "max_logits_val" + ] + + sharded_state_dict = self.sharded_state_dict() + world_size = get_group_size() + ep = self.config.expert_model_parallel_size + pp = self.config.pipeline_model_parallel_size + + def name_filter(param_name, full_name_list): + for full_name in full_name_list: + if full_name in param_name: + return True + return False + + op_list = [] + op_groups = [] + + for param in params: + if name_filter(param.name, no_op_list): + op_list.append(1) + + op_groups.append("") + if param.parallel_optimizer: + param.parallel_optimizer = False + logger.warning( + f"Parameter {param.name}: parallel_optimizer was set to False due to the use of Muon optimizer." + ) + continue + + # compute real op size + sharded_info = sharded_state_dict.get(param.name) + real_op_size, weight_sharded_size = compute_repeat_num_and_model_parallel_size(sharded_info, world_size, pp, + op) + if real_op_size == 1: + op_list.append(1) + op_groups.append("") + logger.info(f"Parameter {param.name} : No op group.") + continue + + op_list.append(real_op_size) + op_group_name, rank_list = get_op_group_name(get_rank(), real_op_size, weight_sharded_size) + logger.info(f"Parameter {param.name} : Muon real_op_size={real_op_size} group list is: {rank_list}") + op_groups.append(op_group_name) + + # check if op is valid for expert + for param, real_op_size in zip(params, op_list): + if "mlp.experts.weight1" not in param.name: + continue + # Validate MoE expert counts divisibility constraint: + # num_moe_experts must be divisible by (optimizer_weight_shard_size * expert_model_parallel_size) + num_moe_experts = self.config.num_moe_experts + if bool(num_moe_experts and num_moe_experts > 0): + if num_moe_experts % (real_op_size * ep) != 0: + error_msg = (f"Invalid configuration: 'num_moe_experts' ({num_moe_experts}) must be divisible by " + f"'real_op_size * expert_model_parallel_size' ({real_op_size} * " + f"{ep} = {real_op_size * ep}).\n" + f"Hint:\n" + f" Although you set `optimizer_weight_shard_size={op}`, the maximum optimizer shard size " + f"for `{param.name}` is `{real_op_size}`. Try reducing 'optimizer_weight_shard_size'.") + logger.error(error_msg) + raise ValueError( + error_msg + ) + # All expert weights share the same real_op_size, so we only need to check once + break + + return tuple(op_list), tuple(op_groups) + + def get_param_layer_indices(self, params): + """Return layer indices for each parameter (used for QK-clip). + + Args: + params: List of parameters from the optimizer. + + Returns: + Tuple of layer indices for each parameter, where: + - layer_idx >= 0 stands for the layer_idx-th decoder layer + - layer_idx < 0 stands for the -(layer_idx+1)-th MTP layer + """ + param_layer = [] + for param in params: + name = param.name + try: + layer_idx = int(name.split(".")[2]) + except (ValueError, IndexError): + layer_idx = 0 + if name.startswith('mtp'): + layer_idx = -layer_idx - 1 + param_layer.append(layer_idx) + return tuple(param_layer) + + def apply_qk_clip_scaling(self, params, param_names, param_layer, logit_threshold, + muon_split_fn, muon_merge_fn): + """Apply QK-clip scaling to attention weight parameters. + + Args: + params: List of all parameters. + param_names: Tuple of parameter names. + param_layer: Tuple of layer indices for each parameter. + logit_threshold: Threshold for logit clipping. + muon_split_fn: Function to split parameters. + muon_merge_fn: Function to merge parameters. + + Returns: + List of (param_idx, scaled_weights) tuples to be updated. + """ + if not self.config.multi_latent_attention: + return [] + ones = ms.Tensor([1.0], dtype.float32) + qk_head_dim = self.config.qk_head_dim + qk_pos_emb_head_dim = self.config.qk_pos_emb_head_dim + + def get_scale_broadcast(scales, head_dim): + scale_broadcast = ops.tile(ops.expand_dims(scales, 1), (1, head_dim)).reshape(-1) + scale_broadcast = ops.expand_dims(scale_broadcast, 1) + return scale_broadcast + + # Build param name to index mapping + param_idx_in_opt = {name: idx for idx, name in enumerate(param_names)} + + updates = [] + for idx, param_name in enumerate(param_names): + if ( + "self_attention.linear_q_proj.weight" not in param_name + and "self_attention.linear_q_up_proj.weight" not in param_name + and "self_attention.linear_kv_up_proj.weight" not in param_name + ): + continue + + layer_idx = param_layer[idx] + param = params[idx] + + # Compute per-head scale factor + logit_threshold_f32 = ops.cast(logit_threshold, dtype=dtype.float32) + if layer_idx >= 0: + max_logits_name = (f"decoder.layers.{layer_idx}.self_attention." + "core_attention.max_logits_val") + else: + max_logits_name = (f"mtp.layers.{-(layer_idx+1)}.transformer_layer." + "self_attention.core_attention.max_logits_val") + + if max_logits_name not in param_idx_in_opt: + continue + + logits_row = params[param_idx_in_opt[max_logits_name]].reshape(-1) + mask = ops.greater_equal(logits_row, logit_threshold_f32) + safe_den = ops.where(mask, logits_row, ones) + scales = ops.where(mask, logit_threshold_f32 / safe_den, ones) + + weights = None + if ( + "self_attention.linear_q_proj.weight" in param_name + or "self_attention.linear_q_up_proj.weight" in param_name + ): + l2q_nope_proj, l2q_pe_proj = muon_split_fn(param_name, param) + l2q_nope_proj *= get_scale_broadcast(ops.sqrt(scales), qk_head_dim) + l2q_pe_proj *= get_scale_broadcast(scales, qk_pos_emb_head_dim) + weights = muon_merge_fn(param_name, [l2q_nope_proj, l2q_pe_proj]) + elif "self_attention.linear_kv_up_proj.weight" in param_name: + lkv2kv_k_nope, lkv2kv_v = muon_split_fn(param_name, param) + lkv2kv_k_nope *= get_scale_broadcast(ops.sqrt(scales), qk_head_dim) + weights = muon_merge_fn(param_name, [lkv2kv_k_nope, lkv2kv_v]) + + if weights is not None: + updates.append((idx, weights)) + + return updates diff --git a/mindformers/parallel_core/training_graph/tensor_parallel/layers.py b/mindformers/parallel_core/training_graph/tensor_parallel/layers.py index 9979dabc3d40b36546225f6204766d4df2ee8029..5bba189dbb947d5842ef6afc92f161ab19daec18 100644 --- a/mindformers/parallel_core/training_graph/tensor_parallel/layers.py +++ b/mindformers/parallel_core/training_graph/tensor_parallel/layers.py @@ -36,6 +36,7 @@ from mindspore.ops.operations import Morph from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore import mint +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.utils.init_method import init_method_zero from mindformers.parallel_core.inference.utils import divide @@ -207,6 +208,28 @@ class VocabParallelEmbedding(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.num_embeddings, self.embedding_dim) + + if self.enable_embedding_tp: + axis_fragmentations = (self.tp, 1) + local_shape = (self.num_embeddings // self.tp, self.embedding_dim) + else: + axis_fragmentations = (1, 1) + local_shape = (self.num_embeddings, self.embedding_dim) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=axis_fragmentations) + return sharded_state_dict + class ColumnParallelLinear(nn.Cell): """Linear layer with column parallelism. @@ -427,6 +450,33 @@ class ColumnParallelLinear(nn.Cell): matmul_in_strategy = ((dp * cp, 1), weight_strategy) self.matmul.shard(in_strategy=matmul_in_strategy) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + tp = self.config.tensor_model_parallel_size + + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size // tp, self.input_size) + if not self.skip_weight_param_allocation: + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(tp, 1)) + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class RowParallelLinear(nn.Cell): """Linear layer with row parallelism. @@ -663,6 +713,33 @@ class RowParallelLinear(nn.Cell): matmul_in_strategy = ((dp * cp, tp), weight_strategy) self.matmul.shard(in_strategy=matmul_in_strategy) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + tp = self.config.tensor_model_parallel_size + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size // tp) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(1, tp)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class LinearNoTP(ColumnParallelLinear): """Linear layer without tensor parallelism. @@ -712,6 +789,32 @@ class LinearNoTP(ColumnParallelLinear): ) ) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_shape=weight_shape, + global_offset=(0, 0), + axis_fragmentations=(1, 1)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_shape=(self.output_size,), + global_offset=(0,), + axis_fragmentations=(1,)) + return sharded_state_dict + class SequenceParallelLinear(ColumnParallelLinear): """Linear layer without tensor parallelism. @@ -761,3 +864,29 @@ class SequenceParallelLinear(ColumnParallelLinear): layout(("cp", "tp"), "dp", "None"), # output [S, B, H] ) ) + + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + weight_shape = (self.output_size, self.input_size) + local_shape = (self.output_size, self.input_size) + + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=local_shape, + global_offset=(0, 0), + global_shape=weight_shape, + axis_fragmentations=(1, 1)) + + if self.has_bias: + sharded_state_dict[self.bias.name] = ShardedTensor( + key=self.bias.name, + org_key=self.bias.name, + dtype=self.bias.dtype, + local_shape=(self.output_size,), + global_offset=(0,), + global_shape=(self.output_size,), + axis_fragmentations=(1,)) + return sharded_state_dict diff --git a/mindformers/parallel_core/training_graph/transformer/flash_attention.py b/mindformers/parallel_core/training_graph/transformer/flash_attention.py index ec55e0ff2076850db68d3c3c25b716949c6a038f..b3b588d53522d48982f0c96689cc76aaa988d8c4 100644 --- a/mindformers/parallel_core/training_graph/transformer/flash_attention.py +++ b/mindformers/parallel_core/training_graph/transformer/flash_attention.py @@ -16,10 +16,11 @@ __all__ = ['FlashAttention'] import math +import numpy as np import mindspore.common.dtype as mstype import mindspore as ms -from mindspore import ops, ParallelMode +from mindspore import ops, ParallelMode, Parameter from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.ops import auto_generate as aclnn_ops @@ -90,7 +91,7 @@ class FlashAttention(Cell): softmax_scale: float = None, cp_comm_type: str = None, ): - super(FlashAttention, self).__init__() + super().__init__() # FA (Flash Attention) is an optimized version of DotProductAttention in Megatron v0.12.0, # with nearly identical computational precision. @@ -159,6 +160,18 @@ class FlashAttention(Cell): self.reshape = aclnn_ops.Reshape() self.fa_out_transpose = aclnn_ops.Transpose() + self.monitor_max_attention_logit = self.config.monitor_max_attention_logit + + if self.monitor_max_attention_logit: + self.max_logits_val = Parameter( + Tensor(np.zeros((self.head_num)), dtype=mstype.float32), + parallel_optimizer=False, requires_grad=False + ) + self.reduce_max = aclnn_ops.ReduceMax() + self.reduce_max.add_prim_attr("self_define_shard", True) + self.assign_add = ops.AssignAdd() + self.assign_add.add_prim_attr("self_define_shard", True) + if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.sharding_propagation(config) elif _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,): @@ -249,7 +262,7 @@ class FlashAttention(Cell): attention_mask = cast(attention_mask, ms.uint8) if self.input_layout == "TND": - _, _, _, output = self.flash_attention(query, + softmax_val, _, _, output = self.flash_attention(query, key, value, alibi_mask, @@ -259,6 +272,9 @@ class FlashAttention(Cell): prefix, actual_seq_qlen, actual_seq_kvlen) + if self.monitor_max_attention_logit: + max_logits = self.reduce_max(softmax_val, (0, 2)) + output = F.depend(output, self.assign_add(self.max_logits_val, max_logits)) return output q_seq_len, bsz = query.shape[:2] @@ -283,7 +299,7 @@ class FlashAttention(Cell): drop_mask_bits = None if self.use_alibi_mask: alibi_mask = self.alibi_rescale_mul(alibi_mask, F.cast(self.alibi_rescale_factor, alibi_mask.dtype)) - _, _, _, output = self.flash_attention(query, + softmax_val, _, _, output = self.flash_attention(query, key, value, alibi_mask, @@ -291,6 +307,10 @@ class FlashAttention(Cell): padding_mask, attention_mask, prefix) + if self.monitor_max_attention_logit: + max_logits = self.reduce_max(softmax_val, (0, 2, 3)) + output = F.depend(output, self.assign_add(self.max_logits_val, max_logits)) + if self.input_layout == "BNSD": output = self._merge_heads(output) elif self.input_layout == "BSH": @@ -331,4 +351,19 @@ class FlashAttention(Cell): if self.use_alibi_mask: self.alibi_rescale_mul.shard(((dp, tp, cp, 1), (1,))) + if self.monitor_max_attention_logit: + self.assign_add.shard( + in_strategy=(layout("tp"), layout("tp")), + out_strategy=(layout("tp"),) + ) + if self.input_layout == "BNSD": + self.reduce_max.shard( + in_strategy=(layout("None", "tp", "None", "None"),), + out_strategy=(layout("tp"),) + ) + elif self.input_layout == "TND": + self.reduce_max.shard( + in_strategy=(layout("None", "tp", "None"),), + out_strategy=(layout("tp"),) + ) return self diff --git a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py index e4a54710f55ac51eea06d666661624da7f8cf075..83b72cea0bb64c3dcdcbec971862a3650af24a74 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py @@ -20,6 +20,7 @@ from mindspore.ops.auto_generate import Shape, Cast, GroupedMatmul, Reshape, Swi from mindspore.ops.operations import Morph from mindspore.parallel._utils import _get_parallel_mode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout from mindformers.parallel_core.training_graph.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher, MoEAlltoAllDeredundencyTokenDispatcher, MoEAlltoAllZeroRedundancyTokenDispatcher from mindformers.parallel_core.transformer_config import TransformerConfig @@ -200,3 +201,27 @@ class FFNGroupedGEMM(nn.Cell): layout(dp, sp, mp0), # output [B, S, h] ) ) + + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + ep = self.config.expert_model_parallel_size + sharded_state_dict = {} + sharded_state_dict[self.weight1.name] = ShardedTensor( + key=self.weight1.name, + org_key=self.weight1.name, + dtype=self.weight1.dtype, + local_shape=(self.num_local_experts // ep * self.hidden_size, self.moe_ffn_hidden_size * 2), + global_shape=(self.num_local_experts * self.hidden_size, self.moe_ffn_hidden_size * 2), + global_offset=(0, 0), + axis_fragmentations=(ep, 1), + ) + sharded_state_dict[self.weight2.name] = ShardedTensor( + key=self.weight2.name, + org_key=self.weight2.name, + dtype=self.weight2.dtype, + local_shape=(self.num_local_experts // ep * self.moe_ffn_hidden_size, self.hidden_size), + global_shape=(self.num_local_experts * self.moe_ffn_hidden_size, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(ep, 1), + ) + return sharded_state_dict diff --git a/mindformers/parallel_core/training_graph/transformer/moe/router.py b/mindformers/parallel_core/training_graph/transformer/moe/router.py index 87f682e16d51515ed43442f9a31f734bf6e9bf73..c9bbf817ad35453a120ee789f4dce1b18c62672f 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/router.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/router.py @@ -26,6 +26,7 @@ from mindspore.ops.auto_generate import AddExt, AssignAdd, Cast, Div, Mul, Resha from mindspore.ops.operations import Shape, ReduceSum, ReduceMean from mindspore.parallel._utils import _get_parallel_mode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.device_matrix import layout_moe as layout from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.tools.utils import get_real_group_size, get_real_rank @@ -117,6 +118,20 @@ class Router(ABC, nn.Cell): router_logits = self.linear(inputs.astype(self.moe_router_dtype), weight) return router_logits + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.expert_dim, self.hidden_size), + global_shape=(self.expert_dim, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + ) + return sharded_state_dict + class TopKRouter(Router): """Route each token to the top-k experts.""" diff --git a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py index e00c7f27e09c930341aca554cd3a390d5ed15328..09ff7100a9bea2ae3c0229cb888d962d744faee4 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/shared_experts.py @@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import Cast, Mul, Sigmoid from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.context import ParallelMode +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.training_graph.transformer.mlp import MLP, MLPSubmodules, MLPInterleaved from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.training_graph.device_matrix import layout @@ -108,6 +109,20 @@ class SharedExpertMLP(MLP): def expert_sharding_propagation(self, config: TransformerConfig): super().sharding_propagation(config) + def sharded_state_dict(self): + """Provide the sharded state dict. Sharded info is not complete, Only for Muon optimizer now.""" + sharded_state_dict = {} + sharded_state_dict[self.shared_experts_gate.weight.name] = ShardedTensor( + key=self.shared_experts_gate.weight.name, + org_key=self.shared_experts_gate.weight.name, + dtype=self.shared_experts_gate.weight.dtype, + local_shape=(1, self.hidden_size), + global_shape=(1, self.hidden_size), + global_offset=(0, 0), + axis_fragmentations=(1, 1), + ) + return sharded_state_dict + class SharedExpertMLPInterleaved(MLPInterleaved): """ diff --git a/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py b/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py index 6e50269836774a3eb00fc50644896b1494f6bd21..87ace5765d8c0c06053e67188257db530517bff7 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py @@ -308,7 +308,7 @@ class MoEAlltoAllDeredundencyTokenDispatcher(MoETokenDispatcher): self.nonzero = ops.NonZero().recompute(False) self.squeeze_0 = ops.Squeeze(0).recompute(False) self.oep_allgather = ops.AllGather(group=self.oep_group).recompute(False) - self.onehot = ops.OneHot().recompute(False) + self.onehot = ops.OneHot() self.iep_alltoallv = ops.AlltoAllV(group=self.iep_group, block_size=1).recompute(False) def _get_oep_group_name(self): diff --git a/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py b/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py index a04f0ff112dcca6da4d36394d1291c16b4bfa3a5..e0dde20f18f66be5026d48ef73900fe78d7ac9f1 100644 --- a/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py +++ b/mindformers/parallel_core/training_graph/transformer/multi_token_prediction.py @@ -541,6 +541,9 @@ class MtpSharedVocabParallelEmbedding(VocabParallelEmbedding): output = self.embedding_morph(input_ids, weight) return output + def sharded_state_dict(self): + return {} + class MtpSharedLanguageModelEmbedding(LanguageModelEmbedding): """Embedding layer used in Multi-Token Prediction module, same to standard LanguageModelEmbedding.""" diff --git a/mindformers/parallel_core/training_graph/transformer/norm.py b/mindformers/parallel_core/training_graph/transformer/norm.py index 01adc1d019073247ebabea5dc5b85981fd881ee6..043f87f8a6b38dfceeb78c02517480347bc0d249 100644 --- a/mindformers/parallel_core/training_graph/transformer/norm.py +++ b/mindformers/parallel_core/training_graph/transformer/norm.py @@ -23,6 +23,7 @@ from mindspore.ops.auto_generate import MeanExt, Sqrt, Rsqrt, SubExt, AddExt, Mu from mindspore.common.initializer import initializer from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindformers.checkpoint.sharded_tensor import ShardedTensor from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.parallel_core.training_graph.device_matrix import layout @@ -55,6 +56,7 @@ class LayerNorm(nn.Cell): super(LayerNorm, self).__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype + self.dim = dim self.gamma = Parameter(initializer('ones', dim, self.params_dtype), name="gamma", parallel_optimizer=False) @@ -115,6 +117,29 @@ class LayerNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for LayerNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.gamma.name] = ShardedTensor( + key=self.gamma.name, + org_key=self.gamma.name, + dtype=self.gamma.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + sharded_state_dict[self.beta.name] = ShardedTensor( + key=self.beta.name, + org_key=self.beta.name, + dtype=self.beta.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class FusedLayerNorm(nn.Cell): r""" @@ -136,7 +161,7 @@ class FusedLayerNorm(nn.Cell): super(FusedLayerNorm, self).__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.layer_norm = P.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=eps) @@ -178,6 +203,29 @@ class FusedLayerNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for FusedLayerNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.gamma.name] = ShardedTensor( + key=self.gamma.name, + org_key=self.gamma.name, + dtype=self.gamma.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + sharded_state_dict[self.beta.name] = ShardedTensor( + key=self.beta.name, + org_key=self.beta.name, + dtype=self.beta.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class RMSNorm(nn.Cell): r""" @@ -199,7 +247,7 @@ class RMSNorm(nn.Cell): super(RMSNorm, self).__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.eps = eps self.weight = Parameter(initializer('ones', (dim), self.params_dtype)) @@ -249,6 +297,20 @@ class RMSNorm(nn.Cell): def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for RMSNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class FusedRMSNorm(nn.Cell): r""" @@ -270,7 +332,7 @@ class FusedRMSNorm(nn.Cell): super(FusedRMSNorm, self).__init__() self.params_dtype = config.params_dtype self.compute_type = config.layernorm_compute_dtype - + self.dim = dim self.eps = eps self.weight = Parameter(initializer('ones', (dim), self.params_dtype)) @@ -294,11 +356,25 @@ class FusedRMSNorm(nn.Cell): if in_strategy: self.norm.shard(in_strategy) else: - self.norm.shard((layout("cp", "dp", "None"), layout("None",))) + self.norm.shard((layout("cp", "dp", "None"), layout("None", ))) def sharding_propagation(self, config: TransformerConfig): pass + def sharded_state_dict(self): + """Return sharding metadata for FusedRMSNorm parameters.""" + sharded_state_dict = {} + sharded_state_dict[self.weight.name] = ShardedTensor( + key=self.weight.name, + org_key=self.weight.name, + dtype=self.weight.dtype, + local_shape=(self.dim,), + global_shape=(self.dim,), + global_offset=(0,), + axis_fragmentations=(1,), + ) + return sharded_state_dict + class Norm: """ diff --git a/mindformers/parallel_core/transformer_config_utils.py b/mindformers/parallel_core/transformer_config_utils.py index 12921a6369ecebe6fe3b857822dbd70d4d5fd315..46ef5fe6bcb8d3b5674480dfa2d8015eda0eb579 100644 --- a/mindformers/parallel_core/transformer_config_utils.py +++ b/mindformers/parallel_core/transformer_config_utils.py @@ -301,6 +301,7 @@ COMMON_CONFIG_MAPPING = { "untie_embeddings_and_output_weights": "untie_embeddings_and_output_weights", "hidden_act": "hidden_act", "mask_func_type": "mask_func_type", + "monitor_max_attention_logit": "monitor_max_attention_logit", ("extend_method", "position_embedding_type"): "position_embedding_type", ("init_method_std", "initializer_range"): "init_method_std", diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 4524aaac77272863fd0aeb3060651442fba1427a..781671b7cad73b9f444887ca01e72c60642a50a9 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -451,12 +451,60 @@ class TrainModelMixin: raise ValueError(f"the length of cur_layer_linear_fc2_weights_dict is " f"{len(cur_layer_linear_fc2_weights_dict)}, can't stack them.") - def get_model_parameters(self): - """Get current rank trainable parameters in model .""" + def check_and_get_model(self): + """Check and get GPT model instance.""" if not hasattr(self, 'model'): raise RuntimeError("Mcore model definition should use the fixed paradigm: " "self.model = GPTModel(*args, **kwargs) definition. " "Currently, this attribute cannot be correctly recognized. " "Please modify the GPTModel definition method.") - model = getattr(self, 'model') + return getattr(self, 'model') + + def get_model_parameters(self): + """Get current rank trainable parameters in model .""" + model = self.check_and_get_model() return model.get_model_parameters() + + def get_max_attention_logit(self): + """Get max attention logit values from the model.""" + model = self.check_and_get_model() + return model.get_max_attention_logit() + + def make_model_muon_fns(self): + """Make model muon functions.""" + model = self.check_and_get_model() + return model.make_model_muon_fns() + + def get_muon_filter(self): + """Get muon filter.""" + model = self.check_and_get_model() + return model.get_muon_filter() + + def get_tp_dims(self, parameters): + """Get tensor parallel dimensions for parameters.""" + model = self.check_and_get_model() + return model.get_tp_dims(parameters) + + def get_op_groups_info(self, parameters, op_size): + """Get operation groups information for parameters.""" + model = self.check_and_get_model() + return model.get_op_groups_info(parameters, op_size) + + def get_parallel_config_for_muon(self): + """Get parallel configuration for Muon optimizer.""" + model = self.check_and_get_model() + return model.get_parallel_config_for_muon() + + def get_param_layer_indices(self, parameters): + """Get layer indices for parameters.""" + model = self.check_and_get_model() + return model.get_param_layer_indices(parameters) + + def apply_qk_clip_scaling(self, parameters, param_names, param_layers, + logit_threshold, split_fn, merge_fn): + """Apply QK clip scaling to parameters.""" + model = self.check_and_get_model() + return model.apply_qk_clip_scaling( + parameters, param_names, param_layers, + logit_threshold, split_fn, merge_fn + ) diff --git a/mindformers/tools/register/template.py b/mindformers/tools/register/template.py index f96a6f628b0e0a6706374bba0c0a564375b9bc4d..48f6770bea077bb28c60d5a0c74f71f24a74b780 100644 --- a/mindformers/tools/register/template.py +++ b/mindformers/tools/register/template.py @@ -544,6 +544,7 @@ class MonitorConfig(Config): device_local_loss_format = None optimizer_state_format = None weight_state_format = None + max_attention_logit_format = None throughput_baseline = None print_struct = False check_for_global_norm = False @@ -686,7 +687,7 @@ class ConfigTemplate: continue new_config[sub_config] = class_.apply(config.pop(sub_config, None)) - unused_config = [key for key in config.keys()] + unused_config = list(config.keys()) if unused_config: logger.warning(f"Some configs in yaml are useless for {run_mode}: {unused_config}") config.update(new_config) diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 105d73bd5fa1132cc6f69eebf9db92e85003000c..70aaeb4ec2382fe1aa1e102dcf0b973a0fd19dcb 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -468,6 +468,7 @@ class BaseTrainer: if self.config.get("generation_config", None): self.config.model.generation_config = self.config.generation_config network = build_network(self.config.model, default_args=default_args) + self.real_model = network if hasattr(network, "check_pipeline_stage") and callable(network.check_pipeline_stage): network.check_pipeline_stage() return network @@ -600,19 +601,27 @@ class BaseTrainer: optimizer_type=self.config.optimizer.type, model_params=model_params) if lr_schedule is not None: + default_args = {"params": group_params, "learning_rate": lr_schedule} + if self.config.optimizer.type == "Muon": + default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num + default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model self.optimizer = build_optim( self.config.optimizer, - default_args={"params": group_params, - "learning_rate": lr_schedule}) + default_args=default_args) else: if self.config.optimizer.learning_rate is None: raise ValueError("learning_rate must be input") self.config.optimizer.learning_rate = self.learning_rate_scale( self.config.optimizer.learning_rate, scale_factor) \ if learning_scale and scale_factor is not None else self.config.optimizer.learning_rate + default_args = {"params": group_params} + if self.config.optimizer.type == "Muon": + default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num + default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model + # Build optimizer with fixed learning rate self.optimizer = build_optim( self.config.optimizer, - default_args={"params": group_params}) + default_args=default_args) return self.optimizer def create_optimizer_scheduler_without_param_init(self, network, model_params: set, layer_scale=False): @@ -1184,7 +1193,7 @@ class BaseTrainer: network = self.network model_params = set() - if self.config.optimizer.type in ("PmaAdamW", "FusedPmaAdamW"): + if self.config.optimizer.type in ("PmaAdamW", "FusedPmaAdamW", "Muon"): if hasattr(network, "get_model_parameters"): model_params.update(network.get_model_parameters()) else: @@ -1285,7 +1294,9 @@ class BaseTrainer: "dataset_size": config.data_size, "initial_epoch": config.runner_config.initial_epoch, "initial_step": config.runner_config.initial_step, + "micro_batch_num": config.parallel_config.micro_batch_num, "global_batch_size": self.global_batch_size, + "tensor_model_parallel_size": config.parallel_config.model_parallel, "check_for_nan_in_loss_and_grad": getattr(config, "check_for_nan_in_loss_and_grad", False), "use_skip_data_by_global_norm": getattr(config, "use_skip_data_by_global_norm", False), "embedding_size": embedding_size, @@ -1353,6 +1364,13 @@ class BaseTrainer: } default_callbacks.append(build_callback({"type": "TrainFaultTolerance"}, default_args=default_args)) + if config.optimizer.type == "Muon" or config.monitor_config.max_attention_logit_format is not None: + logger.info( + f"Added MaxLogitsMonitor | optimizer={config.optimizer.type}, " + f"max_attention_logit_format={config.monitor_config.max_attention_logit_format}" + ) + default_callbacks.append(build_callback({"type": "MaxLogitsMonitor"})) + if callbacks is not None: if isinstance(callbacks, list): default_callbacks.extend(callbacks) diff --git a/mindformers/trainer/optimizer_grouped_parameters.py b/mindformers/trainer/optimizer_grouped_parameters.py index b614550d0414789f1c063f808fdeb9ff7e287c6a..24a191240e87abc9c6a47339b70814a796a39282 100644 --- a/mindformers/trainer/optimizer_grouped_parameters.py +++ b/mindformers/trainer/optimizer_grouped_parameters.py @@ -57,7 +57,7 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, decay_parameters_names = [] - if optimizer_type in ("PmaAdamW", "FusedPmaAdamW"): + if optimizer_type in ("PmaAdamW", "FusedPmaAdamW", "Muon"): filter_current_stage_parameters(model, model_params) for param in model.trainable_params(): @@ -107,7 +107,6 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(param.name) - param_groups = json.dumps(parameter_group_names, indent=2) logger.info("Param groups = %s", param_groups) return list(parameter_group_vars.values()) diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 0ed9c54e14511c068c455ed8c729cfd47143cdbe..680840c7086660a9dc44562042088566802aed8e 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -234,6 +234,7 @@ class Trainer: else: self.is_model_instance = False if isinstance(self.model, str): + self.model = f"{self.model}_{self.pet_method}" if self.pet_method else self.model self.model = f"{self.model}_{self.pet_method}" if self.pet_method else self.model if self.model not in SUPPORT_MODEL_NAMES: raise ValueError(f"model must be in {SUPPORT_MODEL_NAMES} " @@ -270,6 +271,8 @@ class Trainer: task_config = self.get_task_config(self.task, self.model_name) self.config = self._config_init(args, task_config) + if self.config.get('optimizer') and self.config.optimizer.type == "Muon": + self.config.model.model_config.monitor_max_attention_logit = True self._reassign_monitor_config() # build parallel config build_parallel_config(self.config) @@ -343,6 +346,8 @@ class Trainer: dump_local_norm=bool(monitor_config.get('local_norm_format')), dump_device_local_norm=bool(monitor_config.get('device_local_norm_format')) ) + if monitor_config.max_attention_logit_format: + self.config.model.model_config.monitor_max_attention_logit = True if monitor_config.local_loss_format: set_context(monitor_local_loss=True) if monitor_config.device_local_loss_format: @@ -1220,7 +1225,7 @@ class Trainer: """ Initializes a git repo in `self.config.hub_model_id`. """ - from modelfoundry_hub import create_repo + from modelfoundry_hub import create_repo # pylint: disable=import-outside-toplevel if self.config.rank_id: return @@ -1382,7 +1387,7 @@ class Trainer: def _try_use_pretrained_model_dir_as_ckpt(self): """Use `pretrained_model_dir` as fallback for checkpoint if applicable.""" if not self.config.load_checkpoint and self.config.pretrained_model_dir: - from mindformers.utils import contains_safetensors_files + from mindformers.utils import contains_safetensors_files # pylint: disable=import-outside-toplevel if contains_safetensors_files(self.config.pretrained_model_dir): self.config.load_checkpoint = self.config.pretrained_model_dir logger.info(f'Parameter load_checkpoint does not set the weight path. Defaulting to ' @@ -1508,7 +1513,7 @@ class Trainer: The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the progress of the commit if `blocking=True`. """ - from modelfoundry_hub import upload_folder + from modelfoundry_hub import upload_folder # pylint: disable=import-outside-toplevel if self.hub_model_id is None: self.init_openmind_repo() diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index 08a34b0095669d0293b00815035fd8f4110435f3..94ddfee585730dfb024d97a536925924e55d0adf 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -273,7 +273,7 @@ "signature": "(summary_dir=None, collect_freq=10, collect_specified_data=None, keep_default_action=True, custom_lineage_data=None, collect_tensor_freq=None, max_file_size=None, export_options=None)" }, "mindformers.core.TrainingStateMonitor": { - "signature": "(origin_epochs: int, config: dict = None, step_interval: int = 1, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, global_batch_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False)" + "signature": "(origin_epochs: int, config: dict = None, step_interval: int = 1, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, micro_batch_num: int = 0, global_batch_size: int = 0, tensor_model_parallel_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False)" }, "mindformers.core.TrainingStateMonitor.on_train_epoch_begin": { "signature": "(self, run_context)" @@ -471,7 +471,7 @@ "signature": "(summary_dir=None, collect_freq=10, collect_specified_data=None, keep_default_action=True, custom_lineage_data=None, collect_tensor_freq=None, max_file_size=None, export_options=None)" }, "mindformers.core.callback.TrainingStateMonitor": { - "signature": "(origin_epochs: int, config: dict = None, step_interval: int = 1, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, global_batch_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False)" + "signature": "(origin_epochs: int, config: dict = None, step_interval: int = 1, dataset_size: int = None, initial_epoch: int = 0, initial_step: int = 0, micro_batch_num: int = 0, global_batch_size: int = 0, tensor_model_parallel_size: int = 0, check_for_nan_in_loss_and_grad: bool = False, use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False)" }, "mindformers.core.callback.TrainingStateMonitor.on_train_epoch_begin": { "signature": "(self, run_context)" diff --git a/tests/st/test_ut/test_core/test_optim/test_get_op_group.py b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py new file mode 100644 index 0000000000000000000000000000000000000000..05728ab5128650e4f695680300cd3be40df2cc95 --- /dev/null +++ b/tests/st/test_ut/test_core/test_optim/test_get_op_group.py @@ -0,0 +1,178 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test get op groups info for GPT model.""" + +from unittest.mock import patch + +import mindspore as ms +import pytest + +from mindformers import build_context +from mindformers.checkpoint.sharded_tensor import build_sharded_tensor +from mindformers.parallel_core.training_graph.base_models.gpt import gpt_model +from mindformers.parallel_core.training_graph.base_models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, \ + get_gpt_mtp_block_spec +from mindformers.parallel_core.training_graph.base_models.gpt.gpt_model import GPTModel, \ + compute_repeat_num_and_model_parallel_size, get_op_group_name +from mindformers.parallel_core.transformer_config import TransformerConfig + + +def build_transformer_config() -> TransformerConfig: + """Create a minimal transformer config for tensor-parallel unit tests.""" + return TransformerConfig( + data_parallel_size=1, + pipeline_model_parallel_size=1, + tensor_model_parallel_size=1, + # model architecture + vocab_size=1024, + position_embedding_type="rope", + num_attention_heads=2, + num_layers=2, + hidden_size=128, + ffn_hidden_size=512, + # moe architecture + num_moe_experts=4, + first_k_dense_replace=1, + mtp_num_layers=1, + add_bias_linear=False, + moe_grouped_gemm=True + ) + + +def build_gpt_model(): + """Construct a GPTModel instance with the default test configuration.""" + config = build_transformer_config() + transformer_layer_spec = get_gpt_decoder_block_spec(config) + mtp_block_spec = None + if config.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec) + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=config.vocab_size, + max_sequence_length=config.max_position_embeddings, + position_embedding_type=config.position_embedding_type, + rotary_percent=1.0, + rotary_base=config.rotary_base, + rope_scaling=False, + mtp_block_spec=mtp_block_spec + ) + return model + + +def build_sharded_info(local_shape, axis_fragmentations): + """Helper to create a simple ShardedTensor descriptor.""" + return build_sharded_tensor( + param_name="test", + param_dtype=ms.float32, + local_shape=local_shape, + global_shape=local_shape, + axis_fragmentations=axis_fragmentations, + global_offset=(0,) * len(local_shape), + ) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gpt_model_sharded_state_dict(): + """ + Feature: GPTModel + Description: Test the sharded state dict of GPT model. + Expectation: The sharded state dict has all the trainable parameters and the shape is correct. + """ + build_context({"use_legacy": False}) + model = build_gpt_model() + sharded_state_dict = model.sharded_state_dict() + + params = model.trainable_params() + for param in params: + assert param.name in sharded_state_dict + assert param.shape == sharded_state_dict[param.name].global_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize( + "axis_fragmentations, world_size, pipeline_parallel, opt_group_size, local_shape, expected", + [ + # case 0: real_op_size == opt_group_size + ((1, 1), 12, 2, 4, (12, 4), (4, 1)), + # case 1: real_op_size < opt_group_size + ((2, 1), 16, 2, 8, (12, 4), (4, 2)), + # case 2: real_op_size = 1 due to local shape not divisible by real_op_size + ((4, 1), 32, 2, 4, (10, 4), (1, 4)), + ], +) +def test_compute_repeat_num_and_model_parallel_size(axis_fragmentations, world_size, pipeline_parallel, + opt_group_size, local_shape, expected): + """ + Feature: compute_repeat_num_and_model_parallel_size() + Description: Test the compute repeat num and model parallel size. + Expectation: The compute repeat num and model parallel size should be correct. + """ + sharded_info = build_sharded_info(local_shape, axis_fragmentations) + assert compute_repeat_num_and_model_parallel_size( + sharded_info, + world_size=world_size, + pp=pipeline_parallel, + op=opt_group_size, + ) == expected + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_compute_repeat_num_and_model_parallel_size_multiple_axis_error(): + """ + Feature: compute_repeat_num_and_model_parallel_size() + Description: Test the error of compute repeat num and model parallel size. + Expectation: The ValueError should be raised. + """ + sharded_info = build_sharded_info((8, 8), (2, 2)) + with pytest.raises(ValueError): + compute_repeat_num_and_model_parallel_size(sharded_info, world_size=16, pp=1, op=2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@patch("mindformers.parallel_core.training_graph.base_models.gpt.gpt_model.create_communication_group") +def test_get_op_group_name_with_mock(mock_create_group): + """ + Feature: get_op_group_name() + Description: Test the get op group name with mock. + Expectation: The get op group name with mock should be correct. + """ + mock_create_group.return_value = "mock_group" + gpt_model.OP_GROUP_NAME.clear() + + # case 0: model_parallel_size > 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert result == ("mock_group", [1, 3]) + mock_create_group.assert_called_once_with([1, 3]) + + second_result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=2) + assert second_result == result + mock_create_group.assert_called_once() + + # case 1: model_parallel_size = 1 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=1) + assert result == ("mock_group", [2, 3]) + + # case 2: model_parallel_size = 4 + result = get_op_group_name(rank_id=3, real_op_size=2, model_parallel_size=4) + assert result == ("mock_group", [3, 7])