At import time, these calls to logging.debug() implicitly call logging.basicConfig (https://docs.python.org/3/library/logging.html#logging.basicConfig), setting logging config for the whole project which cannot then be overwritten later. For instance, consider the following test script:
import logging
import jax
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logger.info('info')This should log out 'info', but because when import jax is called, this _mlir_lib/__init__.py file is run and a logging.debug is called, calling logging.basicConfig, my logging.basicConfig(level=logging.INFO) does nothing.
Fix: instead of using root logger, use a module level logger.
Found in this issue: https://github.com/google/jax/issues/12526