-
Notifications
You must be signed in to change notification settings - Fork 87
issue/834: add paged attention for nvidia gpu #836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| * This function initializes a descriptor that holds all the metadata needed | ||
| * for the paged attention computation. | ||
| * | ||
| * @param handle The handle to the InfiniOP library context. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请在描述中标出各张量的形状和含义
| // ----- Input Tensors ----- | ||
| const Tdata *k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh] | ||
| const Tdata *v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh] | ||
| const int32_t *slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成int64,或者最好使用template
| const Tdata *q_, | ||
| const Tdata *k_cache_, | ||
| const Tdata *v_cache_, | ||
| const int32_t *block_tables_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用template或者更大的数据类型,int32容易溢出
| float scale) { | ||
|
|
||
| auto dtype = q_desc->dtype(); | ||
| CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少block_tables_desc、seq_lens_desc的类型检查
| void *stream_) const { | ||
| cudaStream_t stream = (cudaStream_t)stream_; | ||
| if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { | ||
| if (_info.head_size == 128) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果计算时对head dim有要求,请在创建算子的描述时做检查
| const int seq_idx = blockIdx.y; | ||
| const int head_idx = blockIdx.x; | ||
| const int num_heads = gridDim.x; | ||
| const ptrdiff_t o_stride = q_stride / 3; // qkv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也要改吧
fb9bd9c to
9695b7f
Compare
python 测试截图:
