diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md index 40b109d102d5..6dfa95eb16fa 100644 --- a/jax/experimental/pallas/g3doc/debugging.md +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -22,7 +22,7 @@ Note that interpret mode will not be able to fully replicate the behavior or pro ### debug_print -The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation. +The `pl.debug_print` function can be used to print runtime values inside of a kernel. For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option.