Source code for jax._src.compiler

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Interface to the compiler

from __future__ import annotations

from collections.abc import Sequence
import logging
import os
import tempfile
import time
from typing import Any
import warnings

from jax._src import compilation_cache
from jax._src import config as config
from jax._src import distributed
from jax._src import lib
from jax._src import monitoring
from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.xla_bridge import process_count
import numpy as np


_DISABLE_MOST_OPTIMIZATIONS = config.DEFINE_bool(
    'jax_disable_most_optimizations',
    config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
    'Try not to do much optimization work. This can be useful if the cost of '
    'optimization is greater than that of running a less-optimized program.')

_COMPILER_DETAILED_LOGGING_MIN_OPS = config.DEFINE_integer(
    "jax_compiler_detailed_logging_min_ops",
    config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
    help=(
        'How big should a module be in MLIR operations before JAX enables '
        'detailed compiler logging? The intent of this flag is to suppress '
        'detailed logging for small/uninteresting computations.'
    ),
)

# The special XLA-AutoFDO profile version that indicates that a profile is not
# available and retrieval should not be attempted.
_NO_PROFILE_DONT_RETRIEVE = -1

traceback_util.register_exclusion(__file__)

CompileOptions = xc.CompileOptions

logger = logging.getLogger(__name__)


# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
# version. The default (-1) takes care of errors.
# TODO(b/289098047): consider refactoring this interface.
def get_latest_profile_version(backend: xc.Client) -> int:
  del backend
  return -1


def _walk_operations(op, k):
  k -= 1
  if k < 0:
    return k
  for region in op.regions:
    for block in region:
      for child_op in block:
        k = _walk_operations(child_op, k)
        if k < 0:
          return k
  return k


def use_detailed_logging(module: ir.Module) -> bool:
  """Returns 'true' if detailed logging should be enabled for 'module'."""
  bound = _COMPILER_DETAILED_LOGGING_MIN_OPS.value
  return _walk_operations(module.operation, bound) < 0


[docs] def get_compile_options( num_replicas: int, num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape: list[int] | None = None, auto_spmd_partitioning_mesh_ids: list[int] | None = None, env_options_overrides: dict[str, str] | None = None, fdo_profile: bytes | None = None, detailed_logging: bool = True, backend: xc.Client | None = None, ) -> xc.CompileOptions: """Returns the compile options to use, as derived from flag values. Args: num_replicas: Number of replicas for which to compile. num_partitions: Number of partitions for which to compile. device_assignment: Optional ndarray of jax devices indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with `num_replicas` and `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. use_auto_spmd_partitioning: boolean indicating whether to automatically generate XLA shardings for SPMD partitioner. auto_spmd_partitioning_mesh_shape: device mesh shape used to create auto_spmd_partitioning search space. auto_spmd_partitioning_mesh_ids: device ids used to create auto_spmd_partitioning search space. env_options_overrides: dict of additional options parsed by the compiler fdo_profile: Optional profile for feedback-directed optimization passed to XLA. detailed_logging: Is this an "interesting" computation about which XLA would be wise to log compilation information? backend: the client, if available. """ compile_options = xc.CompileOptions() compile_options.num_replicas = num_replicas compile_options.num_partitions = num_partitions build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or [] build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or [] if device_assignment is not None: logger.debug( 'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s', num_replicas, num_partitions, device_assignment) device_assignment = np.array(device_assignment) # Allow 1D device assignment if num_partitions is 1. if (device_assignment.ndim == 1) and (num_partitions == 1): device_assignment = device_assignment[:, None] if num_replicas != device_assignment.shape[0]: msg = 'device_assignment does not match num_replicas: {} vs {}.' raise ValueError(msg.format(device_assignment, num_replicas)) if num_partitions != device_assignment.shape[1]: msg = 'device_assignment does not match num_partitions: {} vs {}.' raise ValueError(msg.format(device_assignment, num_partitions)) if device_assignment.dtype == object: device_assignment = np.vectorize(lambda d: d.id, otypes=[int])( device_assignment) device_assignment = xc.DeviceAssignment.create(device_assignment) assert device_assignment.replica_count() == num_replicas assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment if env_options_overrides is not None: compile_options.env_option_overrides = list(env_options_overrides.items()) debug_options = compile_options.executable_build_options.debug_options if lib.cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = lib.cuda_path if _DISABLE_MOST_OPTIMIZATIONS.value: debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False # XLA-AutoFDO profile version: precedence order is: # 1. Whatever --jax_xla_profile_version is set to. # 2. If --jax_xla_profile_version is not set (i.e., 0), call the function # set in get_latest_profile_version and use the return value if non-zero. # If the function returns 0, set -1; this is an error. # -1 indicates that no attempt should be made to retrieve the latest profile # later on. jax_xla_profile_version = config.jax_xla_profile_version.value if jax_xla_profile_version > 0: compile_options.profile_version = jax_xla_profile_version logger.debug("get_compile_options XLA-AutoFDO profile: " + "using JAX XLA profile version %d from flag", jax_xla_profile_version) else: compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE if backend is None: logging.info("get_compile_options: no backend supplied; " "disabling XLA-AutoFDO profile") else: fdo_profile_version = get_latest_profile_version(backend) if fdo_profile_version != 0: compile_options.profile_version = fdo_profile_version logger.debug("get_compile_options XLA-AutoFDO profile: " + "using XLA-AutoFDO profile version %d", fdo_profile_version) else: logger.error("get_compile_options XLA-AutoFDO profile: " + "XLA-AutoFDO profile version is 0; this should not happen") debug_options.xla_detailed_logging = detailed_logging return compile_options
@profiler.annotate_function def backend_compile( backend: xc.Client, module: ir.Module, options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: # Convert ir.Module to a string representation, unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) if getattr(backend, "needs_str_ir", True): built_c = mlir.module_to_bytecode(module) else: built_c = module # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results if host_callbacks: return backend.compile(built_c, compile_options=options, host_callbacks=host_callbacks) # Some backends don't have `host_callbacks` option yet # TODO(sharadmv): remove this fallback when all backends allow `compile` # to take in `host_callbacks` return backend.compile(built_c, compile_options=options) def compile_or_get_cached( backend: xc.Client, computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value if dumped_to := mlir.dump_module_to_file(computation, "compile"): logging.info("Dumped the module to %s.", dumped_to) # Persistent compilation cache only implemented on TPU and GPU and the backend # that supports serialization of executables. # TODO(skye): add warning when initializing cache on unsupported default platform supported_platforms = ["tpu", "gpu"] # TODO(b/323256224): Add back support for CPU together with extra fields in a # cache key with underlying hardware features (xla_extension_version >= 230). use_compilation_cache = ( config.enable_compilation_cache.value and getattr(backend, "supports_executable_serialization", True) and backend.platform in supported_platforms ) if not use_compilation_cache: return backend_compile(backend, computation, compile_options, host_callbacks) compilation_cache.set_once_cache_used( lambda: monitoring.record_event( "/jax/compilation_cache/tasks_using_cache")) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') try: cache_key = compilation_cache.get_cache_key( computation, devices, compile_options, backend) except xc._xla.XlaRuntimeError as ex: logger.error("compile_or_get_cached: unable to generate cache key, " "skipping the cache: %s", ex) return backend_compile(backend, computation, compile_options, host_callbacks) cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( module_name, cache_key, compile_options, backend) cache_retrieval_time = time.monotonic() - cache_retrieval_start if retrieved_executable is not None: assert retrieved_compile_time is not None logger.debug("Persistent compilation cache hit for '%s'", module_name) monitoring.record_event('/jax/compilation_cache/cache_hits') monitoring.record_event_duration_secs( '/jax/compilation_cache/compile_time_saved_sec', retrieved_compile_time - cache_retrieval_time) monitoring.record_event_duration_secs( "/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time) return retrieved_executable elif ( process_count() > 1 and config.share_binary_between_hosts.value and distributed.global_state.client is not None # Host callbacks are currently baked into the HLO module so we cant share # them. and len(host_callbacks) == 0 ): return _compile_and_share_module( backend, computation, compile_options, host_callbacks, distributed.global_state.client, module_name, cache_key, ) elif ( process_count() > 1 and config.share_autotune_config_between_hosts.value and distributed.global_state.client is not None ): return _compile_and_write_autotune_config( backend, computation, compile_options, host_callbacks, distributed.global_state.client, module_name, cache_key, ) else: return _compile_and_write_cache( backend, computation, compile_options, host_callbacks, module_name, cache_key, ) # The process with id 0 should compile the module and write an autotune config # to the K-V storage. def _compile_and_write_autotune_config( backend: xc.Client, computation: ir.Module, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], global_client: lib.xla_extension.DistributedRuntimeClient, module_name: str, cache_key: str, ) -> xc.LoadedExecutable: share_timeout = config.share_binary_between_hosts_timeout_ms.value debug_options = compile_options.executable_build_options.debug_options autotune_tmp_file = os.path.join( _compile_and_write_autotune_config.autotune_configs_dir, cache_key ) if os.path.exists(autotune_tmp_file): debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file return _compile_and_write_cache( backend, computation, compile_options, host_callbacks, module_name, cache_key, ) if distributed.global_state.process_id == 0: debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file executable = _compile_and_write_cache( backend, computation, compile_options, host_callbacks, module_name, cache_key, ) with open(autotune_tmp_file, "rb") as f: autotune_config = f.read() autotune_config = compilation_cache.compress_executable(autotune_config) global_client.key_value_set_bytes(cache_key, autotune_config) else: autotune_config = global_client.blocking_key_value_get_bytes( cache_key, share_timeout ) autotune_config = compilation_cache.decompress_executable(autotune_config) with open(autotune_tmp_file, "wb") as f: f.write(autotune_config) debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file executable = _compile_and_write_cache( backend, computation, compile_options, host_callbacks, module_name, cache_key, ) return executable _compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp() # The process with id 0 should compile the module and write it to the K-V # storage. # TODO: In case when the process with id 0 is not participating in computation # we need to choose another process to compile the module. def _compile_and_share_module( backend: xc.Client, computation: ir.Module, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], global_client: lib.xla_extension.DistributedRuntimeClient, module_name: str, cache_key: str, ) -> xc.LoadedExecutable: share_timeout = config.share_binary_between_hosts_timeout_ms.value # TODO: We need a proper eviction protocol here, otherwise all compiled # modules will pile in memory. if cache_key in _compile_and_share_module.modules_cache: return _compile_and_share_module.modules_cache[cache_key] if distributed.global_state.process_id == 0: executable = _compile_and_write_cache( backend, computation, compile_options, host_callbacks, module_name, cache_key, ) serialized_executable = backend.serialize_executable(executable) serialized_executable = compilation_cache.compress_executable( serialized_executable ) global_client.key_value_set_bytes(cache_key, serialized_executable) else: serialized_executable = global_client.blocking_key_value_get_bytes( cache_key, share_timeout ) serialized_executable = compilation_cache.decompress_executable( serialized_executable ) executable = backend.deserialize_executable( serialized_executable, compile_options ) _compile_and_share_module.modules_cache[cache_key] = executable return executable _compile_and_share_module.modules_cache = {} def _compile_and_write_cache( backend: xc.Client, computation: ir.Module, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], module_name: str, cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() executable = backend_compile( backend, computation, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time _cache_write( cache_key, compile_time, module_name, backend, executable, host_callbacks ) return executable def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, backend: xc.Client ) -> tuple[xc.LoadedExecutable | None, int | None]: """Looks up the `computation` and it's compilation time in the persistent compilation cache repository. """ try: return compilation_cache.get_executable_and_time( cache_key, compile_options, backend) except Exception as ex: if config.raise_persistent_cache_errors.value: raise warnings.warn( f"Error reading persistent compilation cache entry for " f"'{module_name}': {type(ex).__name__}: {ex}") return None, None def _cache_write(cache_key: str, compile_time_secs: float, module_name: str, backend: xc.Client, executable: xc.LoadedExecutable, host_callbacks: Sequence[Any]) -> None: """Writes the `serialized_computation` and its compilation time to the persistent compilation cache repository. """ # Only write cache entries from the first process. Otherwise we create # problems with contention for writes on some filesystems, e.g., GCS. if distributed.global_state.process_id != 0: logger.debug("Not writing persistent cache entry since process_id != 0") return if host_callbacks: logger.debug( "Not writing persistent cache entry for '%s' because it uses host " "callbacks (e.g. from jax.debug.print or breakpoint)", module_name) return min_compile_time = config.persistent_cache_min_compile_time_secs.value if compile_time_secs < min_compile_time: logger.debug( "Not writing persistent cache entry for '%s' because it took < %.2f " "seconds to compile (%.2fs)", module_name, min_compile_time, compile_time_secs) return else: logger.debug( "'%s' took at least %.2f seconds to compile (%.2fs)", module_name, min_compile_time, compile_time_secs) try: compilation_cache.put_executable_and_time( cache_key, module_name, executable, backend, int(compile_time_secs)) except Exception as ex: if config.raise_persistent_cache_errors.value: raise warnings.warn( f"Error writing persistent compilation cache entry for " f"'{module_name}': {type(ex).__name__}: {ex}")