From 30cbfd0e61cd436487b4d8a5976ee42942e7ab85 Mon Sep 17 00:00:00 2001 From: niujunhao Date: Fri, 12 Dec 2025 12:07:07 +0800 Subject: [PATCH] add weight decay. --- mindformers/trainer/base_trainer.py | 4 +++- .../trainer/optimizer_grouped_parameters.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 0672c679b..18fedc5d1 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -638,7 +638,9 @@ class BaseTrainer: model_params=model_params, optimizer_type=self.config.optimizer.type, layer_scale=self.config.layer_scale, - layer_decay=self.config.layer_decay + layer_decay=self.config.layer_decay, + need_weight_decay=self.config.need_weight_decay, + no_weight_decay=self.config.no_weight_decay, ) # Build optimizer with dynamic lr_scheduler if available diff --git a/mindformers/trainer/optimizer_grouped_parameters.py b/mindformers/trainer/optimizer_grouped_parameters.py index bbc18379e..f386c1ab3 100644 --- a/mindformers/trainer/optimizer_grouped_parameters.py +++ b/mindformers/trainer/optimizer_grouped_parameters.py @@ -100,9 +100,12 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, model_params: set = None, grouped_lr_schedule: dict = None, layer_scale: bool = False, - layer_decay: float = 1.0,): + layer_decay: float = 1.0, + need_weight_decay: list = [], + no_weight_decay: list = [], + ): """ - Build optimizer parameter groups with appropriate weight decay, + Build optimizer parameter groups with appropriate weight decay, learning rate scheduling, and optional parameter grouping. """ if not isinstance(model, (Cell, PreTrainedModel)): @@ -119,6 +122,8 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, no_wd_keywords = model.no_weight_decay_keywords() logger.info(f'Get no weight decay keywords: {no_wd_keywords}') + logger.info(f'Get need_weight_decay: {need_weight_decay}, no_weight_decay: {no_weight_decay}') + # set default values if not provided if not weight_decay: weight_decay = 0.0 @@ -144,6 +149,13 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, or param_name in no_wd_params or check_keywords_in_name(param_name, no_wd_keywords) ) + + if check_keywords_in_name(param_name, no_weight_decay): + no_wd = True + elif check_keywords_in_name(param_name, need_weight_decay): + no_wd = False + logger.info(f"Set param: {param_name}, weight_decay: {not no_wd}") + if no_wd: wd_mul = 0.0 group_name = 'no_weight_decay' -- Gitee