From c56e0a79a564ec37ec892ad95565d025e0ad7b9c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 Aug 2024 10:31:50 -0400 Subject: [PATCH] logging improvements (#1808) [skip ci] * logging improvements * fix sort --- src/axolotl/utils/data/sft.py | 7 ++++++- src/axolotl/utils/distributed.py | 6 +++--- src/axolotl/utils/models.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 97061cc62..1b6df1cde 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -187,7 +187,12 @@ def load_tokenized_prepared_datasets( else: if cfg.push_dataset_to_hub: LOG.info("Unable to find prepared dataset in Huggingface hub") - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") + if cfg.is_preprocess: + LOG.info( + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." + ) + else: + LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info("Loading raw datasets...") if not cfg.is_preprocess: LOG.warning( diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 4444a20c9..3a559f5f5 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -153,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name if is_main_process(): value_scalar = fn() value_tensor = torch.tensor( - value_scalar, device=torch.cuda.current_device() - ).float() + value_scalar, device=torch.cuda.current_device(), dtype=torch.float32 + ) else: value_tensor = torch.tensor( - 0.0, device=torch.cuda.current_device() + 0.0, device=torch.cuda.current_device(), dtype=torch.float32 ) # Placeholder tensor # Broadcast the tensor to all processes. diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 87f50d9a2..1e9819c56 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1022,7 +1022,7 @@ def load_lora(model, cfg, inference=False, config_only=False): if cfg.lora_target_linear: linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(linear_names)}") + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") lora_target_modules = list(set(lora_target_modules + linear_names)) lora_config_kwargs = {}