Fused LogP¶
Fused LogP computes selected token log probabilities from model logits. It targets RL
post-training workloads where repeated log_softmax + gather operations create memory
pressure at large group sizes.
Entry Point¶
from rl_engine.kernels.registry import kernel_registry
logp_op = kernel_registry.get_op("logp")
output = logp_op(logits, token_ids)
Backends¶
| Backend | Wrapper | Native symbol | Notes |
|---|---|---|---|
| CUDA SM90 | FusedLogpSM90Op |
_C.fused_logp_sm90 |
TMA-oriented path for Hopper-class GPUs. |
| CUDA generic | FusedLogpGenericOp |
_C.fused_logp |
Generic compiled extension fallback. |
| PyTorch native | NativeOp |
None | Baseline fallback path. |
Tensor Contract¶
| Argument | Shape | Dtype | Requirements |
|---|---|---|---|
logits |
[N, V] |
bfloat16 for SM90 path |
Contiguous, on the target device. |
token_ids / labels |
[N] |
Converted to int32 |
Same logical device as logits. |
| Output | [N] |
Backend-defined tensor dtype | One selected log probability per row. |
Reference Semantics¶
ref = torch.log_softmax(logits.float(), dim=-1)
ref = torch.gather(ref, dim=-1, index=token_ids.unsqueeze(-1).long()).squeeze(-1)
Tests¶
The current accuracy test compares the dispatched operator with a PyTorch reference and uses a dtype-dependent threshold.
Implementation Files¶
rl_engine/kernels/registry.pyrl_engine/kernels/ops/cuda.pycsrc/ops.cppcsrc/fused_logp_kernel.cucsrc/cuda/fused_logp_sm90.cu