diff --git a/src/cam/comm_operator/pybind/functions.h b/src/cam/comm_operator/pybind/functions.h index 8f583349f9d1b851242ff3047aec8d22e22d69ae..4d5b4be7ca85b586422b851744925becb84264b5 100644 --- a/src/cam/comm_operator/pybind/functions.h +++ b/src/cam/comm_operator/pybind/functions.h @@ -58,7 +58,7 @@ moe_combine_normal_impl_autograd( const at::Tensor &tokenSrcInfo, \ const at::Tensor &epRecvCounts, \ const at::Tensor &recvTopkWeights, \ - const std::optional &tpRecvCounts, \ + const c10::optional &tpRecvCounts, \ c10::string_view epGroupName, \ int64_t epWorldSize, \ int64_t epRankId, \ diff --git a/src/cam/comm_operator/pybind/moe_dispatch_normal.cpp b/src/cam/comm_operator/pybind/moe_dispatch_normal.cpp index 31a2987f28fa388d4926e940b2e825e43dea7d2c..d927d13da1f49cbe7011562576086c9091e1a5ce 100644 --- a/src/cam/comm_operator/pybind/moe_dispatch_normal.cpp +++ b/src/cam/comm_operator/pybind/moe_dispatch_normal.cpp @@ -173,7 +173,7 @@ public: } }; -tensor_list moe_dispatch_normal_impl_autograd( +std::tuple moe_dispatch_normal_impl_autograd( const at::Tensor &x, \ const at::Tensor &topkIdx, \ const at::Tensor &sendOffset, \ @@ -193,7 +193,7 @@ tensor_list moe_dispatch_normal_impl_autograd( auto result = ExtMoeDispatchNormal::apply(x, topkIdx, sendOffset, sendTokenIdx, recvOffset, \ recvCount, groupEp, epWorldSize, epRankId, \ groupTp, tpWorldSize, tpRankId, moeExpertNum, quantMode, globalBs); - return result; + return std::make_tuple(result[0], result[1], result[2]); } // moe_dispatch_normal diff --git a/src/cam/comm_operator/pybind/setup.py b/src/cam/comm_operator/pybind/setup.py index dd46a73c96da0cf6129982e1d96500f4925b80d2..1da4931fb92b540d01a16edab739bd957008b52d 100644 --- a/src/cam/comm_operator/pybind/setup.py +++ b/src/cam/comm_operator/pybind/setup.py @@ -13,23 +13,23 @@ import torch import platform import importlib.util -sys.path.append(os.path.join(os.path.dirname(__file__), './pytorch_extension')) +sys.path.append(os.path.join(os.path.dirname(__file__), "./pytorch_extension")) from bdist_wheel_build import BdistWheelBuild from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension from torch_npu.utils.cpp_extension import NpuExtension # 格式: V版本.R版本.C版本.B版本 -env_version = os.getenv('CAM_WHL_VERSION', '208.1.0.B001') +env_version = os.getenv("CAM_WHL_VERSION", "208.1.0.B001") torch_path = os.path.dirname(torch.__file__) -torch_npu_spec = importlib.util.find_spec('torch_npu') +torch_npu_spec = importlib.util.find_spec("torch_npu") torch_npu_path = os.path.dirname(torch_npu_spec.origin) print(f"torch_path: {torch_path}") print(f"torch_npu_path: {torch_npu_path}") PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.abspath(torch_npu_spec.origin)) architecture = str(platform.machine()) -if architecture.startswith('x86'): +if architecture.startswith("x86"): arch = "x86_64" else: arch = "aarch64" @@ -39,45 +39,47 @@ for env_name in env_names: if env_name not in os.environ: print(f"{env_name} is not in env, please export {env_name} first") compile_args = [ - '-I' + os.path.join(PYTORCH_NPU_INSTALL_PATH, "include/third_party/acl/inc"), + "-I" + os.path.join(PYTORCH_NPU_INSTALL_PATH, "include/third_party/acl/inc"), "-fPIC", "-fstack-protector-strong", "-w", "-D_FORTIFY_SOURCE=2", ] if "BUILD_TYPE" in os.environ and os.environ.get("BUILD_TYPE") == "Debug": - compile_args.extend(['-g', '-O0']) + compile_args.extend(["-g", "-O0"]) else: - compile_args.extend(['-O2']) + compile_args.extend(["-O2"]) if "ENABLE_COV" in os.environ and os.environ.get("ENABLE_COV") == "1": - compile_args.extend(['-coverage']) + compile_args.extend(["-coverage"]) print(compile_args) exts = [] ext1 = NpuExtension( name="cam_ge_op_lib", include_dirs=[ - os.path.join(torch_npu_path, 'include'), - os.path.join(torch_npu_path, 'include/third_party/acl/inc/acl/'), - os.path.join(torch_npu_path, 'include/third_party/acl/inc'), - os.path.join(os.environ["ASCEND_HOME_PATH"], f'{arch}-linux', 'include'), - os.path.join(os.environ["ASCEND_HOME_PATH"], f'{arch}-linux', 'include', 'hccl'), - os.path.join(os.environ["ASCEND_HOME_PATH"], f'{arch}-linux', 'include', 'experiment', 'runtime'), - os.path.join(os.environ["ASCEND_HOME_PATH"], f'{arch}-linux', 'include', "experiment", 'msprof'), - os.path.join(torch_path, 'include'), + os.path.join(torch_npu_path, "include"), + os.path.join(torch_npu_path, "include/third_party/acl/inc/acl/"), + os.path.join(torch_npu_path, "include/third_party/acl/inc"), + os.path.join(os.environ["ASCEND_HOME_PATH"], f"{arch}-linux", "include"), + os.path.join(os.environ["ASCEND_HOME_PATH"], f"{arch}-linux", "include", "hccl"), + os.path.join(os.environ["ASCEND_HOME_PATH"], f"{arch}-linux", "include", "experiment", "runtime"), + os.path.join(os.environ["ASCEND_HOME_PATH"], f"{arch}-linux", "include", "experiment", "msprof"), + os.path.join(torch_path, "include"), os.path.join(os.path.dirname(__file__), "./", "pytorch_extension")], library_dirs=[ - os.path.join(torch_path, 'lib'), - os.path.join(torch_npu_path, 'lib'), - os.path.join(os.environ["ASCEND_HOME_PATH"], 'opp', 'vendors', 'CAM', 'op_api', 'lib'), - os.path.join(os.environ["ASCEND_HOME_PATH"], f'{arch}-linux', 'lib64')], + os.path.join(torch_path, "lib"), + os.path.join(torch_npu_path, "lib"), + os.path.join(os.environ["ASCEND_HOME_PATH"], f"{arch}-linux", "lib64")], libraries=[ "torch_npu", - 'gcov', - 'runtime', 'torch', 'ascendcl', 'profapi', 'opapi', 'cust_opapi'], + "gcov", + "runtime", + "torch", + "ascendcl", + "profapi"], sources=["./fused_deep_moe.cpp", "./moe_dispatch_normal.cpp", "./moe_combine_normal.cpp", - "./pybind.cpp", + "./pybind.cpp", ], extra_compile_args = compile_args, @@ -92,11 +94,11 @@ BdistWheelBuild.dependencies = ["libc10.so", "libtorch.so", "libtorch_cpu.so", " setup( name="cam_ge_operator", version=env_version, - keywords='cam_ge_op_lib', + keywords="cam_ge_op_lib", ext_modules=exts, packages=find_packages(), cmdclass={ "build_ext": BuildExtension, - 'bdist_wheel': BdistWheelBuild + "bdist_wheel": BdistWheelBuild }, ) \ No newline at end of file