Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import atexit | |
| import functools | |
| from typing import Any, Callable, List, Optional | |
| from lightning.pytorch.callbacks import Callback as PTLCallback | |
| from nemo.lightning.base_callback import BaseCallback | |
| from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback | |
| class CallbackGroup: | |
| """A singleton registry to host and fan-out lifecycle callbacks. | |
| Other code should call methods on this group (e.g., `on_model_init_start`). | |
| The group will iterate all registered callbacks and, if a callback implements | |
| the method, invoke it with the provided arguments. | |
| """ | |
| _instance: Optional['CallbackGroup'] = None | |
| def get_instance(cls) -> 'CallbackGroup': | |
| """Get the singleton instance of CallbackGroup. | |
| Returns: | |
| CallbackGroup: The singleton instance. | |
| """ | |
| if cls._instance is None: | |
| cls._instance = CallbackGroup() | |
| return cls._instance | |
| def __init__(self) -> None: | |
| self._callbacks: List[BaseCallback] = [OneLoggerNeMoCallback()] | |
| # Ensure application-end is emitted at most once per process | |
| self._app_end_emitted: bool = False | |
| def register(self, callback: BaseCallback) -> None: | |
| """Register a callback to the callback group. | |
| Args: | |
| callback: The callback to register. | |
| """ | |
| self._callbacks.append(callback) | |
| def update_config(self, nemo_version: str, trainer: Any, **kwargs) -> None: | |
| """Update configuration across all registered callbacks and attach them to trainer. | |
| Args: | |
| nemo_version: Version key (e.g., 'v1' or 'v2') for downstream config builders. | |
| trainer: Lightning Trainer to which callbacks should be attached if missing. | |
| **kwargs: Forwarded to each callback's update_config implementation. | |
| """ | |
| # Forward update to each callback that supports update_config | |
| sanitized_group_callbacks: List[BaseCallback] = [] | |
| for cb in self._callbacks: | |
| # Will ignore other callbacks like unittest.mock.MagicMock | |
| if not isinstance(cb, BaseCallback): | |
| continue | |
| if hasattr(cb, 'update_config'): | |
| method = getattr(cb, 'update_config') | |
| if callable(method): | |
| method(nemo_version=nemo_version, trainer=trainer, **kwargs) | |
| sanitized_group_callbacks.append(cb) | |
| # Filter trainer callbacks to avoid leaking MagicMocks from tests | |
| existing = list(getattr(trainer, 'callbacks', [])) | |
| sanitized_trainer_callbacks = [cb for cb in existing if isinstance(cb, PTLCallback)] | |
| callbacks = sanitized_group_callbacks + sanitized_trainer_callbacks | |
| # Sanitize callback state_key for pickling safety | |
| for cb in callbacks: | |
| try: | |
| key = getattr(cb, 'state_key', None) | |
| if not isinstance(key, str): | |
| safe_key = ( | |
| f"{cb.__class__.__module__}.{getattr(cb.__class__, '__qualname__', cb.__class__.__name__)}" | |
| ) | |
| setattr(cb, 'state_key', safe_key) | |
| except Exception: | |
| pass | |
| trainer.callbacks = callbacks | |
| def callbacks(self) -> List['BaseCallback']: | |
| """Get the list of registered callbacks. | |
| Returns: | |
| List[BaseCallback]: List of registered callbacks. | |
| """ | |
| return self._callbacks | |
| def __getattr__(self, method_name: str) -> Callable: | |
| """Dynamically create a dispatcher for unknown attributes. | |
| Any attribute access is treated as a lifecycle method name. | |
| When invoked, the dispatcher will call that method on each registered | |
| callback if it exists. | |
| """ | |
| def dispatcher(*args, **kwargs): | |
| for cb in self._callbacks: | |
| if hasattr(cb, method_name): | |
| method = getattr(cb, method_name) | |
| if callable(method): | |
| method(*args, **kwargs) | |
| return dispatcher | |
| # Explicit idempotent app-end to avoid duplicate emissions across multiple callers | |
| def on_app_end(self, *args, **kwargs) -> None: | |
| """Emit application-end callbacks exactly once per process. | |
| Invokes `on_app_end` on each registered callback, if present. Subsequent | |
| calls are no-ops. All positional and keyword arguments are forwarded. | |
| """ | |
| if self._app_end_emitted: | |
| return | |
| self._app_end_emitted = True | |
| for cb in self._callbacks: | |
| if hasattr(cb, 'on_app_end'): | |
| method = getattr(cb, 'on_app_end') | |
| if callable(method): | |
| method(*args, **kwargs) | |
| def hook_class_init_with_callbacks(cls, start_callback: str, end_callback: str) -> None: | |
| """Hook a class's __init__ to emit CallbackGroup start/end hooks. | |
| Args: | |
| cls (type): Class whose __init__ should be wrapped. | |
| start_callback (str): CallbackGroup method to call before __init__. | |
| end_callback (str): CallbackGroup method to call after __init__. | |
| """ | |
| if not hasattr(cls, '__init__'): | |
| return | |
| original_init = cls.__init__ | |
| # Idempotence guard: avoid wrapping the same __init__ multiple times (e.g., in multiple inheritance) | |
| if getattr(original_init, '_init_wrapped_for_callbacks', False): | |
| return | |
| def wrapped_init(self, *args, **kwargs): | |
| # Reentrancy guard: avoid double-emitting hooks across super().__init__ chains | |
| if getattr(self, '_in_wrapped_init', False): | |
| # If we're already inside a wrapped __init__, just call the original | |
| return original_init(self, *args, **kwargs) | |
| setattr(self, '_in_wrapped_init', True) | |
| group = CallbackGroup.get_instance() | |
| if hasattr(group, start_callback): | |
| getattr(group, start_callback)() | |
| result = original_init(self, *args, **kwargs) | |
| if hasattr(group, end_callback): | |
| getattr(group, end_callback)() | |
| return result | |
| wrapped_init._init_wrapped_for_callbacks = True | |
| cls.__init__ = wrapped_init | |
| # Eagerly create the singleton on import so that early callers can use it | |
| CallbackGroup.get_instance() | |
| # Ensure that a single app-end is emitted at process shutdown (e.g., pytest end-of-session, | |
| # non-Hydra entrypoints). Safe due to idempotent on_app_end. | |
| atexit.register(lambda: CallbackGroup.get_instance().on_app_end()) | |
| __all__ = ['CallbackGroup', 'hook_class_init_with_callbacks'] | |