Compare commits
2 Commits
phi-moe
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60763b2e61 | ||
|
|
082a41af9d |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.10"
|
python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
|
|||||||
3
.github/workflows/main.yml
vendored
3
.github/workflows/main.yml
vendored
@@ -114,9 +114,6 @@ jobs:
|
|||||||
images: |
|
images: |
|
||||||
winglian/axolotl-cloud
|
winglian/axolotl-cloud
|
||||||
axolotlai/axolotl-cloud
|
axolotlai/axolotl-cloud
|
||||||
tags: |
|
|
||||||
type=ref,event=branch
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
|
|||||||
5
.github/workflows/multi-gpu-e2e.yml
vendored
5
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,11 +8,6 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
|||||||
1
.github/workflows/tests-nightly.yml
vendored
1
.github/workflows/tests-nightly.yml
vendored
@@ -48,7 +48,6 @@ jobs:
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ datasets:
|
|||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
|
||||||
|
|
||||||
# Custom user instruction prompt
|
# Custom user instruction prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-0.5B
|
|
||||||
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
chat_template: qwen_25
|
|
||||||
rl: dpo
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
|
||||||
type: chat_template.default
|
|
||||||
field_messages: conversation
|
|
||||||
field_chosen: chosen
|
|
||||||
field_rejected: rejected
|
|
||||||
message_field_role: role
|
|
||||||
message_field_content: content
|
|
||||||
roles:
|
|
||||||
system:
|
|
||||||
- system
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/dpo-out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.2
|
transformers==4.46.1
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.1.0
|
||||||
datasets==3.1.0
|
datasets==3.0.1
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
@@ -53,4 +53,3 @@ immutabledict==4.2.0
|
|||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.5.0
|
torchao==0.5.0
|
||||||
schedulefree==1.3.0
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Export specific ENV variables to /etc/rp_environment
|
# Export specific ENV variables to /etc/rp_environment
|
||||||
echo "Exporting environment variables..."
|
echo "Exporting environment variables..."
|
||||||
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
add_keys_to_authorized() {
|
add_keys_to_authorized() {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ MOE_ARCH_BLOCK = {
|
|||||||
"JetMoeMoE",
|
"JetMoeMoE",
|
||||||
],
|
],
|
||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
"phimoe": "PhiMoESparseMoeBlock",
|
|
||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1038,37 +1038,24 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
|
self,
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = DPOTrainer.tokenize_row(
|
res = super().tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
)
|
)
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
for key in res.keys():
|
||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
|
||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
|
||||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
|
||||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
|
||||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
|
||||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
@@ -1429,15 +1416,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["eval_strategy"] = "no"
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
elif self.cfg.eval_steps:
|
elif self.cfg.eval_steps:
|
||||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.eval_strategy:
|
elif self.cfg.evaluation_strategy:
|
||||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
training_arguments_kwargs[
|
||||||
|
"evaluation_strategy"
|
||||||
|
] = self.cfg.evaluation_strategy
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
if self.cfg.save_steps:
|
if self.cfg.save_steps:
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
@@ -1871,10 +1860,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
training_args_kwargs["eval_strategy"] = "steps"
|
training_args_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["eval_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -20,7 +19,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
"phimoe",
|
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
@@ -29,28 +27,71 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||||
if has_remote_code:
|
if model_type == "gemmoe":
|
||||||
patch_remote(model_name)
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif model_type == "deepseek_v2":
|
||||||
|
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||||
|
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
return
|
||||||
|
|
||||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
# retain for legacy
|
||||||
patch_mixtral_moe_forward_zero3()
|
if model_type == "mixtral":
|
||||||
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "llama":
|
||||||
|
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mistral":
|
||||||
|
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "qwen2":
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "qwen2_moe":
|
||||||
|
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "falcon":
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "phi":
|
||||||
|
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "gemma":
|
||||||
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "gemma2":
|
||||||
|
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "starcoder2":
|
||||||
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_* to be available
|
# we need to load the model here in order for modeling_* to be available
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
parts = model_config.__class__.__module__.split(".")
|
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||||
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
|
||||||
module_name = ".".join(parts)
|
|
||||||
modeling_arch = importlib.import_module(module_name)
|
modeling_arch = importlib.import_module(module_name)
|
||||||
if hasattr(modeling_arch, "_get_unpad_data"):
|
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||||
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
"""
|
|
||||||
fix for FSDP gradient accumulation
|
|
||||||
see https://github.com/huggingface/transformers/pull/34645
|
|
||||||
"""
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
|
||||||
|
|
||||||
ORIGINAL_CONTEXT_CODE = """
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i == len(batch_samples) - 1
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_CONTEXT_CODE = """
|
|
||||||
context = (
|
|
||||||
functools.partial(self.accelerator.no_sync, model=model)
|
|
||||||
if i != len(batch_samples) - 1
|
|
||||||
else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_loop_code() -> str:
|
|
||||||
training_loop = inspect.getsource(
|
|
||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
return training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def check_training_loop_is_patchable() -> bool:
|
|
||||||
train_loop = get_training_loop_code()
|
|
||||||
train_loop, _ = detab_code(train_loop)
|
|
||||||
return ORIGINAL_CONTEXT_CODE in train_loop
|
|
||||||
|
|
||||||
|
|
||||||
def patch_training_loop_for_fsdp_grad_accum():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
|
||||||
"""
|
|
||||||
|
|
||||||
train_loop = get_training_loop_code()
|
|
||||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
train_loop
|
|
||||||
)
|
|
||||||
train_loop, _ = detab_code(train_loop)
|
|
||||||
assert (
|
|
||||||
ORIGINAL_CONTEXT_CODE in train_loop
|
|
||||||
), "Original _inner_training_loop code not found"
|
|
||||||
|
|
||||||
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
|
||||||
train_loop = train_loop.replace(
|
|
||||||
"def _inner_training_loop(",
|
|
||||||
"def _fixed_inner_training_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in train_loop:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching _inner_training_loop", main_process_only=True)
|
|
||||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
|
||||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
@@ -64,7 +64,10 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
if (
|
||||||
|
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||||
|
and state.global_step == 1
|
||||||
|
):
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for working with config dicts"""
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -8,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
@@ -244,3 +247,370 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_validate_config(cfg):
|
||||||
|
"""
|
||||||
|
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||||
|
information about the model architecture
|
||||||
|
"""
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
if not cfg.bf16 and not cfg.bfloat16:
|
||||||
|
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not cfg.merge_lora
|
||||||
|
and not cfg.is_preprocess
|
||||||
|
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
not (cfg.bf16 or cfg.bfloat16)
|
||||||
|
and (cfg.fp16 or cfg.float16)
|
||||||
|
and not cfg.adapter
|
||||||
|
and not cfg.flash_attention
|
||||||
|
and cfg.sample_packing
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||||
|
)
|
||||||
|
# ValueError: Attempting to unscale FP16 gradients.
|
||||||
|
# OR
|
||||||
|
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||||
|
if cfg.max_packed_sequence_len:
|
||||||
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.rl:
|
||||||
|
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
||||||
|
|
||||||
|
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||||
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
|
)
|
||||||
|
if cfg.batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.eval_batch_size
|
||||||
|
and cfg.micro_batch_size
|
||||||
|
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.adapter == "qlora":
|
||||||
|
if cfg.merge_lora:
|
||||||
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if cfg.load_in_8bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
|
if cfg.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if cfg.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if cfg.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
if cfg.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not cfg.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with QLoRA")
|
||||||
|
|
||||||
|
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||||
|
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||||
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
|
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||||
|
raise ValueError("Fused modules are not supported with LoRA")
|
||||||
|
|
||||||
|
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||||
|
raise ValueError(
|
||||||
|
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.relora_steps:
|
||||||
|
if cfg.adapter not in ("lora", "qlora"):
|
||||||
|
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||||
|
|
||||||
|
if cfg.fsdp:
|
||||||
|
raise ValueError("fsdp not supported with ReLoRA")
|
||||||
|
|
||||||
|
if cfg.deepspeed:
|
||||||
|
raise ValueError("deepspeed not supported with ReLoRA")
|
||||||
|
|
||||||
|
if cfg.lr_scheduler == "one_cycle":
|
||||||
|
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||||
|
|
||||||
|
if cfg.trust_remote_code:
|
||||||
|
LOG.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||||
|
raise ValueError(
|
||||||
|
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||||
|
raise ValueError("FSDP is not supported for falcon models")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||||
|
) and cfg.gradient_checkpointing:
|
||||||
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
|
|
||||||
|
if cfg.flash_optimum is True:
|
||||||
|
if cfg.adapter:
|
||||||
|
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||||
|
if cfg.fp16 or cfg.bf16:
|
||||||
|
raise ValueError("AMP is not supported with BetterTransformer")
|
||||||
|
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||||
|
LOG.warning(
|
||||||
|
"You should probably set bfloat16 or float16 to true to "
|
||||||
|
"load the model in float16 for BetterTransformers"
|
||||||
|
)
|
||||||
|
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||||
|
LOG.warning("torch>=2.0.0 required")
|
||||||
|
raise ValueError(
|
||||||
|
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||||
|
LOG.warning(
|
||||||
|
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||||
|
)
|
||||||
|
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||||
|
)
|
||||||
|
|
||||||
|
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||||
|
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||||
|
):
|
||||||
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
|
|
||||||
|
if cfg.push_to_hub_model_id:
|
||||||
|
raise ValueError(
|
||||||
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||||
|
LOG.warning(
|
||||||
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
|
raise ValueError(
|
||||||
|
"revision_of_model is not supported for GPTQ models. "
|
||||||
|
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||||
|
+ "point to its path, and remove revision_of_model from the config."
|
||||||
|
)
|
||||||
|
|
||||||
|
# if cfg.sample_packing and cfg.sdp_attention:
|
||||||
|
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||||
|
# raise ValueError(
|
||||||
|
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||||
|
# )
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.xformers_attention:
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||||
|
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||||
|
LOG.warning(
|
||||||
|
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||||
|
"This may work on H100s."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.early_stopping_patience:
|
||||||
|
if not cfg.save_steps or not cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
|
if cfg.save_steps % cfg.eval_steps != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.saves_per_epoch and cfg.save_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.evals_per_epoch
|
||||||
|
and cfg.evaluation_strategy
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.evaluation_strategy
|
||||||
|
and cfg.eval_steps
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.val_set_size == 0
|
||||||
|
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||||
|
and not cfg.test_datasets
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.sample_packing
|
||||||
|
and cfg.eval_table_size
|
||||||
|
and cfg.eval_sample_packing is not False
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||||
|
|
||||||
|
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||||
|
cfg.wandb_name = cfg.wandb_run_id
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.noisy_embedding_alpha is not None:
|
||||||
|
# Deprecated, use neftune_noise_alpha
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
if cfg.neftune_noise_alpha is None:
|
||||||
|
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||||
|
else:
|
||||||
|
# User is providing both; bail and have them sort out their settings
|
||||||
|
raise ValueError(
|
||||||
|
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
|
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.unfrozen_parameters
|
||||||
|
and cfg.gradient_checkpointing_kwargs
|
||||||
|
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||||
|
):
|
||||||
|
# https://github.com/huggingface/transformers/issues/21381
|
||||||
|
raise ValueError(
|
||||||
|
"`use_reentrant` must be false when used with partially frozen model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||||
|
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||||
|
contents = file.read()
|
||||||
|
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||||
|
if cfg.flash_attention:
|
||||||
|
if (
|
||||||
|
deepspeed_cfg.zero_optimization
|
||||||
|
and deepspeed_cfg.zero_optimization.stage == 3
|
||||||
|
):
|
||||||
|
if not (
|
||||||
|
(
|
||||||
|
deepspeed_cfg.bf16
|
||||||
|
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
deepspeed_cfg.fp16
|
||||||
|
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||||
|
)
|
||||||
|
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||||
|
LOG.warning(
|
||||||
|
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.test_datasets and cfg.val_set_size:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||||
|
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||||
|
|
||||||
|
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||||
|
raise ValueError(
|
||||||
|
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.eval_causal_lm_metrics:
|
||||||
|
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||||
|
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||||
|
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||||
|
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# MPT 7b
|
||||||
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
# no 8bit adaAmw w bf16
|
||||||
|
|
||||||
|
# GPT-NeoX
|
||||||
|
# evals broken when extending context len
|
||||||
|
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||||
|
# attention_mask = causal_mask + attention_mask
|
||||||
|
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||||
|
|||||||
@@ -68,7 +68,6 @@ class DeprecatedParameters(BaseModel):
|
|||||||
rope_scaling: Optional[Any] = None
|
rope_scaling: Optional[Any] = None
|
||||||
noisy_embedding_alpha: Optional[float] = None
|
noisy_embedding_alpha: Optional[float] = None
|
||||||
dpo_beta: Optional[float] = None
|
dpo_beta: Optional[float] = None
|
||||||
evaluation_strategy: Optional[str] = None
|
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -100,13 +99,6 @@ class DeprecatedParameters(BaseModel):
|
|||||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
return dpo_beta
|
return dpo_beta
|
||||||
|
|
||||||
@field_validator("evaluation_strategy")
|
|
||||||
@classmethod
|
|
||||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
|
||||||
if evaluation_strategy is not None:
|
|
||||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
|
||||||
return evaluation_strategy
|
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""parameters that have been remapped to other names"""
|
"""parameters that have been remapped to other names"""
|
||||||
@@ -739,7 +731,7 @@ class AxolotlInputConfig(
|
|||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
eval_steps: Optional[Union[int, float]] = None
|
eval_steps: Optional[Union[int, float]] = None
|
||||||
evals_per_epoch: Optional[Union[int]] = None
|
evals_per_epoch: Optional[Union[int]] = None
|
||||||
eval_strategy: Optional[str] = None
|
evaluation_strategy: Optional[str] = None
|
||||||
save_steps: Optional[Union[int, float]] = None
|
save_steps: Optional[Union[int, float]] = None
|
||||||
saves_per_epoch: Optional[int] = None
|
saves_per_epoch: Optional[int] = None
|
||||||
save_strategy: Optional[str] = None
|
save_strategy: Optional[str] = None
|
||||||
@@ -1041,21 +1033,21 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_evals(cls, data):
|
def check_evals(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("eval_strategy")
|
data.get("evaluation_strategy")
|
||||||
and data.get("eval_steps")
|
and data.get("eval_steps")
|
||||||
and data.get("eval_strategy") != "steps"
|
and data.get("evaluation_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
data.get("val_set_size") == 0
|
data.get("val_set_size") == 0
|
||||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
)
|
)
|
||||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1063,11 +1055,11 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
data.get("evals_per_epoch")
|
data.get("evals_per_epoch")
|
||||||
and data.get("eval_strategy")
|
and data.get("evaluation_strategy")
|
||||||
and data.get("eval_strategy") != "steps"
|
and data.get("evaluation_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.get("do_bench_eval") and not (
|
if data.get("do_bench_eval") and not (
|
||||||
@@ -1299,25 +1291,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("adapter") == "qlora"
|
|
||||||
and data.get("gradient_checkpointing_kwargs", {})
|
|
||||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
|
||||||
is False
|
|
||||||
and "zero3" in data.get("deepspeed", "")
|
|
||||||
):
|
|
||||||
# may result in:
|
|
||||||
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
|
||||||
# Recomputed values for the following tensors have different metadata
|
|
||||||
# than during the forward pass.
|
|
||||||
LOG.warning(
|
|
||||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_val_w_test_datasets(cls, data):
|
def check_val_w_test_datasets(cls, data):
|
||||||
@@ -1327,19 +1300,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_eval_strategy(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("evaluation_strategy") is not None
|
|
||||||
and data.get("eval_strategy") is None
|
|
||||||
):
|
|
||||||
LOG.info(
|
|
||||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
|
||||||
)
|
|
||||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||||
@@ -1442,6 +1402,17 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_grad_accum_4_46_2(cls, data):
|
||||||
|
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
|
||||||
|
if version("transformers") == "4.46.2":
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
|
||||||
|
"Please use a lower value or switch to an older version of transformers."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
@@ -260,7 +260,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
@@ -270,7 +269,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=ds_trust_remote_code,
|
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -350,15 +348,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
ds = load_from_disk(config_dataset.path)
|
||||||
ds = load_from_disk(config_dataset.path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
ds = load_dataset(
|
|
||||||
config_dataset.path,
|
|
||||||
name=config_dataset.name,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
|
||||||
@@ -376,7 +366,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
load_ds_kwargs = {}
|
load_ds_kwargs = {}
|
||||||
if config_dataset.split:
|
if config_dataset.split:
|
||||||
load_ds_kwargs["split"] = config_dataset.split
|
load_ds_kwargs = {"split": config_dataset.split}
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -384,7 +374,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -402,7 +391,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
)
|
)
|
||||||
elif config_dataset.path.startswith("https://"):
|
elif config_dataset.path.startswith("https://"):
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
@@ -413,7 +401,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
|
|||||||
@@ -238,7 +238,6 @@ def load_tokenizer(cfg):
|
|||||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and k != "pad_token"
|
|
||||||
):
|
):
|
||||||
lora_modules_to_save = ", ".join(
|
lora_modules_to_save = ", ".join(
|
||||||
[f"`{x}`" for x in lora_modules_to_save]
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
@@ -395,17 +394,10 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
has_remote_code = (
|
|
||||||
"auto_map" in self.model_config
|
|
||||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
|
||||||
)
|
|
||||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
|
||||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
|
||||||
has_remote_code = self.cfg.trust_remote_code
|
|
||||||
patch_for_multipack(
|
patch_for_multipack(
|
||||||
self.cfg.model_config_type,
|
self.cfg.model_config_type,
|
||||||
model_name=self.cfg.base_model,
|
model_name=self.cfg.base_model,
|
||||||
has_remote_code=has_remote_code,
|
is_remote_code=self.cfg.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
if self.cfg.is_llama_derived_model:
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
|
||||||
patch_training_loop_for_fsdp_grad_accum,
|
|
||||||
)
|
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -496,11 +493,6 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
if cfg.fsdp:
|
|
||||||
try:
|
|
||||||
patch_training_loop_for_fsdp_grad_accum()
|
|
||||||
except AssertionError:
|
|
||||||
pass
|
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
shared pytest fixtures
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
# Create a temporary directory
|
|
||||||
_temp_dir = tempfile.mkdtemp()
|
|
||||||
yield _temp_dir
|
|
||||||
# Clean up the directory after the test
|
|
||||||
shutil.rmtree(_temp_dir)
|
|
||||||
@@ -3,25 +3,28 @@ E2E tests for multigpu eval
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUEval:
|
class TestMultiGPUEval(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for MultiGPU Eval Sample Packing
|
Test case for MultiGPU Eval Sample Packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_eval_sample_packing(self, temp_dir):
|
def test_eval_sample_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -80,14 +83,13 @@ class TestMultiGPUEval:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -146,8 +148,6 @@ class TestMultiGPUEval:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper
|
from ..utils import is_hopper, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -28,16 +28,18 @@ def download_model():
|
|||||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPULlama:
|
class TestMultiGPULlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_lora_ddp(self, temp_dir):
|
def test_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -46,7 +48,9 @@ class TestMultiGPULlama:
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -77,23 +81,19 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@with_temp_dir
|
||||||
"gradient_accumulation_steps",
|
def test_lora_ddp_packed(self, temp_dir):
|
||||||
[1, 4],
|
|
||||||
)
|
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
@@ -105,7 +105,9 @@ class TestMultiGPULlama:
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -116,7 +118,7 @@ class TestMultiGPULlama:
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -136,8 +138,6 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
@@ -145,6 +145,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||||
|
@with_temp_dir
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -209,14 +210,13 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -278,94 +278,25 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@with_temp_dir
|
||||||
"gradient_accumulation_steps",
|
def test_fsdp(self, temp_dir):
|
||||||
[1, 4],
|
|
||||||
)
|
|
||||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
"sequence_len": 2048,
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"val_set_size": 0.01,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 10,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp": [
|
|
||||||
"full_shard",
|
|
||||||
"auto_wrap",
|
|
||||||
],
|
|
||||||
"fsdp_config": {
|
|
||||||
"fsdp_limit_all_gathers": True,
|
|
||||||
"fsdp_offload_params": False,
|
|
||||||
"fsdp_sync_module_states": True,
|
|
||||||
"fsdp_use_orig_params": False,
|
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
|
||||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"fsdp_state_dict_type",
|
|
||||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
|
||||||
)
|
|
||||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
|
||||||
"sample_packing": True,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -393,7 +324,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -410,14 +341,79 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fsdp_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 15,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -487,29 +483,28 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@with_temp_dir
|
||||||
"gradient_accumulation_steps",
|
def test_ds_zero3_packed(self, temp_dir):
|
||||||
[1, 4],
|
|
||||||
)
|
|
||||||
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -520,7 +515,7 @@ class TestMultiGPULlama:
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -541,19 +536,19 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -566,7 +561,9 @@ class TestMultiGPULlama:
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -598,8 +595,6 @@ class TestMultiGPULlama:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,30 +4,31 @@ E2E tests for multigpu qwen2
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUQwen2:
|
class TestMultiGPUQwen2(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
@with_temp_dir
|
||||||
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
def test_qlora_fsdp_dpo(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": base_model,
|
"base_model": "Qwen/Qwen2-1.5B",
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
@@ -46,9 +47,9 @@ class TestMultiGPUQwen2:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 5,
|
"max_steps": 15,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -90,8 +91,6 @@ class TestMultiGPUQwen2:
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerFSDPIntegration(unittest.TestCase):
|
|
||||||
"""Unsloth monkeypatch integration tests."""
|
|
||||||
|
|
||||||
def test_train_loop_patchable(self):
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_training_loop_is_patchable(),
|
|
||||||
"HF transformers _inner_training_loop has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for llama
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLlama(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Llama models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fft_trust_remote_code(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": True,
|
|
||||||
"save_safetensors": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
@@ -108,37 +108,3 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "schedule_free_adamw",
|
|
||||||
"lr_scheduler": "constant",
|
|
||||||
"save_safetensors": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for qwen
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.qwen")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestE2eQwen:
|
|
||||||
"""
|
|
||||||
Test cases for qwen models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
|
||||||
def test_dpo(self, base_model, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": base_model,
|
|
||||||
"rl": "dpo",
|
|
||||||
"chat_template": "qwen_25",
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
|
||||||
"split": "train",
|
|
||||||
"type": "chat_template.default",
|
|
||||||
"field_messages": "conversation",
|
|
||||||
"field_chosen": "chosen",
|
|
||||||
"field_rejected": "rejected",
|
|
||||||
"message_field_role": "role",
|
|
||||||
"message_field_content": "content",
|
|
||||||
"roles": {
|
|
||||||
"system": ["system"],
|
|
||||||
"user": ["user"],
|
|
||||||
"assistant": ["assistant"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"warmup_steps": 20,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"tf32": True,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main_process_port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@@ -371,79 +371,44 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
with tempfile.TemporaryDirectory() as tmp_dir2:
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
||||||
snapshot_download(
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
snapshot_download(
|
||||||
repo_type="dataset",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
local_dir=tmp_ds_path,
|
repo_type="dataset",
|
||||||
revision="d05c1cb",
|
local_dir=tmp_ds_path,
|
||||||
)
|
revision="d05c1cb",
|
||||||
|
)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"ds_type": "parquet",
|
"ds_type": "parquet",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"data_files": [
|
"data_files": [
|
||||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||||
],
|
],
|
||||||
"revision": "d05c1cb",
|
"revision": "d05c1cb",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
self.tokenizer, cfg, prepared_path
|
self.tokenizer, cfg, prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
def test_loading_local_dataset_folder(self):
|
|
||||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=tmp_ds_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": str(tmp_ds_path),
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
|
||||||
assert "input_ids" in dataset.features
|
|
||||||
assert "attention_mask" in dataset.features
|
|
||||||
assert "labels" in dataset.features
|
|
||||||
shutil.rmtree(tmp_ds_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -726,7 +726,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "epoch",
|
"evaluation_strategy": "epoch",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -734,14 +734,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "no",
|
"evaluation_strategy": "no",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -749,14 +749,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "steps",
|
"evaluation_strategy": "steps",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -767,7 +767,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "steps",
|
"evaluation_strategy": "steps",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -790,7 +790,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "no",
|
"evaluation_strategy": "no",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -801,7 +801,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "epoch",
|
"evaluation_strategy": "epoch",
|
||||||
"val_set_size": 0,
|
"val_set_size": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -810,7 +810,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -826,7 +826,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -856,7 +856,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"eval_strategy": "epoch",
|
"evaluation_strategy": "epoch",
|
||||||
"val_set_size": 0.01,
|
"val_set_size": 0.01,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1095,24 +1095,6 @@ class TestValidation(BaseValidation):
|
|||||||
assert new_cfg["dpo_beta"] is None
|
assert new_cfg["dpo_beta"] is None
|
||||||
assert len(self._caplog.records) == 1
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
def test_eval_strategy_remap(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"evaluation_strategy": "steps",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert new_cfg.eval_strategy == "steps"
|
|
||||||
assert (
|
|
||||||
"evaluation_strategy is deprecated, use eval_strategy instead"
|
|
||||||
in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user