Skip to content

Add DispatchSegmentedTopK that's flexible enough to cover arbitrary per-segment and uniform parameter combinations #6864

@elstehle

Description

@elstehle

Context

#5673 lists an extensive feature matrix for DeviceSegmentedTopK. We need a DispatchSegmentedTopK layer that is flexible enough to cover this wide range of use cases across different domains, while simultaneously allowing optimum performance for targeted workloads.

Goals

There are two main goals we want to achieve with a flexible DispatchSegmentedTopK design:

  1. Arbitrary Parameter Mixtures: Support any combination of segment-specific parameters and parameters that are uniform across all segments.
  2. Targeted Optimization: Enable peak performance for workloads with narrow constraints (e.g., a fixed segment size of 512 and K=128).

Segment-Specific vs. Uniform vs. Static Parameters

Segmented Top-K depends on three key parameters: k (items to select), segment size, and sort direction. To support the goals above, each of these must be specifiable in three ways:

  1. Uniform: The same value applies to all segments (e.g., "Get top-5 for all segments").
  2. Segment-specific: Values vary per segment (e.g., "Segment 0 needs top-5, Segment 1 needs top-2").
  3. Static: Values are known at compile-time (e.g., "K is always 32").

Additionally, users should be able to provide the number of segments as either a device-accessible iterator or a host value. Furthermore, if users can guarantee an upper bound on the total number of items (summed across all segments), we can optimize the allocation of temporary buffers for intermediate results (similar to CUB's existing sorting algorithms).

Optimal Performance for Targeted Workloads

LLMs (e.g., MoE and Sparse Attention models) often operate under narrow constraints, such as uniform segment sizes of exactly 512 items or a fixed K=128. We need a mechanism for users to communicate these constraints to the algorithm.

By establishing a flexible interface now, we allow users to provide "hints" (e.g., MaxK=1024) to benefit from future kernel specializations without needing to refactor their algorithm invocations later.

Examples of Constraint-Based Optimization:

  • Small K (Reduce-based): Knowing that K is in the range [1, 8] suggests a reduce-based Top-K algorithm is optimal. If this information is available at compile-time, we can compile only the reduce-based code path, reducing compilation time, binary size, and potentially register pressure.
  • Bounded Segment Sizes: Being aware during runtime dispatch that all segment sizes are in the range [1, 4096] allows us to use a "one-worker-per-segment" approach (e.g., one thread block per segment). This enables us to avoid allocating temporary global memory for intermediate results, significantly reducing overhead.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions