From c9d99a50922d3d0bdf033a0d1621e80604821731 Mon Sep 17 00:00:00 2001 From: g00663789 Date: Mon, 11 Aug 2025 21:32:58 +0800 Subject: [PATCH 1/3] Opensource commit --- CMakeLists.txt | 12 ++ README.md | 5 + kernels/load_and_reshape_flash.cpp | 265 +++++++++++++++++++++++ kernels/multi_layer_mem_kernels.cpp | 233 ++++++++++++++++++++ kernels/single_layer_mem_kernels.cpp | 309 +++++++++++++++++++++++++++ kernels/types.h | 28 +++ npu_lib.cmake | 11 + 7 files changed, 863 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 kernels/load_and_reshape_flash.cpp create mode 100644 kernels/multi_layer_mem_kernels.cpp create mode 100644 kernels/single_layer_mem_kernels.cpp create mode 100644 kernels/types.h create mode 100644 npu_lib.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..9be332b --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,12 @@ +# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}. +# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library +file(GLOB KERNEL_FILES kernels/*.cpp) + +message(STATUS "kernel files: ${KERNEL_FILES}") + +include(npu_lib.cmake) + +target_include_directories(cache_kernels + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..541ff1d --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# KVCache Ops + +KVCache Ops is a simple library containing LLM KVCache related operators for Ascend NPU. + +We currently have a few operators that support KVCache offload. diff --git a/kernels/load_and_reshape_flash.cpp b/kernels/load_and_reshape_flash.cpp new file mode 100644 index 0000000..8117fee --- /dev/null +++ b/kernels/load_and_reshape_flash.cpp @@ -0,0 +1,265 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel_operator.h" +#include +#include "types.h" + +template class LoadAndReshapeFlashCopy { + using local_scalar_t = AscendC::LocalTensor; + +public: + __aicore__ inline LoadAndReshapeFlashCopy() + { + } + + __aicore__ inline void init(GM_ADDR cacheTensor, GM_ADDR keyCachePtr, GM_ADDR valueCachePtr, GM_ADDR slotmappings, + const int64_t numPages, const int64_t hiddenDims, const int32_t pagedSize, + const int32_t numTokens, const int32_t numLayers, const int32_t layerIdx, + const bool page2L, AscendC::TPipe *pipe) + { + this->pipe_ = pipe; + this->numPages_ = numPages; + this->hiddenDims_ = hiddenDims; + this->numTokens_ = numTokens; + this->pagedSize_ = pagedSize; + this->numLayers_ = numLayers; + this->layerIdx_ = layerIdx; + this->valid_ = true; + this->page2L_ = page2L; + + // TODO: Not sure how many to allocate, but let's do 4 blocks of hiddenDims_ + // if it was fp16, 2048, we would get 16kb.? + // should check whether hiddenDims_ is > 192KB. + this->pipe_->InitBuffer(this->pagedTokenQue_, 4, this->hiddenDims_*sizeof(scalar_t)); + } + + __aicore__ inline void reset(){ + this->valid_ = true; + } + + __aicore__ inline void updateTensorMemOffsetAndProcess(__gm__ uint8_t *pagedKeyTensor, + __gm__ uint8_t *pagedValueTensor, + __gm__ uint8_t* nonPagedTensor, + __gm__ uint8_t *slotmappings, const int tokenIdx) + { + __gm__ slot_t *slotmappingPtr = reinterpret_cast<__gm__ slot_t*>(slotmappings); + int64_t slot = static_cast(slotmappingPtr[tokenIdx]); + + if (slot == -1) { + this->valid_ = false; + return; + } + + // for the page tensor + int64_t pagedIdxOffset = slot * this->hiddenDims_; + + // for the lmc tensor + int64_t nonPagedKeyOffset = this->layerIdx_ * this->numTokens_ * this->hiddenDims_ + + tokenIdx * this->hiddenDims_; + + // values are stored after keys in the non-paged tensor + int64_t nonPagedValueOffset = this->numLayers_ * this->numTokens_ * this->hiddenDims_ + + this->layerIdx_ * this->numTokens_ * this->hiddenDims_ + + tokenIdx * this->hiddenDims_; + + // keys + this->keyTokensGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(pagedKeyTensor) + pagedIdxOffset, + this->hiddenDims_); + this->lmcBufferKeyGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(nonPagedTensor) + nonPagedKeyOffset, + this->hiddenDims_); + // values + this->valueTokensGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(pagedValueTensor) + pagedIdxOffset, + this->hiddenDims_); + this->lmcBufferValueGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(nonPagedTensor) + nonPagedValueOffset, + this->hiddenDims_); + } + + __aicore__ inline void processFunc() { + if (!this->valid_) { + return; + } + // 1. Alloc Tensor for local page + local_scalar_t hiddenKeysDimTensor = this->pagedTokenQue_.template AllocTensor(); + local_scalar_t hiddenValuesDimTensor = this->pagedTokenQue_.template AllocTensor();; + + // 2. copy from global tensor into local (GM -> UB) + if (this->page2L_) { + AscendC::DataCopy(hiddenKeysDimTensor, this->keyTokensGlobal_, this->hiddenDims_); + AscendC::DataCopy(hiddenValuesDimTensor, this->valueTokensGlobal_, this->hiddenDims_); + } else { + AscendC::DataCopy(hiddenKeysDimTensor, this->lmcBufferKeyGlobal_, this->hiddenDims_); + AscendC::DataCopy(hiddenValuesDimTensor, this->lmcBufferValueGlobal_, this->hiddenDims_); + } + + // 3. enque vecin + pagedTokenQue_.EnQue(hiddenKeysDimTensor); + pagedTokenQue_.EnQue(hiddenValuesDimTensor); + + // 4. deque vecin, possible to reuse due to QueBind + hiddenKeysDimTensor = pagedTokenQue_.DeQue(); + hiddenValuesDimTensor = pagedTokenQue_.DeQue(); + + // 5. datacopy into GM + if (this->page2L_) { + AscendC::DataCopy(this->lmcBufferKeyGlobal_, hiddenKeysDimTensor, this->hiddenDims_); + AscendC::DataCopy(this->lmcBufferValueGlobal_, hiddenValuesDimTensor, this->hiddenDims_); + } else { + AscendC::DataCopy(this->keyTokensGlobal_, hiddenKeysDimTensor, this->hiddenDims_); + AscendC::DataCopy(this->valueTokensGlobal_, hiddenValuesDimTensor, this->hiddenDims_); + } + // 6. free alloced Tensor + pagedTokenQue_.FreeTensor(hiddenKeysDimTensor); + pagedTokenQue_.FreeTensor(hiddenValuesDimTensor); + } + +private: + AscendC::TPipe *pipe_; + AscendC::TQueBind pagedTokenQue_; + + // [numPages, pagedSize, heads*headsize] + AscendC::GlobalTensor keyTokensGlobal_; + AscendC::GlobalTensor valueTokensGlobal_; + + // Depends on LMC setting whether we store in tokensMajor or not. + // the layout would be the followings: + // [tokens, kvs, heads*headsize] or [kvs, tokens, heads*headsize] + // TODO: check whether should combine the two and use a loop + AscendC::GlobalTensor lmcBufferKeyGlobal_; + AscendC::GlobalTensor lmcBufferValueGlobal_; + + int64_t numPages_; // num vllm npu blocks + int32_t pagedSize_; // per npu block tokens + int64_t hiddenDims_; // heads * headsize + int32_t numTokens_; // num tokens in the cache tensor chunk + int32_t numLayers_; // num layers in the cache tensor + int32_t layerIdx_; // layer idx in the cache tensor + bool valid_; + bool page2L_; // true, from pagedTensor to LMC, false otherwise +}; + +#define LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(TYPE, SLOTTYPE) \ + extern "C" __global__ __aicore__ void load_and_reshape_flash_copy_##TYPE##_##SLOTTYPE( \ + __gm__ uint8_t* dstCacheTensor, __gm__ uint8_t* keyCachePtr, __gm__ uint8_t* valueCachePtr, \ + __gm__ uint8_t* slotmappings, const int64_t hiddenDims, const int64_t numPages, const int32_t pagedSize, \ + const int32_t numTokens, const int32_t numLayers, const int32_t layerIdx, const bool page2L, \ + const int blockNum) \ + { \ + AscendC::TPipe pipe; \ + LoadAndReshapeFlashCopy op{}; \ + op.init(dstCacheTensor, keyCachePtr, valueCachePtr, slotmappings, numPages, hiddenDims, pagedSize, \ + numTokens, numLayers, layerIdx, page2L, &pipe); \ + int64_t bIdx = AscendC::GetBlockIdx(); \ + for (int64_t i = bIdx; i < numTokens; i+=blockNum) \ + { \ + op.reset(); \ + op.updateTensorMemOffsetAndProcess(keyCachePtr, valueCachePtr, dstCacheTensor, slotmappings, i); \ + op.processFunc(); \ + } \ + } + +// Declare support kernel entry +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(half, int32_t); +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(half, int64_t); +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(bfloat16_t, int32_t); +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(bfloat16_t, int64_t); +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(int8_t, int32_t); +LOAD_AND_RESHAPE_FLASH_COPY_TYPE_DECLARE(int8_t, int64_t); + +namespace kvcache_ops { + +#define LOAD_AND_RESHAPE_FLASH_COPY_KERNEL_CALL(TYPE, SLOTTYPE) \ + load_and_reshape_flash_copy_##TYPE##_##SLOTTYPE<<>>(dstCacheTensor, keyCachePtr, \ + valueCachePtr, slotmappings, hiddenDims, numPages, pagedSize, \ + numTokens, numLayers, layerIdx, page2L, blockDim); + +template +void load_and_reshape_kernel_call(uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, uint8_t *keyCachePtr, + uint8_t *valueCachePtr, uint8_t *slotmappings, const int64_t hiddenDims, const int64_t numPages, + const int32_t pagedSize, const int32_t numTokens, const int32_t numLayers, + const int32_t layerIdx, const bool page2L); + + +#define LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(TYPE, SLOTTYPE) \ +template<> \ +void load_and_reshape_kernel_call(uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, \ + uint8_t *keyCachePtr, uint8_t *valueCachePtr, uint8_t *slotmappings, \ + const int64_t hiddenDims, const int64_t numPages, \ + const int32_t pagedSize, const int32_t numTokens, \ + const int32_t numLayers, const int32_t layerIdx, \ + const bool page2L) { \ + LOAD_AND_RESHAPE_FLASH_COPY_KERNEL_CALL(TYPE, SLOTTYPE); \ +} + +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(half, int32_t); +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(half, int64_t); +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int32_t); +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int64_t); +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(int8_t, int32_t); +LOAD_AND_RESHAPE_KERNEL_CALL_TYPE_DECLARE(int8_t, int64_t); + +template +void dispatch_on_slot_type(kvcache_ops::AscendType slotType, uint32_t blockDim, void *stream, + uint8_t *dstCacheTensor, uint8_t *keyCachePtr, uint8_t *valueCachePtr, + uint8_t *slotmappings, const int64_t hiddenDims, const int64_t numPages, + const int32_t pagedSize, const int32_t numTokens, const int32_t numLayers, + const int32_t layerIdx, const bool page2L) { + switch(slotType) { + case kvcache_ops::AscendType::INT32: + load_and_reshape_kernel_call(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numPages, pagedSize, numTokens, numLayers, layerIdx, + page2L); + break; + case kvcache_ops::AscendType::INT64: + load_and_reshape_kernel_call(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numPages, pagedSize, numTokens, numLayers, layerIdx, + page2L); + break; + default: + return; + } +} + +extern void load_and_reshape_flash_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, + uint32_t blockDim, void *stream, + uint8_t *dstCacheTensor, uint8_t *keyCachePtr, uint8_t *valueCachePtr, + uint8_t *slotmappings, const int64_t hiddenDims, const int64_t numPages, + const int32_t pagedSize, const int32_t numTokens, const int32_t numLayers, + const int32_t layerIdx, bool page2L) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); + + switch(type) { + case kvcache_ops::AscendType::FP16: + dispatch_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numPages, pagedSize, numTokens, numLayers, layerIdx, + page2L); + break; + case kvcache_ops::AscendType::BF16: + dispatch_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numPages, pagedSize, numTokens, numLayers, layerIdx, + page2L); + break; + case kvcache_ops::AscendType::INT8: + dispatch_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numPages, pagedSize, numTokens, numLayers, layerIdx, + page2L); + break; + default: + return; + } +} + +} // namespace kvcache_ops diff --git a/kernels/multi_layer_mem_kernels.cpp b/kernels/multi_layer_mem_kernels.cpp new file mode 100644 index 0000000..b0178eb --- /dev/null +++ b/kernels/multi_layer_mem_kernels.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel_operator.h" +#include +#include "types.h" + +template class MultiLayerPagedKVCopy { + using local_scalar_t = AscendC::LocalTensor; + +public: + __aicore__ inline MultiLayerPagedKVCopy() + { + } + + __aicore__ inline void init(GM_ADDR pagedKVCaches, GM_ADDR cacheTensor, GM_ADDR slotmappings, + const int64_t hiddenDims, const int32_t numLayers, const int64_t pageBuffSize, + const int32_t numTokensChunk, const bool page2L, + AscendC::TPipe *pipe) + { + this->pipe_ = pipe; + this->numLayers_ = numLayers; + this->hiddenDims_ = hiddenDims; + this->pageBuffSize_ = pageBuffSize; + this->numTokensChunk_ = numTokensChunk; + this->page2L_ = page2L; + this->valid_ = true; + + this->pipe_->InitBuffer(pagedTokenQue_, 4, this->hiddenDims_*sizeof(scalar_t)); + } + + __aicore__ inline void reset(){ + this->valid_ = true; + } + + __aicore__ inline void updateMemOffset(__gm__ uint8_t *pagedKVCaches, __gm__ uint8_t* cacheTensor, + __gm__ uint8_t *slotmappings, const int tokenIdx, + const int kvIdx, const int layerIdx) + { + __gm__ slot_t *slotmappingPtr = reinterpret_cast<__gm__ slot_t*>(slotmappings); + int64_t slot = static_cast(slotmappingPtr[tokenIdx]); + + if (slot == -1) { + this->valid_ = false; + return; + } + + // its a pointer within the GM addr space, that point to another GM addr space + __gm__ uint8_t * __gm__ *pagedKVCachesPtr = reinterpret_cast<__gm__ uint8_t* __gm__ *>(pagedKVCaches); + + // getting the right ptr to the paged kvcache layer + __gm__ uint8_t *pagedLayerKVCaches = pagedKVCachesPtr[layerIdx]; + + int64_t pagedIdxOffset = kvIdx * this->pageBuffSize_ * this->hiddenDims_ + + slot * this->hiddenDims_; + + int64_t dstTensorIdxOffset = kvIdx * this->numLayers_ * this->numTokensChunk_ * this->hiddenDims_ + + layerIdx * this->numTokensChunk_ * this->hiddenDims_ + + tokenIdx * this->hiddenDims_; + + this->pagedTokenGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(pagedLayerKVCaches) + pagedIdxOffset, + this->hiddenDims_); + this->lmcBufferGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(cacheTensor) + dstTensorIdxOffset, + this->hiddenDims_); + } + + __aicore__ inline void processFunc() { + if (!this->valid_) { + return; + } + // 1. Alloc Tensor for local page + local_scalar_t hiddenDimTensor = pagedTokenQue_.AllocTensor(); + + // 2. copy from global tensor into local + if (this->page2L_) { + AscendC::DataCopy(hiddenDimTensor, this->pagedTokenGlobal_, this->hiddenDims_); + } else { + AscendC::DataCopy(hiddenDimTensor, this->lmcBufferGlobal_, this->hiddenDims_); + } + + // 3. enque vecin + pagedTokenQue_.EnQue(hiddenDimTensor); + // 4. deque vecin, possible to reuse due to QueBind + hiddenDimTensor = pagedTokenQue_.DeQue(); + + // 5. datacopy into GM + if (this->page2L_) { + AscendC::DataCopy(this->lmcBufferGlobal_, hiddenDimTensor, this->hiddenDims_); + } else { + AscendC::DataCopy(this->pagedTokenGlobal_, hiddenDimTensor, this->hiddenDims_); + } + + // 6. free alloced Tensor + pagedTokenQue_.FreeTensor(hiddenDimTensor); + } + + +private: + AscendC::TPipe *pipe_; + AscendC::TQueBind pagedTokenQue_; + + // [layers * [kvs, numPages * pagedSize, heads*headsize]] + AscendC::GlobalTensor pagedTokenGlobal_; + // [kvs, layers, numTokensChunk, heads*headsize] + AscendC::GlobalTensor lmcBufferGlobal_; + int32_t numLayers_; // num layers + int64_t pageBuffSize_; // pages * pageSize + int64_t hiddenDims_; // heads * headSize + int32_t numTokensChunk_; // num tokens in the cache tensor chunk + bool valid_; + bool page2L_; // true, from pagedTensor to LMC, false otherwise +}; + +// NOTE: there are potential micro optimizaiton here. +#define MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(TYPE, SLOTTYPE) \ + extern "C" __global__ __aicore__ void multi_layer_paged_kv_copy_##TYPE##_##SLOTTYPE( \ + __gm__ uint8_t* pagedKVCaches, __gm__ uint8_t* dstCacheTensor, __gm__ uint8_t* slotmappings, \ + const int64_t hiddenDims, const int32_t kvs, const int32_t numLayers, \ + const int64_t pageBuffSize, const int32_t numTokensChunk, const int coreNum, const bool page2L) \ + { \ + AscendC::TPipe pipe; \ + MultiLayerPagedKVCopy op{}; \ + op.init(pagedKVCaches, dstCacheTensor, slotmappings, hiddenDims, \ + numLayers, pageBuffSize, numTokensChunk, page2L, &pipe); \ + int64_t bIdx = AscendC::GetBlockIdx(); \ + for (int64_t i = bIdx; i < numTokensChunk; i+=coreNum) { \ + for (int32_t kvIdx = 0; kvIdx < kvs; kvIdx ++) { \ + for (int32_t layerIdx = 0; layerIdx < numLayers; layerIdx++) { \ + op.reset(); \ + op.updateMemOffset(pagedKVCaches, dstCacheTensor, slotmappings, i, kvIdx, layerIdx); \ + op.processFunc(); \ + } \ + } \ + } \ + } + +// Declare support kernel entry +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int32_t); +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int64_t); +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int32_t); +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int64_t); +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int32_t); +MULTI_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int64_t); + +namespace kvcache_ops { + +#define MULTI_LAYER_PAGED_KV_COPY_KERNEL_CALL(TYPE, SLOTTYPE) \ + multi_layer_paged_kv_copy_##TYPE##_##SLOTTYPE<<>>(pagedKVCaches, dstCacheTensor, \ + slotmappings, hiddenDims, kvs, \ + numLayers, pageBuffSize, \ + numTokensChunk, blockDim, page2L); + +template +void multi_layer_paged_kernel(uint32_t blockDim, void *stream, uint8_t *pagedKVCaches, uint8_t *dstCacheTensor, + uint8_t *slotmappings, const int64_t hiddenDims, const int32_t kvs, const int32_t numLayers, + const int64_t pageBuffSize, const int32_t numTokensChunk, const bool page2L); + +#define MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(TYPE, SLOTTYPE) \ +template<> \ +void multi_layer_paged_kernel(uint32_t blockDim, void *stream, uint8_t *pagedKVCaches, \ + uint8_t *dstCacheTensor, uint8_t *slotmappings, \ + const int64_t hiddenDims, const int32_t kvs, const int32_t numLayers, \ + const int64_t pageBuffSize, const int32_t numTokensChunk, \ + const bool page2L){ \ + MULTI_LAYER_PAGED_KV_COPY_KERNEL_CALL(TYPE, SLOTTYPE); \ +} + +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int32_t); +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int64_t); +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int32_t); +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int64_t); +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int32_t); +MULTI_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int64_t); + +template +void dispatch_paged_kernel_on_slot_type(kvcache_ops::AscendType slotType, uint32_t blockDim, void *stream, + uint8_t *pagedKVCaches, uint8_t *dstCacheTensor, uint8_t *slotmappings, + const int64_t hiddenDims, const int32_t kvs, const int32_t numLayers, + const int64_t pageBuffSize, const int32_t numTokensChunk, const bool page2L) { + switch(slotType) { + case kvcache_ops::AscendType::INT32: + multi_layer_paged_kernel(blockDim, stream, pagedKVCaches, dstCacheTensor, slotmappings, + hiddenDims, kvs, numLayers, pageBuffSize, numTokensChunk, page2L); + break; + case kvcache_ops::AscendType::INT64: + multi_layer_paged_kernel(blockDim, stream, pagedKVCaches, dstCacheTensor, slotmappings, + hiddenDims, kvs, numLayers, pageBuffSize, numTokensChunk, page2L); + break; + default: + return; + } +} + +extern void multi_layer_kv_transfer_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, + uint32_t blockDim, void *stream, uint8_t *pagedKVCaches, + uint8_t *dstCacheTensor, uint8_t *slotmappings, + const int64_t hiddenDims, const int32_t kvs, const int32_t numLayers, + const int64_t pageBuffSize, const int32_t numTokensChunk, const bool page2L) +{ + switch(type) { + case kvcache_ops::AscendType::FP16: + dispatch_paged_kernel_on_slot_type(slotType, blockDim, stream, pagedKVCaches, dstCacheTensor, + slotmappings, hiddenDims, kvs, numLayers, pageBuffSize, + numTokensChunk, page2L); + break; + case kvcache_ops::AscendType::BF16: + dispatch_paged_kernel_on_slot_type(slotType, blockDim, stream, pagedKVCaches, dstCacheTensor, + slotmappings, hiddenDims, kvs, numLayers, pageBuffSize, + numTokensChunk, page2L); + break; + case kvcache_ops::AscendType::INT8: + dispatch_paged_kernel_on_slot_type(slotType, blockDim, stream, pagedKVCaches, dstCacheTensor, + slotmappings, hiddenDims, kvs, numLayers, pageBuffSize, + numTokensChunk, page2L); + break; + default: + return; + } +} + +} // namespace kvcache_ops \ No newline at end of file diff --git a/kernels/single_layer_mem_kernels.cpp b/kernels/single_layer_mem_kernels.cpp new file mode 100644 index 0000000..dd75af6 --- /dev/null +++ b/kernels/single_layer_mem_kernels.cpp @@ -0,0 +1,309 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel_operator.h" +#include +#include "types.h" + +template class SingleLayerPagedKVCopy { + using local_scalar_t = AscendC::LocalTensor; + +public: + __aicore__ inline SingleLayerPagedKVCopy() + { + } + + __aicore__ inline void init(GM_ADDR cacheTensor, GM_ADDR keyCachePtr, GM_ADDR valueCachePtr, GM_ADDR slotmappings, + const int64_t hiddenDims, const int32_t numTokens, const bool page2L, + const bool tokenMajor, AscendC::TPipe *pipe) + { + this->pipe_ = pipe; + this->hiddenDims_ = hiddenDims; + this->numTokens_ = numTokens; + this->tokenMajor_ = tokenMajor; + this->valid_ = true; + this->page2L_ = page2L; + if constexpr (IsMLA) { + this->numKvs_ = 1; + } else { + this->numKvs_ = 2; + } + // TODO: Not sure how many to allocate, but let's do 4 blocks of hiddenDims_ + // if it was fp16, 2048, we would get 16kb ? + this->pipe_->InitBuffer(this->pagedTokenQue_, 4, this->hiddenDims_*sizeof(scalar_t)); + } + + __aicore__ inline void reset(){ + this->valid_ = true; + } + + __aicore__ inline void updateTensorMemOffsetAndProcess(__gm__ uint8_t *pagedTensor, __gm__ uint8_t* nonPagedTensor, + __gm__ uint8_t *slotmappings, const int tokenIdx, const int kvIdx) + { + __gm__ slot_t *slotmappingPtr = reinterpret_cast<__gm__ slot_t*>(slotmappings); + int64_t slot = slotmappingPtr[tokenIdx]; + + if (slot == -1) { + this->valid_ = false; + return; + } + + // for the page tensor + int64_t pagedIdxOffset = slot * this->hiddenDims_; + + // for the lmc tensor + int64_t nonPagedIdxOffset = -1; + if (this->tokenMajor_) { + nonPagedIdxOffset = tokenIdx * this->numKvs_ * this->hiddenDims_ + + kvIdx * this->hiddenDims_; + } else { + nonPagedIdxOffset = kvIdx * this->numTokens_ * this -> hiddenDims_ + + tokenIdx * this->hiddenDims_; + } + + if (kvIdx == 0) { + // keys + this->keyTokensGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(pagedTensor) + pagedIdxOffset, + this->hiddenDims_); + this->lmcBufferKeyGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(nonPagedTensor) + nonPagedIdxOffset, + this->hiddenDims_); + } else { + // values + this->valueTokensGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(pagedTensor) + pagedIdxOffset, + this->hiddenDims_); + this->lmcBufferValueGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ scalar_t*>(nonPagedTensor) + nonPagedIdxOffset, + this->hiddenDims_); + } + } + + __aicore__ inline void processFunc() { + if (!this->valid_) { + return; + } + // 1. Alloc Tensor for local page + local_scalar_t hiddenKeysDimTensor = this->pagedTokenQue_.template AllocTensor(); + local_scalar_t hiddenValuesDimTensor; + if constexpr(!IsMLA) { + hiddenValuesDimTensor = this->pagedTokenQue_.template AllocTensor(); + } + + // 2. copy from global tensor into local + if (this->page2L_) { + AscendC::DataCopy(hiddenKeysDimTensor, this->keyTokensGlobal_, this->hiddenDims_); + if constexpr (!IsMLA) { + AscendC::DataCopy(hiddenValuesDimTensor, this->valueTokensGlobal_, this->hiddenDims_); + } + } else { + AscendC::DataCopy(hiddenKeysDimTensor, this->lmcBufferKeyGlobal_, this->hiddenDims_); + if constexpr(!IsMLA) { + AscendC::DataCopy(hiddenValuesDimTensor, this->lmcBufferValueGlobal_, this->hiddenDims_); + } + } + + // 3. enque vecin + pagedTokenQue_.EnQue(hiddenKeysDimTensor); + if constexpr(!IsMLA) { + pagedTokenQue_.EnQue(hiddenValuesDimTensor); + } + + // 4. deque vecin, possible to reuse due to QueBind + hiddenKeysDimTensor = pagedTokenQue_.DeQue(); + if constexpr(!IsMLA) { + hiddenValuesDimTensor = pagedTokenQue_.DeQue(); + } + + // 5. datacopy into GM + if (this->page2L_) { + AscendC::DataCopy(this->lmcBufferKeyGlobal_, hiddenKeysDimTensor, this->hiddenDims_); + if constexpr(!IsMLA) { + AscendC::DataCopy(this->lmcBufferValueGlobal_, hiddenValuesDimTensor, this->hiddenDims_); + } + } else { + AscendC::DataCopy(this->keyTokensGlobal_, hiddenKeysDimTensor, this->hiddenDims_); + if constexpr(!IsMLA) { + AscendC::DataCopy(this->valueTokensGlobal_, hiddenValuesDimTensor, this->hiddenDims_); + } + } + + // 6. free alloced Tensor + pagedTokenQue_.FreeTensor(hiddenKeysDimTensor); + if constexpr(!IsMLA) { + pagedTokenQue_.FreeTensor(hiddenValuesDimTensor); + } + } + +private: + AscendC::TPipe *pipe_; + // a depth of 2 + AscendC::TQueBind pagedTokenQue_; + + // [kvs, numPages * pagedSize, heads*headsize] + AscendC::GlobalTensor keyTokensGlobal_; + // iff !isMLA + AscendC::GlobalTensor valueTokensGlobal_; + + // Depends on LMC setting whether we store in tokensMajor or not. + // the layout would be the followings: + // [tokens, kvs, heads*headsize] or [kvs, tokens, heads*headsize] + // TODO: check whether should combine the two and use a loop + AscendC::GlobalTensor lmcBufferKeyGlobal_; + AscendC::GlobalTensor lmcBufferValueGlobal_; + + int64_t hiddenDims_; // heads * headsize + int32_t numTokens_; // num tokens in the cache tensor chunk + int16_t numKvs_; // 1 if MLA else 2 + bool page2L_; // whether the direction of copy is from page to lmc + bool tokenMajor_; // whether the lmc buffer is in token major. + bool valid_; +}; + +#define SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(TYPE, SLOTTYPE, ISMLA) \ + extern "C" __global__ __aicore__ void single_layer_paged_kv_copy_##TYPE##_##SLOTTYPE##_##ISMLA( \ + __gm__ uint8_t* dstCacheTensor, __gm__ uint8_t* keyCachePtr, __gm__ uint8_t* valueCachePtr, \ + __gm__ uint8_t* slotmappings, const int64_t hiddenDims, const int32_t numTokens, const int coreNums, \ + const bool page2L, const bool tokenMajor) \ + { \ + AscendC::TPipe pipe; \ + SingleLayerPagedKVCopy op{}; \ + op.init(dstCacheTensor, keyCachePtr, valueCachePtr, slotmappings, hiddenDims, numTokens, \ + page2L, tokenMajor, &pipe); \ + int64_t bIdx = AscendC::GetBlockIdx(); \ + for (int64_t i = bIdx; i < numTokens; i+=coreNums) \ + { \ + op.reset(); \ + op.updateTensorMemOffsetAndProcess(keyCachePtr, dstCacheTensor, slotmappings, i, 0); \ + if constexpr(!ISMLA) { \ + op.updateTensorMemOffsetAndProcess(valueCachePtr, dstCacheTensor, slotmappings, i, 1); \ + } \ + op.processFunc(); \ + } \ + } + +// Declare support kernel entry +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int32_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int32_t, true); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int32_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int32_t, true); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int32_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int32_t, true); + +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int64_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(half, int64_t, true); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int64_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(bfloat16_t, int64_t, true); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int64_t, false); +SINGLE_LAYER_PAGED_KV_COPY_TYPE_DECLARE(int8_t, int64_t, true); + +namespace kvcache_ops { + +#define SINGLE_LAYER_PAGED_KV_COPY_KERNEL_CALL(TYPE, SLOTTYPE, ISMLA) \ + single_layer_paged_kv_copy_##TYPE##_##SLOTTYPE##_##ISMLA<<>>(dstCacheTensor, \ + keyCachePtr, valueCachePtr, slotmappings, hiddenDims, \ + numTokens, blockDim, page2L, tokenMajor); + +template +void single_layer_paged_kernel(uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, uint8_t *keyCachePtr, + uint8_t *valueCachePtr, uint8_t *slotmappings, const int64_t hiddenDims, + const int32_t numTokens, const bool page2L, const bool tokenMajor); + +#define SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(TYPE, SLOTTYPE, ISMLA) \ +template<> \ +void single_layer_paged_kernel(uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, \ + uint8_t *keyCachePtr, uint8_t *valueCachePtr, uint8_t *slotmappings, \ + const int64_t hiddenDims, const int32_t numTokens, const bool page2L, \ + const bool tokenMajor){ \ + SINGLE_LAYER_PAGED_KV_COPY_KERNEL_CALL(TYPE, SLOTTYPE, ISMLA); \ +} + + +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int32_t, false); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int64_t, false); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int32_t, false); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int64_t, false); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int32_t, false); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int64_t, false); + +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int32_t, true); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(half, int64_t, true); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int32_t, true); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(bfloat16_t, int64_t, true); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int32_t, true); +SINGLE_LAYER_PAGED_KERNEL_CALL_TYPE_DECLARE(int8_t, int64_t, true); + + + +template +void dispatch_single_layer_kernel_on_slot_type(kvcache_ops::AscendType slotType, uint32_t blockDim, void *stream, + uint8_t *dstCacheTensor, uint8_t *keyCachePtr, uint8_t *valueCachePtr, + uint8_t *slotmappings, const int64_t hiddenDims, const int32_t numTokens, + const bool page2L, const bool tokenMajor, const bool isMLA) { + if (isMLA) { + switch(slotType) { + case kvcache_ops::AscendType::INT32: + single_layer_paged_kernel(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numTokens, page2L, tokenMajor); + break; + case kvcache_ops::AscendType::INT64: + single_layer_paged_kernel(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numTokens, page2L, tokenMajor); + break; + default: + return; + } + } else { + switch(slotType) { + case kvcache_ops::AscendType::INT32: + single_layer_paged_kernel(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numTokens, page2L, tokenMajor); + break; + case kvcache_ops::AscendType::INT64: + single_layer_paged_kernel(blockDim, stream, dstCacheTensor, keyCachePtr, valueCachePtr, + slotmappings, hiddenDims, numTokens, page2L, tokenMajor); + break; + default: + return; + } + } + +} + + +extern void single_layer_kv_transfer_kernel(kvcache_ops::AscendType type, kvcache_ops::AscendType slotType, + uint32_t blockDim, void *stream, uint8_t *dstCacheTensor, + uint8_t *keyCachePtr, uint8_t *valueCachePtr, + uint8_t *slotmappings, const int64_t hiddenDims, const int32_t numTokens, + const bool page2L, const bool tokenMajor, const bool isMLA) +{ + switch(type) { + case kvcache_ops::AscendType::FP16: + dispatch_single_layer_kernel_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, + valueCachePtr, slotmappings, hiddenDims, numTokens, page2L, + tokenMajor, isMLA); + break; + case kvcache_ops::AscendType::BF16: + dispatch_single_layer_kernel_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, + valueCachePtr, slotmappings, hiddenDims, numTokens, + page2L, tokenMajor, isMLA); + break; + case kvcache_ops::AscendType::INT8: + dispatch_single_layer_kernel_on_slot_type(slotType, blockDim, stream, dstCacheTensor, keyCachePtr, + valueCachePtr, slotmappings, hiddenDims, numTokens, page2L, + tokenMajor, isMLA); + default: + return; + } +} + +} // namespace kvcache_ops \ No newline at end of file diff --git a/kernels/types.h b/kernels/types.h new file mode 100644 index 0000000..e2e9433 --- /dev/null +++ b/kernels/types.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace kvcache_ops { +enum struct AscendType { + FP16 = 0, + BF16 = 1, + FP32 = 2, + INT8 = 3, + INT32 = 4, + INT64 = 5, +}; +} \ No newline at end of file diff --git a/npu_lib.cmake b/npu_lib.cmake new file mode 100644 index 0000000..53f6e40 --- /dev/null +++ b/npu_lib.cmake @@ -0,0 +1,11 @@ +if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake) +else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the cann package is installed") +endif() +include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + +# ascendc_library use to add kernel file to generate ascendc library +ascendc_library(cache_kernels SHARED ${KERNEL_FILES}) -- Gitee From 9db2994811671ccf30567c0f71bdd4aa549ecf49 Mon Sep 17 00:00:00 2001 From: gfmyeung Date: Mon, 11 Aug 2025 22:55:18 +0800 Subject: [PATCH 2/3] update license copyright --- kernels/types.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/types.h b/kernels/types.h index e2e9433..31fd173 100644 --- a/kernels/types.h +++ b/kernels/types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. -- Gitee From 1e0a7d10b8c3c1f3014a2de852d5a2a247dc2fc1 Mon Sep 17 00:00:00 2001 From: gfmyeung Date: Mon, 11 Aug 2025 22:55:45 +0800 Subject: [PATCH 3/3] add .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..600d2d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.vscode \ No newline at end of file -- Gitee