Source code for jax.experimental.key_reuse._core

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import defaultdict
from functools import partial, reduce, wraps
from typing import Any, Callable, NamedTuple

import jax
from jax import lax
from jax import tree_util
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import source_info_util
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache

from jax.experimental.shard_map import shard_map_p
import numpy as np


class Sink(NamedTuple):
  idx: int
  mask: bool | np.ndarray = True

  def __repr__(self):
    if isinstance(self.mask, bool) and self.mask:
      return f"Sink({self.idx})"
    else:
      return f"Sink({self.idx}, mask={self.mask})"


class Source(NamedTuple):
  idx: int
  mask: bool | np.ndarray = True

  def __repr__(self):
    if isinstance(self.mask, bool) and self.mask:
      return f"Source({self.idx})"
    else:
      return f"Source({self.idx}, mask={self.mask})"

class Forward(NamedTuple):
  in_idx: int
  out_idx: int


class KeyReuseSignature(NamedTuple):
  sinks: list[Sink]
  sources: list[Source]
  forwards: list[Forward] = []

  def check_signature(self, *args, funcname="function", context=None):
    for sink in self.sinks:
      if not isinstance(args[sink.idx], prng.PRNGKeyArray):
        continue
      if np.any(args[sink.idx]._consumed & sink.mask):
        msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
        if context:
          msg += " {context}"
        raise KeyReuseError(msg)

  def update_consumption(self, args_in, args_out):
    for sink in self.sinks:
      arg = args_in[sink.idx]
      if isinstance(arg, prng.PRNGKeyArray):
        arg._consumed = arg._consumed | sink.mask
    for arg in args_out:
      if isinstance(arg, prng.PRNGKeyArray):
        arg._consumed = True
    for source in self.sources:
      if isinstance(args_out[source.idx], prng.PRNGKeyArray):
        args_out[source.idx]._consumed = ~np.asarray(source.mask)
    for forward in self.forwards:
      arg_in = args_in[forward.in_idx]
      arg_out = args_out[forward.out_idx]
      if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
        arg_out._consumed = arg_in._consumed


[docs] class KeyReuseError(RuntimeError): pass
consume_p = core.Primitive("consume") consume_p.def_impl(lambda x: x) consume_p.def_abstract_eval(lambda x: x) batching.defvectorized(consume_p) mlir.register_lowering( consume_p, mlir.lower_fun(lambda x: x, multiple_results=False)) def consume(key): """Consume the key and return a consumed copy.""" return consume_p.bind(key) assert_consumed_value_p = core.Primitive("assert_consumed_value") assert_consumed_value_p.def_impl(lambda x, *, value: x) assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x) batching.defvectorized(assert_consumed_value_p) mlir.register_lowering( assert_consumed_value_p, mlir.lower_fun(lambda x, *, value: x, multiple_results=False)) def assert_unconsumed(key): """Assert that a key is unconsumed""" assert_consumed_value_p.bind(key, value=False) def assert_consumed(key, value=True): """Assert that a key is consumed""" assert_consumed_value_p.bind(key, value=value) def _check_consumed_value(eqn, consumed): """Extra check for use with assert_consumed_value_p""" expected = eqn.params['value'] if not np.all(consumed == expected): if np.all(expected): raise AssertionError(f"Expected key to be consumed in {eqn}") elif not np.any(expected): raise AssertionError(f"Expected key to not be consumed in {eqn}") else: raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}") # The behavior of most primitives can be described via simple signatures. key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)]) key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) # TODO(jakevdp): should fold_in sink its input key? # key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)]) key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], []) # TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.device_put_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.reshape_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature([], [Source(0)], []) # TODO(jakevdp): should unwrap sink its input key? key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature([], [], []) key_reuse_signatures[debug_callback_p] = KeyReuseSignature([], []) key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature([Sink(1)], [], [Forward(0, 0)]) key_reuse_signatures[lax.gather_p] = KeyReuseSignature([], [], [Forward(0, 0)]) key_reuse_signatures[lax.scatter_p] = KeyReuseSignature([Sink(2)], [], [Forward(0, 0)]) # Equality checks don't consume key_reuse_signatures[lax.eq_p] = KeyReuseSignature([], [], []) key_reuse_signatures[lax.ne_p] = KeyReuseSignature([], [], []) # Rules which require more dynamic logic. key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {} # The default signature will Sink all key inputs, and not Source any. def unknown_signature(eqn): def is_key(var: core.Atom): return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) return KeyReuseSignature( sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)], sources=[], ) @weakref_lru_cache def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature: """Parse the jaxpr to determine key reuse signature""" consumed: dict[core.Atom, bool | np.ndarray] = {} forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs. def resolve_forwards(var: core.Atom) -> core.Atom: if not forwards: return var for _ in range(len(forwards) + 1): if isinstance(var, core.Literal): return var if var in forwards: var = forwards[var] else: return var raise ValueError("forwarding cycle detected") def is_key(var: core.Atom): return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key) def sink(var: core.Atom, mask=True): if not is_key(var): return var = resolve_forwards(var) assert not isinstance(var, core.Literal) if np.any(np.logical_and(consumed.get(var, False), mask)): return True consumed[var] = np.logical_or(consumed.get(var, False), mask) def source(var: core.Atom, mask=False): if not is_key(var): return var = resolve_forwards(var) assert not isinstance(var, core.Literal) consumed[var] = mask def is_consumed(var: core.Atom): var = resolve_forwards(var) if isinstance(var, core.Literal): return False return consumed.get(var, False) for eqn in jaxpr.eqns: traceback = eqn.source_info.traceback name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack with source_info_util.user_context(traceback, name_stack=name_stack): if eqn.primitive in key_reuse_signatures: signature = key_reuse_signatures[eqn.primitive] elif eqn.primitive in key_reuse_signatures_dynamic: signature = key_reuse_signatures_dynamic[eqn.primitive](eqn) else: signature = unknown_signature(eqn) if eqn.primitive == assert_consumed_value_p: # This is a special case that goes beyond normal key reuse logic. _check_consumed_value(eqn, is_consumed(eqn.invars[0])) for in_idx, out_idx in signature.forwards: forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx] for snk in signature.sinks: if not 0 <= snk.idx < len(eqn.invars): raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]") if sink(eqn.invars[snk.idx], snk.mask): raise KeyReuseError(f"In {eqn.primitive}, argument {snk.idx} is already consumed.") for var in eqn.outvars: if not isinstance(var, core.Literal) and var not in forwards: source(var, True) # consumed unless in a Source. for src in signature.sources: if not 0 <= src.idx < len(eqn.outvars): raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]") source(eqn.outvars[src.idx]) return KeyReuseSignature( sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) if is_key(v) and np.any(consumed.get(v, False))], sources=[Source(i) for i, v in enumerate(jaxpr.outvars) if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)], forwards=[Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] for idx_out, outvar in enumerate(jaxpr.outvars) if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars] ) def check_key_reuse_jaxpr(jaxpr: core.Jaxpr) -> None: """Check the jaxpr for key reuse.""" get_jaxpr_type_signature(jaxpr) def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: """Function to statically check key reuse.""" args_flat, in_tree = tree_util.tree_flatten(args) in_avals_flat = [core.get_aval(arg) for arg in args_flat] wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) check_key_reuse_jaxpr(jaxpr) #---------------------------------------------------------------------------------- # key reuse rules for particular primitives: def _slice_signature(eqn): in_aval = eqn.invars[0].aval if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key): return KeyReuseSignature([], [], [Forward(0, 0)]) if any(core.is_symbolic_dim(s) for s in in_aval.shape): return KeyReuseSignature([], [], [Forward(0, 0)]) start_indices = eqn.params['start_indices'] limit_indices = eqn.params['limit_indices'] strides = eqn.params['strides'] or (1,) * len(start_indices) idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides)) sink = np.zeros(in_aval.shape, dtype=bool) sink[idx] = True return KeyReuseSignature([Sink(0, sink)], [Source(0)]) key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature def _pjit_key_type_signature(eqn): return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr) key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature def _shard_map_type_signature(eqn): return get_jaxpr_type_signature(eqn.params['jaxpr']) key_reuse_signatures_dynamic[shard_map_p] = _shard_map_type_signature def _cond_key_type_signature(eqn): signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']] sinks = defaultdict(list) sources = defaultdict(list) for sig in signatures: for sink in sig.sinks: sinks[sink.idx].append(sink.mask) for source in sig.sources: sources[source.idx].append(source.mask) combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()] combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()] combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in set.intersection(*(set(sig.forwards) for sig in signatures))] return KeyReuseSignature(combined_sinks, combined_sources, combined_forwards) key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature def _scan_key_type_signature(eqn): jaxpr = eqn.params['jaxpr'].jaxpr num_consts = eqn.params['num_consts'] num_carry = eqn.params['num_carry'] signature = get_jaxpr_type_signature(jaxpr) # scan body should not consume key in constants if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts): raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, " "because key constants are repeatedly consumed:\n" f" {signature=}\n" f" {eqn=}\n" f" {jaxpr=}") # scan carry should only consume keys that are sourced on output. carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)} carry_sources = {s.idx: s.mask for s in signature.sources if 0 <= s.idx < num_carry and np.any(s.mask)} if not set(carry_sinks).issubset(set(carry_sources)): # TODO(jakevdp): check that masks match raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, " "because consumed inputs don't match sourced outputs:\n" f" {signature=}\n" f" {eqn=}\n" f" {jaxpr=}") return signature key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature def _while_key_type_signature(eqn): cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr cond_nconsts = eqn.params['cond_nconsts'] body_jaxpr = eqn.params['body_jaxpr'].jaxpr body_nconsts = eqn.params['body_nconsts'] cond_signature = get_jaxpr_type_signature(cond_jaxpr) body_signature = get_jaxpr_type_signature(body_jaxpr) # Error if there are sinks among consts. if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts): raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " f" {cond_signature=}\n" f" {eqn=}") if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts): raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " f" {body_signature=}\n" f" {eqn=}") # carry should only consume keys that are sourced on output. body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts} cond_carry_sinks = {s.idx - cond_nconsts: s.mask for s in cond_signature.sinks if s.idx >= cond_nconsts} carry_sources = {s.idx: s.mask for s in body_signature.sources} # TODO(jakevdp): check masks at each index? if not (cond_carry_sinks.keys() <= carry_sources.keys()): raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: " f" {cond_signature=}\n" f" {eqn=}") if not (body_carry_sinks.keys() <= carry_sources.keys()): raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: " f" {body_signature=}\n" f" {eqn=}") if body_carry_sinks.keys() & cond_carry_sinks.keys(): raise KeyReuseError("while_loop cond and body functions both use the same key: " f" {cond_signature=}\n" f" {body_signature=}\n" f" {eqn=}") return body_signature key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature def _remat_key_type_signature(eqn): # The assumption here is that the non-differentiated pass contains all relevant # key usage, and the differentiated pass # 1) will only consume keys that are already consumed in the non-differentiated pass # 2) will never create keys # Therefore, the differentiated pass is a no-op. if eqn.params['differentiated']: return KeyReuseSignature([], []) return get_jaxpr_type_signature(eqn.params['jaxpr']) key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature # TODO(jakevdp): when we integrate key reuse checks more tightly with JAX, # we should move this logic directly into each primitive impl. def key_reuse_impl_rule(prim, original_rule): @wraps(original_rule) def key_reuse_impl(*args, **kwargs): if config.enable_key_reuse_checks.value: if prim == pjit.pjit_p: funcname = "jit-compiled function" jaxpr = kwargs['jaxpr'].jaxpr signature = get_jaxpr_type_signature(jaxpr) elif prim in key_reuse_signatures: funcname = str(prim) jaxpr = None signature = key_reuse_signatures[prim] elif prim in key_reuse_signatures_dynamic: funcname = str(prim) jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr signature = get_jaxpr_type_signature(jaxpr) else: raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}") signature.check_signature(*args, funcname=funcname) result = original_rule(*args, **kwargs) signature.update_consumption(args, result if prim.multiple_results else [result]) return result else: return original_rule(*args, **kwargs) return key_reuse_impl for prim in (*key_reuse_signatures, *key_reuse_signatures_dynamic): prim.impl = key_reuse_impl_rule(prim, prim.impl) # type: ignore[method-assign]