From 39a98ae8134fecf1ea9942e978368dc2de58734b Mon Sep 17 00:00:00 2001 From: sunshine Date: Mon, 24 Nov 2025 10:33:03 +0800 Subject: [PATCH] 310P support for fusedrope --- kernels/fused_rope/fused_rope.cpp | 2 ++ kernels/fused_rope/fused_rope_bf16.h | 32 +++++++++++++++++++--------- kernels/fused_rope/fused_rope_fp32.h | 4 ++-- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/kernels/fused_rope/fused_rope.cpp b/kernels/fused_rope/fused_rope.cpp index 66e105e..4767bac 100644 --- a/kernels/fused_rope/fused_rope.cpp +++ b/kernels/fused_rope/fused_rope.cpp @@ -18,6 +18,7 @@ extern "C" __global__ __aicore__ void FusedRopeKernel( uint64_t numTokensFrontCoreLastLoop, uint64_t numTokensTailCoreLastLoop, uint64_t tilingKey) { TPipe pipe; +#if (ASCEND_AICORE_ARCH >= 220) // DT_BF16 if (tilingKey == (uint64_t)kvcache_ops::AscendType::BF16) { TPipe* ptr = &pipe; @@ -35,6 +36,7 @@ extern "C" __global__ __aicore__ void FusedRopeKernel( op.Process(); } } +#endif // DT_FLOAT16 if (tilingKey == (uint64_t)kvcache_ops::AscendType::FP16) { TPipe* ptr = &pipe; diff --git a/kernels/fused_rope/fused_rope_bf16.h b/kernels/fused_rope/fused_rope_bf16.h index d5635a0..eed962a 100644 --- a/kernels/fused_rope/fused_rope_bf16.h +++ b/kernels/fused_rope/fused_rope_bf16.h @@ -312,14 +312,26 @@ __aicore__ inline void FusedRopeFP16::Rope( this->rotaryDim); PipeBarrier(); } - Cast( - inQueCalLocal, temp1Local, AscendC::RoundMode::CAST_RINT, - static_cast(loopN * this->numHeads * this->rotaryDim)); + #if ASCEND_AICORE_ARCH >= 220 + Cast( + inQueCalLocal, temp1Local, AscendC::RoundMode::CAST_RINT, + static_cast(loopN * this->numHeads * this->rotaryDim)); + #else + Cast( + inQueCalLocal, temp1Local, AscendC::RoundMode::CAST_NONE, + static_cast(loopN * this->numHeads * this->rotaryDim)); + #endif PipeBarrier(); } else { - Cast( - inQueCalLocal, inLocal, AscendC::RoundMode::CAST_RINT, - static_cast(loopN * this->numHeads * this->rotaryDim)); + #if ASCEND_AICORE_ARCH >= 220 + Cast( + inQueCalLocal, inLocal, AscendC::RoundMode::CAST_RINT, + static_cast(loopN * this->numHeads * this->rotaryDim)); + #else + Cast( + inQueCalLocal, inLocal, AscendC::RoundMode::CAST_NONE, + static_cast(loopN * this->numHeads * this->rotaryDim)); + #endif PipeBarrier(); } } @@ -369,11 +381,11 @@ __aicore__ inline void FusedRopeFP16::Compute(uint64_t index, uint64_t loopN) for (uint32_t i = 0; i < loopN * this->numHeads; i++) { GatherMask( inLocal[i * this->rotaryDim], temp1Local[i * this->rotaryDim], static_cast(1), true, - this->rotaryDim, {1, 1, 0, 0}, rsv); + this->rotaryDim, {1, 1, 8, 0}, rsv); PipeBarrier(); GatherMask( inLocal[i * this->rotaryDim + this->rotaryDim / 2], temp1Local[i * this->rotaryDim], - static_cast(2), true, this->rotaryDim, {1, 1, 0, 0}, rsv); + static_cast(2), true, this->rotaryDim, {1, 1, 8, 0}, rsv); PipeBarrier(); } } else { @@ -388,8 +400,8 @@ __aicore__ inline void FusedRopeFP16::Compute(uint64_t index, uint64_t loopN) inQueueCosSinCacheBeforeCastLocal, oldPositionIdGM, cosSinCacheGM, dstShape, srcShape, dstShape4Negone); PipeBarrier(); - Rope(index, loopN, inLocal, reverseQ, cosSin, oneNeg, inCosSin, - inQueueCosSinCacheBeforeCastLocal, inQueCalLocal, temp1Local, offsetLocal, + Rope(index, loopN, inLocal, reverseQ, cosSin, oneNeg, inCosSin, + inQueueCosSinCacheBeforeCastLocal, inQueCalLocal, temp1Local, offsetLocal, newPositionIdGM, cosSinCacheGM, dstShape, srcShape, dstShape4Negone); PipeBarrier(); diff --git a/kernels/fused_rope/fused_rope_fp32.h b/kernels/fused_rope/fused_rope_fp32.h index 59d24d9..77eaf9d 100644 --- a/kernels/fused_rope/fused_rope_fp32.h +++ b/kernels/fused_rope/fused_rope_fp32.h @@ -316,11 +316,11 @@ __aicore__ inline void FusedRopeFP32::Compute(uint64_t index, uint64_t loopN) for (uint32_t i = 0; i < loopN * this->numHeads; i++) { GatherMask( inQueCalLocal[i * this->rotaryDim], temp1Local[i * this->rotaryDim], static_cast(1), true, - this->rotaryDim, {1, 1, 0, 0}, rsv); + this->rotaryDim, {1, 1, 8, 0}, rsv); PipeBarrier(); GatherMask( inQueCalLocal[i * this->rotaryDim + this->rotaryDim / 2], temp1Local[i * this->rotaryDim], - static_cast(2), true, this->rotaryDim, {1, 1, 0, 0}, rsv); + static_cast(2), true, this->rotaryDim, {1, 1, 8, 0}, rsv); PipeBarrier(); } } else { -- Gitee