feat: add class-based callback system for training lifecycle hooks#706
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a training callback subsystem with new contracts, async dispatch, callback serialization across ChangesTraining Callback System
Sequence Diagram(s)sequenceDiagram
participant run_training
participant torchrun
participant main
participant CallbackManager
participant train
run_training->>torchrun: --callbacks=encoded callbacks
torchrun->>main: args.callbacks
main->>CallbackManager: deserialize + configure context
main->>train: callback_manager
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tests/unit/test_callbacks.py (1)
195-205: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winStrengthen
test_on_train_end_blocksto verify actual blocking behavior.Current callback body is instant, so this test can pass without proving
fire("on_train_end")waits. Add a deliberate delay in callback and assert elapsed wall time.Suggested assertion upgrade
def test_on_train_end_blocks(self): - called = [] + called = [] class SlowCb(TrainerCallback): def on_train_end(self, context): + time.sleep(0.05) called.append(True) mgr = CallbackManager() mgr.add_callback(SlowCb()) + t0 = time.perf_counter() mgr.fire("on_train_end") + elapsed = time.perf_counter() - t0 assert called == [True] + assert elapsed >= 0.05 + mgr.close()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/test_callbacks.py` around lines 195 - 205, The test_on_train_end_blocks test only verifies that the callback is called, but does not actually prove that the fire method waits for the callback to complete. Strengthen this test by adding a deliberate time delay inside the SlowCb.on_train_end method, then measure the elapsed wall time around the mgr.fire("on_train_end") call and assert that the elapsed time is at least as long as the injected delay, which will verify that fire actually blocks and waits for the callback to finish executing.src/instructlab/training/main_ds.py (1)
355-371: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value
on_logfires only on rank-0, unlike the other hooks.This hook is nested inside the
if local_rank == 0:block, so it dispatches only on rank-0, whereason_step_begin,on_step_end,on_epoch_begin,on_epoch_end,on_save, andon_train_endfire on all ranks. This asymmetry is reasonable since the logging metrics (current_lr,cuda_mem_allocated, throughput) are computed only on rank-0, but it's an inconsistency callback authors won't expect. Worth documenting in the hook contract thaton_logis rank-0-only.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/instructlab/training/main_ds.py` around lines 355 - 371, The callback_manager.fire("on_log") call is nested inside the if local_rank == 0: block, making it rank-0-only, unlike other hooks such as on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and on_train_end which fire on all ranks. Document this rank-0-only behavior in the callback hook contract or interface documentation to clarify this asymmetry for callback authors who may not expect on_log to behave differently from the other hooks.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/instructlab/training/callbacks.py`:
- Around line 219-262: Add documentation to the serialize_callback and
deserialize_callback functions clearly stating that TrainerCallback subclasses
must be self-contained (with no external symbol dependencies) and have
zero-argument constructors. Then enhance serialize_callback to validate these
constraints by attempting to instantiate the callback class with no arguments
and checking for any NameError when the callback is executed, raising
informative errors during serialization in the parent process rather than
allowing silent failures in the worker. Additionally, update the test cases to
cover callbacks with constructor parameters and external dependencies to ensure
the validation catches these cases.
In `@src/instructlab/training/main_ds.py`:
- Around line 448-453: The `_save_and_exit()` function returns early at lines
285 and 309, which bypasses the `on_train_end` callback dispatch and
`callback_manager.close()` cleanup code at lines 451-453. To fix this, ensure
that `on_train_end` and `close()` are invoked even when `_save_and_exit()`
triggers an early exit. Either add these cleanup calls directly in the
`_save_and_exit()` function before each return statement, or wrap the training
logic in a try-finally block that guarantees execution of these callbacks and
cleanup regardless of how the function exits.
- Around line 404-407: The code is accessing args.ckpt_output_dir which does not
exist as a defined argument in the parser, causing an AttributeError crash
during the first checkpoint save. Replace all occurrences of
args.ckpt_output_dir with args.output_dir in the three callback_manager.fire()
calls where checkpoint_path is being set as a parameter. The correct attribute
available from the argument parser is args.output_dir.
- Around line 683-703: The TrainerCallback class docstring lacks critical
documentation about the callback execution model. Add prominent documentation to
the TrainerCallback class that clearly states callbacks fire on all distributed
ranks, not just rank 0, and explicitly require callback authors to check
context.is_world_process_zero before executing any rank-specific side effects
such as checkpointing, logging, or job submission. Include concrete examples of
when this guard is necessary to prevent callback authors from inadvertently
running duplicate operations across ranks.
In `@tests/unit/test_callbacks.py`:
- Line 70: Replace all fixed time.sleep(0.1) calls with event-based
synchronization using threading.Event for deterministic completion checks.
Create a threading.Event object before each test section, set it inside the
callback function to signal completion, and then use event.wait(timeout=0.1)
instead of the sleep call. This applies to all occurrences throughout the test
file at the locations mentioned: lines 70, 84, 109, 121, 138, 151, 192, 238, and
253, where the pattern involves waiting for a callback to complete rather than
using arbitrary timing delays.
- Around line 58-255: The test class TestCallbackManager creates multiple
CallbackManager instances throughout its test methods but does not properly
clean up the background threads they spawn. Add a call to mgr.close() at the end
of each test method in TestCallbackManager (in test_fire_dispatches,
test_fire_skips_non_overridden, test_has_callbacks, test_snapshot_isolation,
test_exception_isolation, test_multiple_callbacks, test_kwargs_set_on_snapshot,
test_add_callback_type_error, test_remove_callback_by_instance,
test_remove_callback_by_type, test_fire_all_ranks, test_on_train_end_blocks,
test_fire_invalid_kwarg_raises, test_empty_manager_no_callbacks,
test_hook_name_set_on_snapshot, and test_dict_fields_snapshot_isolation) to
ensure the background thread is properly terminated after each test completes
and prevent thread leakage.
---
Nitpick comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 355-371: The callback_manager.fire("on_log") call is nested inside
the if local_rank == 0: block, making it rank-0-only, unlike other hooks such as
on_step_begin, on_step_end, on_epoch_begin, on_epoch_end, on_save, and
on_train_end which fire on all ranks. Document this rank-0-only behavior in the
callback hook contract or interface documentation to clarify this asymmetry for
callback authors who may not expect on_log to behave differently from the other
hooks.
In `@tests/unit/test_callbacks.py`:
- Around line 195-205: The test_on_train_end_blocks test only verifies that the
callback is called, but does not actually prove that the fire method waits for
the callback to complete. Strengthen this test by adding a deliberate time delay
inside the SlowCb.on_train_end method, then measure the elapsed wall time around
the mgr.fire("on_train_end") call and assert that the elapsed time is at least
as long as the injected delay, which will verify that fire actually blocks and
waits for the callback to finish executing.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bd59a26c-5aa7-48c2-816a-43015e0fc798
📒 Files selected for processing (6)
src/instructlab/training/__init__.pysrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/callbacks.pysrc/instructlab/training/config.pysrc/instructlab/training/main_ds.pytests/unit/test_callbacks.py
278713d to
d118f17
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/test_callbacks.py`:
- Around line 211-215: In the test_close method, restructure the test to wrap
the CallbackManager instance creation and assertions in a try/finally block,
ensuring that m.close() is always called in the finally block, even if any
assertion fails before it. This prevents the background thread from leaking into
subsequent tests when assertions fail.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 936f5620-2645-4239-b9fb-9b7c9f79ca07
📒 Files selected for processing (6)
src/instructlab/training/__init__.pysrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/callbacks.pysrc/instructlab/training/config.pysrc/instructlab/training/main_ds.pytests/unit/test_callbacks.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/instructlab/training/batch_loss_manager.py
- src/instructlab/training/config.py
- src/instructlab/training/init.py
- src/instructlab/training/main_ds.py
|
@RobotSail please review this PR. Thank you. |
|
Tick the box to add this pull request to the merge queue (same as
|
- H1: Guard fire() against closed event loop (no-op after close()) - H2: Wrap train() call in try/finally to ensure close() on exceptions - M2: Per-callback snapshot copy to prevent mutation cross-talk - M3: Drain pending async tasks before stopping event loop in close() - M6: Log on_train_end callback failures instead of bare except pass Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: multica-agent <github@multica.ai>
on_step_begin(step=N) was followed by on_step_end(step=N+1) because global_step was incremented before on_step_end fired. Move on_step_end before the increment so begin/end see the same step number. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: multica-agent <github@multica.ai>
There was a problem hiding this comment.
Adversarial Code Review — 3 independent reviewers (code quality, security, Python/PyTorch)
CRITICAL
1. exec() on CLI-supplied input enables arbitrary code execution — callbacks.py:deserialize_callback
deserialize_callback() passes base64-decoded source to exec(). The --callbacks CLI arg accepts any base64 string. The restricted namespace dict is not a sandbox — exec() has full access to __builtins__ and can import subprocess, read files, etc. The # noqa: S102 suppression acknowledges the linter flag without mitigation. While the comment says "never untrusted input," this is not architecturally enforced. Recommendation: HMAC-sign the serialized payload with a per-session secret so only the parent process's callbacks are accepted, or pass via a temporary file with restricted permissions.
MAJOR
2. on_train_end not fired on unexpected exceptions — main_ds.py
If train() throws an unexpected exception (not one of the handled early-return paths), the finally block in main() calls close() but does NOT call fire("on_train_end"). Callbacks needing cleanup (flushing logs, closing connections) won't be notified on crash. The finally block should fire on_train_end before close(), or train() should wrap the loop in its own try/finally.
3. inspect.getsource() fragility — callbacks.py:serialize_callback
Fails with OSError for classes defined in Jupyter notebooks, interactive sessions, dynamically generated classes, or .pyc-only distributions. These are common ML workflows. No error handling or user-facing guidance is provided. At minimum, catch OSError/TypeError and raise a clear error explaining the self-contained source requirement.
4. No-arg constructor assumption loses instance state — callbacks.py:deserialize_callback
classes[0]() instantiates with zero args. A callback like MyCallback(threshold=0.5) will fail on deserialization, and the original instance state is lost entirely (only source code is serialized, not __dict__). Should validate at serialization time that cls() works, and raise a clear error if not.
5. arbitrary_types_allowed=True blast radius — config.py:194
This is a model-wide Pydantic ConfigDict setting affecting ALL fields in TrainingArgs, not just callbacks. It weakens type validation across the entire model. Consider a scoped solution (custom validator, or BeforeValidator).
6. Bare list typing on callbacks field — config.py:415
callbacks: list | None accepts any list ([42, "hello"]) without validation. Should be list[TrainerCallback] | None so type checkers and Pydantic catch invalid inputs at construction time.
7. Early-return callback cleanup is fragile and error-prone — main_ds.py
The fire("on_train_end") + close() pattern is manually duplicated at 4 early-return points in train(). Any new early-return added later must remember to duplicate this, or on_train_end won't fire. A single try/finally inside train() would be more robust.
8. Context fields populated inconsistently across ranks — main_ds.py
learning_rate, grad_norm, elapsed_time, overall_throughput, cuda_mem_allocated are only set inside if local_rank == 0: but callbacks fire on all ranks. Non-zero ranks see stale None defaults for these fields. This is partially documented but will confuse callback authors.
MINOR
9. HOOK_NAMES is a list used for O(n) membership testing on hot path — callbacks.py:108
fire() does if hook_name not in HOOK_NAMES every call. Use a frozenset for O(1) lookups.
10. Double base64 encoding inflates payload ~77% — callbacks.py
serialize_callback base64-encodes each source, then serialize_callbacks_for_cli JSON-encodes the list and base64-encodes again. The inner encoding is unnecessary — JSON can carry source strings directly.
11. No type annotations on callback_manager parameters — main_ds.py:174, batch_loss_manager.py:38
Both train() and BatchLossManager.__init__() lack type hints: should be callback_manager: CallbackManager | None = None.
12. on_step_end fires before global_step increment — main_ds.py
Both on_step_begin and on_step_end see the same context.step value. This is consistent but differs from HuggingFace Transformers convention — worth documenting.
13. Process argument visibility — main_ds.py
Log redaction of --callbacks=<redacted> is good, but the actual command line (with full callback source) remains visible via /proc/<pid>/cmdline and ps. Consider passing via temp file or env var.
14. copy.copy() shallow copy may break snapshot isolation for nested mutables — callbacks.py:262
Currently safe (values are primitives), but dict[str, Any] typing allows mutable values. A future change storing tensors/lists as metric values would silently break isolation.
There was a problem hiding this comment.
Re-reviewed after fix commits 094a30f, 88d334a, c6578c5, and 5299b05. The author addressed 12 of 14 findings and provided sound justification for the 2 "by design" items (exec() trust boundary, bare list to avoid circular import).
Key fixes verified:
- Per-callback snapshot isolation (c6578c5)
close()idempotent + pending task drain (094a30f, c6578c5)fire()guards closed loop (c6578c5)try/finallyinmain()for crash safety (c6578c5)on_logcontext fields moved outside rank-0 block (094a30f)total_tokensnow cumulative (094a30f)on_step_endfires beforeglobal_stepincrement (5299b05)- checkpoint_path points to actual subdirectory (094a30f)
- Source removed from error messages (094a30f)
--callbacksredacted from log output (88d334a)
Remaining minor items (HOOK_NAMES as list, arbitrary_types_allowed blast radius) are low-risk. All 3 reviewers passed.
E2E Distributed Training Validation — CallbacksRan end-to-end distributed training on 4x A100-SXM4-80GB via FSDP with real custom callbacks to validate the callback mechanism works in a distributed setting. Commit tested: Test Setup
Both callbacks were serialized across the torchrun subprocess boundary (via the Results: PASS ✅All 12 hooks fired on all 4 ranks (96 invocations per rank): What was validated:
Unit Tests + Lint: PASS ✅Also ran the full CI suite via tox:
One observation (not a blocker)
|
|
@Mergifyio queue |
☑️ Command disallowed due to command restrictions in the Mergify configuration.Details
|
Summary
Adds a callback mechanism to the InstructLab Training library, enabling users and Training Hub to hook into training lifecycle events without modifying source code.
TrainerCallbackbase class with 13 lifecycle hooks that users subclass and overridecontext.is_world_process_zeroorcontext.is_local_process_zeroinspect.getsource+ base64 encodingon_train_endblocks up to 10 seconds to allow cleanup callbacks to finish before process exitFiles changed
src/instructlab/training/callbacks.pysrc/instructlab/training/config.pycallbacksfield to TrainingArgssrc/instructlab/training/batch_loss_manager.pysrc/instructlab/training/main_ds.pysrc/instructlab/training/__init__.pytests/unit/test_callbacks.pyTest plan
Resolves #694
Summary by CodeRabbit
Summary by CodeRabbit
New Features
Tests