Skip to content

Conversation

@spike-zhu
Copy link
Contributor

python 测试截图:
image

image

* This function initializes a descriptor that holds all the metadata needed
* for the paged attention computation.
*
* @param handle The handle to the InfiniOP library context.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请在描述中标出各张量的形状和含义

// ----- Input Tensors -----
const Tdata *k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh]
const Tdata *v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh]
const int32_t *slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成int64,或者最好使用template

const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const int32_t *block_tables_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用template或者更大的数据类型,int32容易溢出

float scale) {

auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少block_tables_desc、seq_lens_desc的类型检查

void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_info.head_size == 128) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果计算时对head dim有要求,请在创建算子的描述时做检查

const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const ptrdiff_t o_stride = q_stride / 3; // qkv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也要改吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants