diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index ffe4699f8..21b14d986 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import gc +import json import logging import os import traceback @@ -808,11 +809,44 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): artifact.add_file(temp_file.name) wandb.log_artifact(artifact) wandb.save(temp_file.name) - LOG.info( - "The Axolotl config has been saved to the WandB run under files." - ) + LOG.info( + "The Axolotl config has been saved to the WandB run under files." + ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") + + if args.deepspeed: + try: + # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".json", + prefix="deepspeed_config_", + ) as temp_file: + skip_upload = False + if isinstance(args.deepspeed, dict): + json.dump(args.deepspeed, temp_file, indent=4) + elif isinstance(args.deepspeed, str) and os.path.exists( + args.deepspeed + ): + copyfile(args.deepspeed, temp_file.name) + else: + skip_upload = True + if not skip_upload: + artifact = wandb.Artifact( + f"deepspeed-config-{wandb.run.id}", + type="deepspeed-config", + ) + artifact.add_file(temp_file.name) + wandb.log_artifact(artifact) + wandb.save(temp_file.name) + LOG.info( + "The DeepSpeed config has been saved to the WandB run under files." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving DeepSpeed config to WandB: {err}") + return control