subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import functools
from typing import Any, Callable, List, Optional
from lightning.pytorch.callbacks import Callback as PTLCallback
from nemo.lightning.base_callback import BaseCallback
from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback
class CallbackGroup:
"""A singleton registry to host and fan-out lifecycle callbacks.
Other code should call methods on this group (e.g., `on_model_init_start`).
The group will iterate all registered callbacks and, if a callback implements
the method, invoke it with the provided arguments.
"""
_instance: Optional['CallbackGroup'] = None
@classmethod
def get_instance(cls) -> 'CallbackGroup':
"""Get the singleton instance of CallbackGroup.
Returns:
CallbackGroup: The singleton instance.
"""
if cls._instance is None:
cls._instance = CallbackGroup()
return cls._instance
def __init__(self) -> None:
self._callbacks: List[BaseCallback] = [OneLoggerNeMoCallback()]
# Ensure application-end is emitted at most once per process
self._app_end_emitted: bool = False
def register(self, callback: BaseCallback) -> None:
"""Register a callback to the callback group.
Args:
callback: The callback to register.
"""
self._callbacks.append(callback)
def update_config(self, nemo_version: str, trainer: Any, **kwargs) -> None:
"""Update configuration across all registered callbacks and attach them to trainer.
Args:
nemo_version: Version key (e.g., 'v1' or 'v2') for downstream config builders.
trainer: Lightning Trainer to which callbacks should be attached if missing.
**kwargs: Forwarded to each callback's update_config implementation.
"""
# Forward update to each callback that supports update_config
sanitized_group_callbacks: List[BaseCallback] = []
for cb in self._callbacks:
# Will ignore other callbacks like unittest.mock.MagicMock
if not isinstance(cb, BaseCallback):
continue
if hasattr(cb, 'update_config'):
method = getattr(cb, 'update_config')
if callable(method):
method(nemo_version=nemo_version, trainer=trainer, **kwargs)
sanitized_group_callbacks.append(cb)
# Filter trainer callbacks to avoid leaking MagicMocks from tests
existing = list(getattr(trainer, 'callbacks', []))
sanitized_trainer_callbacks = [cb for cb in existing if isinstance(cb, PTLCallback)]
callbacks = sanitized_group_callbacks + sanitized_trainer_callbacks
# Sanitize callback state_key for pickling safety
for cb in callbacks:
try:
key = getattr(cb, 'state_key', None)
if not isinstance(key, str):
safe_key = (
f"{cb.__class__.__module__}.{getattr(cb.__class__, '__qualname__', cb.__class__.__name__)}"
)
setattr(cb, 'state_key', safe_key)
except Exception:
pass
trainer.callbacks = callbacks
@property
def callbacks(self) -> List['BaseCallback']:
"""Get the list of registered callbacks.
Returns:
List[BaseCallback]: List of registered callbacks.
"""
return self._callbacks
def __getattr__(self, method_name: str) -> Callable:
"""Dynamically create a dispatcher for unknown attributes.
Any attribute access is treated as a lifecycle method name.
When invoked, the dispatcher will call that method on each registered
callback if it exists.
"""
def dispatcher(*args, **kwargs):
for cb in self._callbacks:
if hasattr(cb, method_name):
method = getattr(cb, method_name)
if callable(method):
method(*args, **kwargs)
return dispatcher
# Explicit idempotent app-end to avoid duplicate emissions across multiple callers
def on_app_end(self, *args, **kwargs) -> None:
"""Emit application-end callbacks exactly once per process.
Invokes `on_app_end` on each registered callback, if present. Subsequent
calls are no-ops. All positional and keyword arguments are forwarded.
"""
if self._app_end_emitted:
return
self._app_end_emitted = True
for cb in self._callbacks:
if hasattr(cb, 'on_app_end'):
method = getattr(cb, 'on_app_end')
if callable(method):
method(*args, **kwargs)
def hook_class_init_with_callbacks(cls, start_callback: str, end_callback: str) -> None:
"""Hook a class's __init__ to emit CallbackGroup start/end hooks.
Args:
cls (type): Class whose __init__ should be wrapped.
start_callback (str): CallbackGroup method to call before __init__.
end_callback (str): CallbackGroup method to call after __init__.
"""
if not hasattr(cls, '__init__'):
return
original_init = cls.__init__
# Idempotence guard: avoid wrapping the same __init__ multiple times (e.g., in multiple inheritance)
if getattr(original_init, '_init_wrapped_for_callbacks', False):
return
@functools.wraps(original_init)
def wrapped_init(self, *args, **kwargs):
# Reentrancy guard: avoid double-emitting hooks across super().__init__ chains
if getattr(self, '_in_wrapped_init', False):
# If we're already inside a wrapped __init__, just call the original
return original_init(self, *args, **kwargs)
setattr(self, '_in_wrapped_init', True)
group = CallbackGroup.get_instance()
if hasattr(group, start_callback):
getattr(group, start_callback)()
result = original_init(self, *args, **kwargs)
if hasattr(group, end_callback):
getattr(group, end_callback)()
return result
wrapped_init._init_wrapped_for_callbacks = True
cls.__init__ = wrapped_init
# Eagerly create the singleton on import so that early callers can use it
CallbackGroup.get_instance()
# Ensure that a single app-end is emitted at process shutdown (e.g., pytest end-of-session,
# non-Hydra entrypoints). Safe due to idempotent on_app_end.
atexit.register(lambda: CallbackGroup.get_instance().on_app_end())
__all__ = ['CallbackGroup', 'hook_class_init_with_callbacks']