diff --git a/patch/npu.patch b/patch/npu.patch index 74090bf605418517a0a0c6f69bdfc60e3338bcd0..143bbb88baf0ae6807baee98c1bb10954463bf2a 100644 --- a/patch/npu.patch +++ b/patch/npu.patch @@ -688,7 +688,7 @@ diff -Nur '--exclude=.git' apex/apex/amp/_process_optimizer.py apex-develop/apex + with torch.npu.stream(reduce_stream): + dist.all_reduce(partial_combined_grad_list[index]) + -+ current_param_size_list[name_dict[name]] += param.storage().size() ++ current_param_size_list[name_dict[name]] += param.numel() + for i, _ in enumerate(current_param_size_list): + if current_param_size_list[i] == target_grads_size_list[i] and current_param_size_list[i] != 0: + ready_reduce_index.append(i) @@ -757,7 +757,7 @@ diff -Nur '--exclude=.git' apex/apex/amp/_process_optimizer.py apex-develop/apex + name_order = 0 + for param_idx, param in enumerate(params): + name = '%d_%d'%(idx, param_idx) -+ cur_size = param.storage().size() ++ cur_size = param.numel() + if cur_size > exchange_threshold_list[idx] and tmp_size != 0: + target_grads_size_list.append(tmp_size) + tmp_size = 0 @@ -1250,8 +1250,8 @@ diff -Nur '--exclude=.git' apex/apex/amp/_process_optimizer.py apex-develop/apex + if p.grad is None: + continue + -+ param_size = p.storage().size() -+ grad_size = p.grad.storage().size() ++ param_size = p.numel() ++ grad_size = p.grad.numel() + if is_grad_in_combined_tensor(p.grad, combined_fp32_from_fp32_grad): + group_fp32_from_fp32_param_size += param_size + group_fp32_from_fp32_params.append(p) @@ -1331,11 +1331,11 @@ diff -Nur '--exclude=.git' apex/apex/amp/_process_optimizer.py apex-develop/apex + if p.grad is None: + continue + -+ param_size = p.storage().size() ++ param_size = p.numel() + group_fp32_param_size += param_size + group_fp32_params.append(p) + -+ grad_size = p.grad.storage().size() ++ grad_size = p.grad.numel() + group_fp32_grad_size += grad_size + + combined_group_fp32_param = None diff --git a/src/apex/contrib/combine_tensors/combine_tensors.py b/src/apex/contrib/combine_tensors/combine_tensors.py index a58af969836c5ee15bf6495cbc16930a01e34ef6..9aa680ea8a1d58b00b6615eddc0e2a9bc085bc67 100644 --- a/src/apex/contrib/combine_tensors/combine_tensors.py +++ b/src/apex/contrib/combine_tensors/combine_tensors.py @@ -18,7 +18,7 @@ from change_data_ptr import change_data_ptr def combine_npu(list_of_tensor, require_copy_value = True): total_numel = 0 for tensor in list_of_tensor: - total_numel += tensor.storage().size() + total_numel += tensor.numel() if total_numel == 0: return None @@ -32,20 +32,20 @@ def combine_npu(list_of_tensor, require_copy_value = True): temp = tensor.clone() change_data_ptr(tensor, combined_tensor, idx) tensor.copy_(temp) - idx += tensor.storage().size() + idx += tensor.numel() else: for tensor in list_of_tensor: change_data_ptr(tensor, combined_tensor, idx) - idx += tensor.storage().size() + idx += tensor.numel() return combined_tensor def get_part_combined_tensor(combined_tensor, index, size): if combined_tensor is None or size == 0: return None - if (index + size) > combined_tensor.storage().size(): - raise RuntimeError("(index + size) ({}) > combined_tensor.storage().size() ({})".format( - index + size, combined_tensor.storage().size())) + if (index + size) > combined_tensor.numel(): + raise RuntimeError("(index + size) ({}) > combined_tensor.numel() ({})".format( + index + size, combined_tensor.numel())) part_tensor = torch.zeros(size, dtype=combined_tensor.dtype).npu() change_data_ptr(part_tensor, combined_tensor, index) @@ -59,7 +59,7 @@ def is_combined_tensor_valid(combined_tensor, list_of_tensor): combined_tensor_start_addr = combined_tensor.data_ptr() combined_tensor_end_addr = combined_tensor_start_addr + \ - combined_tensor.storage().size() * combined_tensor.element_size() + combined_tensor.numel() * combined_tensor.element_size() for tensor in list_of_tensor: if tensor is None or \ diff --git a/src/apex/contrib/test/test_combine_tensors.py b/src/apex/contrib/test/test_combine_tensors.py index eb9e4f7041616fb661dcf97f9cac6f37ca61faeb..dc49d3db6c217bef4e8f8cd6446d0cf0a40bc0d6 100644 --- a/src/apex/contrib/test/test_combine_tensors.py +++ b/src/apex/contrib/test/test_combine_tensors.py @@ -18,6 +18,7 @@ import functools as ft import itertools as it import sys import torch +import torch_npu from apex import amp from apex.contrib.combine_tensors import combine_npu @@ -58,9 +59,9 @@ class TestCombineTensors(unittest.TestCase): # test if combine_tensor is contiguous, and x,y,z are will moved into the combine_tensor. self.assertEqual(True, combine_tensor.is_contiguous()) self.assertEqual(combine_tensor.data_ptr(), x.data_ptr()) - self.assertEqual(x.data_ptr() + x.storage().size() * x.element_size(), y.data_ptr()) - self.assertEqual(y.data_ptr() + y.storage().size() * y.element_size(), z.data_ptr()) - self.assertEqual(combine_tensor.storage().size(), x.storage().size() + y.storage().size() + z.storage().size()) + self.assertEqual(x.data_ptr() + x.numel() * x.element_size(), y.data_ptr()) + self.assertEqual(y.data_ptr() + y.numel() * y.element_size(), z.data_ptr()) + self.assertEqual(combine_tensor.numel(), x.numel() + y.numel() + z.numel()) def test_basic_fp32(self): print('----------------------test basic functionality of fp32------------------------') @@ -88,9 +89,9 @@ class TestCombineTensors(unittest.TestCase): # test for tensors with very large sizes. self.assertEqual(True, combine_tensor.is_contiguous()) self.assertEqual(combine_tensor.data_ptr(), x.data_ptr()) - self.assertEqual(x.data_ptr() + x.storage().size() * x.element_size(), y.data_ptr()) - self.assertEqual(y.data_ptr() + y.storage().size() * y.element_size(), z.data_ptr()) - self.assertEqual(combine_tensor.storage().size(), x.storage().size() + y.storage().size() + z.storage().size()) + self.assertEqual(x.data_ptr() + x.numel() * x.element_size(), y.data_ptr()) + self.assertEqual(y.data_ptr() + y.numel() * y.element_size(), z.data_ptr()) + self.assertEqual(combine_tensor.numel(), x.numel() + y.numel() + z.numel()) def test_computation(self): print('----------------------test computation------------------------') @@ -164,4 +165,5 @@ class TestCombineTensors(unittest.TestCase): self.assertNotEqual(store_x_after_combine, new_tensor.data_ptr()) if __name__ == '__main__': + torch.npu.set_device("npu:0") unittest.main(argv=['test_combine_tensors.py'])