Source code for flax.configurations

# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Global configuration options for Flax.

Now a wrapper over jax.config, in which all config vars have a 'flax\_' prefix.

To modify a config value on run time, call:
``flax.config.update('flax_<config_name>', <value>)``
"""

import os
from jax import config as jax_config

from contextlib import contextmanager

# Keep a wrapper at the flax namespace, in case we make our implementation
# in the future.
config = jax_config

# Config parsing utils


[docs] def define_bool_state(name, default, help): """Set up a boolean flag using JAX's config system. The flag will actually be stored as an environment variable of 'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten on runtime with `flax.config.update('flax_<config_name>', <value>)`. """ return jax_config.define_bool_state('flax_' + name, default, help)
[docs] def static_bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. This is deprecated. Please use define_bool_state() unless your flag will be used in a static method and does not require runtime updates. True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Args: varname: the name of the variable default: the default boolean value Returns: boolean return value derived from defaults and environment. Raises: ValueError if the environment variable is anything else. """ val = os.getenv(varname, str(default)) val = val.lower() if val in ('y', 'yes', 't', 'true', 'on', '1'): return True elif val in ('n', 'no', 'f', 'false', 'off', '0'): return False else: raise ValueError( 'invalid truth value {!r} for environment {!r}'.format(val, varname) )
@contextmanager def temp_flip_flag(var_name: str, var_value: bool): """Context manager to temporarily flip feature flags for test functions. Args: var_name: the config variable name (without the 'flax_' prefix) var_value: the boolean value to set var_name to temporarily """ old_value = getattr(config, f'flax_{var_name}') try: config.update(f'flax_{var_name}', var_value) yield finally: config.update(f'flax_{var_name}', old_value) # Flax Global Configuration Variables: # Whether to use the lazy rng implementation. flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True) flax_filter_frames = define_bool_state( name='filter_frames', default=True, help='Whether to hide flax-internal stack frames from tracebacks.', ) flax_profile = define_bool_state( name='profile', default=True, help='Whether to run Module methods under jax.named_scope for profiles.', ) flax_use_orbax_checkpointing = define_bool_state( name='use_orbax_checkpointing', default=True, help='Whether to use Orbax to save checkpoints.', ) flax_preserve_adopted_names = define_bool_state( name='preserve_adopted_names', default=False, help="When adopting outside modules, don't clobber existing names.", ) # TODO(marcuschiam): remove this feature flag once regular dict migration is complete flax_return_frozendict = define_bool_state( name='return_frozendict', default=False, help='Whether to return FrozenDicts when calling init or apply.', ) flax_fix_rng = define_bool_state( name='fix_rng_separator', default=False, help=( 'Whether to add separator characters when folding in static data into' ' PRNG keys.' ), )