Skip to content

Fused LogP

Fused LogP computes selected token log probabilities from model logits. It targets RL post-training workloads where repeated log_softmax + gather operations create memory pressure at large group sizes.

Entry Point

from rl_engine.kernels.registry import kernel_registry

logp_op = kernel_registry.get_op("logp")
output = logp_op(logits, token_ids)

Backends

Backend Wrapper Native symbol Notes
CUDA SM90 FusedLogpSM90Op _C.fused_logp_sm90 TMA-oriented path for Hopper-class GPUs.
CUDA generic FusedLogpGenericOp _C.fused_logp Generic compiled extension fallback.
PyTorch native NativeOp None Baseline fallback path.

Tensor Contract

Argument Shape Dtype Requirements
logits [N, V] bfloat16 for SM90 path Contiguous, on the target device.
token_ids / labels [N] Converted to int32 Same logical device as logits.
Output [N] Backend-defined tensor dtype One selected log probability per row.

Reference Semantics

ref = torch.log_softmax(logits.float(), dim=-1)
ref = torch.gather(ref, dim=-1, index=token_ids.unsqueeze(-1).long()).squeeze(-1)

Tests

python tests/test_op_accuracy.py

The current accuracy test compares the dispatched operator with a PyTorch reference and uses a dtype-dependent threshold.

Implementation Files

  • rl_engine/kernels/registry.py
  • rl_engine/kernels/ops/cuda.py
  • csrc/ops.cpp
  • csrc/fused_logp_kernel.cu
  • csrc/cuda/fused_logp_sm90.cu