Compare commits
10 Commits
codecov-pu
...
dump-confi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b594f18f6e | ||
|
|
700791deb9 | ||
|
|
d6d2cc673b | ||
|
|
1d8f500709 | ||
|
|
83525f14a0 | ||
|
|
68c0e31fd1 | ||
|
|
22f930c658 | ||
|
|
0494359c6c | ||
|
|
26c39e1ca7 | ||
|
|
45adf1bfb9 |
50
.github/workflows/tests.yml
vendored
50
.github/workflows/tests.yml
vendored
@@ -106,12 +106,13 @@ jobs:
|
|||||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Upload coverage artifacts
|
- name: Upload coverage to Codecov
|
||||||
uses: actions/upload-artifact@v4
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
name: coverage-${{ matrix.pytorch_version }}-${{ github.run_id }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
path: ./coverage.xml
|
files: ./coverage.xml
|
||||||
retention-days: 1
|
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||||
|
fail_ci_if_error: false
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -233,14 +234,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
- name: Upload coverage artifacts
|
|
||||||
if: always()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: coverage-e2e-1st-${{ github.run_id }}
|
|
||||||
path: ./e2e-coverage.xml
|
|
||||||
retention-days: 1
|
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
@@ -304,14 +297,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
- name: Upload coverage artifacts
|
|
||||||
if: always()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: coverage-e2e-${{ matrix.cuda }}-${{ matrix.pytorch }}-${{ github.run_id }}
|
|
||||||
path: ./e2e-coverage.xml
|
|
||||||
retention-days: 1
|
|
||||||
|
|
||||||
docker-e2e-cleanup:
|
docker-e2e-cleanup:
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
@@ -351,26 +336,3 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.cleanup
|
modal run cicd.cleanup
|
||||||
|
|
||||||
upload-coverage:
|
|
||||||
name: Upload Coverage to Codecov
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [pytest, docker-e2e-tests, docker-e2e-tests-1st]
|
|
||||||
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download coverage reports
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
path: coverage-reports
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
directory: coverage-reports
|
|
||||||
fail_ci_if_error: false
|
|
||||||
verbose: true
|
|
||||||
name: codecov-umbrella
|
|
||||||
override_commit: ${{ github.event.pull_request.head.sha || github.sha }}
|
|
||||||
override_pr: ${{ github.event.pull_request.number }}
|
|
||||||
|
|||||||
@@ -51,3 +51,5 @@ pytest -v --durations=10 \
|
|||||||
--cov=axolotl \
|
--cov=axolotl \
|
||||||
--cov-append \
|
--cov-append \
|
||||||
--cov-report=xml:e2e-coverage.xml
|
--cov-report=xml:e2e-coverage.xml
|
||||||
|
|
||||||
|
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||||
|
|
||||||
|
|
||||||
@@ -14,21 +12,9 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
|||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
def cicd_pytest():
|
def cicd_pytest():
|
||||||
|
|
||||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
# Read the coverage file if it exists
|
|
||||||
coverage_file = pathlib.Path("/workspace/axolotl/e2e-coverage.xml")
|
|
||||||
if coverage_file.exists():
|
|
||||||
return coverage_file.read_text(encoding="utf-8")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
coverage = cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|
||||||
# Save the coverage file to the local filesystem if it was generated
|
|
||||||
if coverage:
|
|
||||||
with open("e2e-coverage.xml", "w", encoding="utf-8") as f:
|
|
||||||
f.write(coverage)
|
|
||||||
|
|||||||
@@ -77,18 +77,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
def cicd_pytest():
|
def cicd_pytest():
|
||||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
# Read the coverage file if it exists
|
|
||||||
coverage_file = pathlib.Path("/workspace/axolotl/multigpu-coverage.xml")
|
|
||||||
if coverage_file.exists():
|
|
||||||
return coverage_file.read_text(encoding="utf-8")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
coverage = cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|
||||||
# Save the coverage file to the local filesystem if it was generated
|
|
||||||
if coverage:
|
|
||||||
with open("multigpu-coverage.xml", "w", encoding="utf-8") as file:
|
|
||||||
file.write(coverage)
|
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||||
|
|
||||||
**Q: Exitcode -9**
|
**Q: exitcode: -9**
|
||||||
|
|
||||||
> A: This usually happens when you run out of system RAM.
|
> A: This usually happens when you run out of system RAM.
|
||||||
|
|
||||||
**Q: Exitcode -7 while using deepspeed**
|
**Q: exitcode: -7 while using deepspeed**
|
||||||
|
|
||||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.1
|
|||||||
accelerate==1.7.0
|
accelerate==1.7.0
|
||||||
datasets==3.6.0
|
datasets==3.6.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.1
|
trl==0.18.2
|
||||||
hf_xet==1.1.2
|
hf_xet==1.1.2
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
|||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
API_KEY_FIELDS = {"comet_api_key"}
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
@@ -233,4 +235,15 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
|
cfg_to_log = {
|
||||||
|
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||||
|
for k, v in cfg.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"config:\n%s",
|
||||||
|
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||||
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -255,7 +254,6 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -33,7 +32,6 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -27,7 +26,6 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
from .utils import patch_with_compile_disable
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
{"additional_special_tokens": additional_special_tokens}
|
{"additional_special_tokens": additional_special_tokens}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process(use_environ=True):
|
if is_main_process():
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ import inspect
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import inspect
|
|||||||
import types
|
import types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
@@ -545,8 +544,6 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from axolotl.utils.schemas.config import (
|
|||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
def choose_device(cfg):
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Utilities for distributed functionality."""
|
||||||
utility helpers for distributed checks
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pickle # nosec
|
import pickle # nosec
|
||||||
@@ -19,7 +17,7 @@ from transformers.utils.import_utils import (
|
|||||||
distributed_state = None # pylint: disable=invalid-name
|
distributed_state = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def get_device_type():
|
def get_device_type() -> torch.device:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if is_torch_cuda_available():
|
if is_torch_cuda_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@@ -30,7 +28,7 @@ def get_device_type():
|
|||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def get_device_count():
|
def get_device_count() -> int:
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "cuda" in str(cur_device):
|
if "cuda" in str(cur_device):
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
@@ -39,7 +37,7 @@ def get_device_count():
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def get_current_device():
|
def get_current_device() -> int:
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "cuda" in str(cur_device):
|
if "cuda" in str(cur_device):
|
||||||
return torch.cuda.current_device()
|
return torch.cuda.current_device()
|
||||||
@@ -48,15 +46,24 @@ def get_current_device():
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
def init_distributed_state():
|
||||||
"""
|
|
||||||
Check if distributed training is initialized.
|
|
||||||
"""
|
|
||||||
global distributed_state # pylint: disable=global-statement
|
global distributed_state # pylint: disable=global-statement
|
||||||
if not distributed_state:
|
if distributed_state is None:
|
||||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
|
|
||||||
|
|
||||||
|
def get_distributed_state() -> PartialState | None:
|
||||||
|
return distributed_state
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed() -> bool:
|
||||||
|
"""Check if distributed training is initialized."""
|
||||||
|
init_distributed_state()
|
||||||
|
|
||||||
|
if distributed_state is None:
|
||||||
|
return False
|
||||||
|
|
||||||
return distributed_state.use_distributed and distributed_state.initialized
|
return distributed_state.use_distributed and distributed_state.initialized
|
||||||
|
|
||||||
|
|
||||||
@@ -69,31 +76,31 @@ def barrier():
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
def is_main_process(use_environ=False):
|
def is_main_process() -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the current process is the main process. If not in distributed mode,
|
Check if the current process is the main process. If not in distributed mode,
|
||||||
always return `True`.
|
always return `True`.
|
||||||
|
|
||||||
Args:
|
We use a simpler logic when the distributed state is not initialized: we just log
|
||||||
- use_environ (bool, optional): Use environment variable to determine main process.
|
on the 0-th local rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- bool: `True` if the current process is the main process, `False` otherwise.
|
`True` if the current process is the main process, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
if use_environ:
|
if get_distributed_state() is None:
|
||||||
return os.environ.get("LOCAL_RANK", "0") == "0"
|
return os.environ.get("LOCAL_RANK", "0") == "0"
|
||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
return True
|
return True
|
||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
def is_local_main_process(use_environ=False):
|
def is_local_main_process() -> bool:
|
||||||
if use_environ:
|
if get_distributed_state() is None:
|
||||||
return os.environ.get("LOCAL_RANK", "0") == "0"
|
return os.environ.get("LOCAL_RANK", "0") == "0"
|
||||||
return PartialState().is_local_main_process
|
return PartialState().is_local_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size() -> int:
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
@@ -115,7 +122,7 @@ def cleanup_distributed():
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main: bool):
|
||||||
"""
|
"""
|
||||||
runs the wrapped context so that rank 0 runs first before other ranks
|
runs the wrapped context so that rank 0 runs first before other ranks
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,9 +5,8 @@ module to freeze/unfreeze parameters by name
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Logging helpers to only log on main process."""
|
||||||
logging helpers to only log on main process
|
|
||||||
"""
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@@ -14,27 +12,18 @@ from axolotl.utils.distributed import is_main_process
|
|||||||
|
|
||||||
class MultiProcessAdapter(logging.LoggerAdapter):
|
class MultiProcessAdapter(logging.LoggerAdapter):
|
||||||
"""
|
"""
|
||||||
logger adapter for distributed logging, specifically to only log on main process
|
Logger adapter for distributed logging, specifically to only log on main process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, logger, use_environ=False, extra=None):
|
|
||||||
super().__init__(logger, extra)
|
|
||||||
self.use_environ = use_environ
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_log(main_process_only, use_environ=False):
|
def _should_log(main_process_only: bool):
|
||||||
return not main_process_only or (
|
return not main_process_only or is_main_process()
|
||||||
main_process_only and is_main_process(use_environ=use_environ)
|
|
||||||
)
|
|
||||||
|
|
||||||
def log(self, level, msg, *args, **kwargs):
|
def log(self, level, msg, *args, **kwargs):
|
||||||
use_environ = kwargs.pop("use_environ", self.use_environ)
|
|
||||||
main_process_only = kwargs.pop("main_process_only", True)
|
main_process_only = kwargs.pop("main_process_only", True)
|
||||||
kwargs.setdefault("stacklevel", 2)
|
kwargs.setdefault("stacklevel", 2)
|
||||||
|
|
||||||
if self.isEnabledFor(level) and self._should_log(
|
if self.isEnabledFor(level) and self._should_log(main_process_only):
|
||||||
main_process_only, use_environ=use_environ
|
|
||||||
):
|
|
||||||
msg, kwargs = self.process(msg, kwargs)
|
msg, kwargs = self.process(msg, kwargs)
|
||||||
self.logger.log(level, msg, *args, **kwargs)
|
self.logger.log(level, msg, *args, **kwargs)
|
||||||
|
|
||||||
@@ -50,13 +39,11 @@ class MultiProcessAdapter(logging.LoggerAdapter):
|
|||||||
self.warning(*args, **kwargs)
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_logger(
|
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
||||||
name: str, log_level: str | None = None, use_environ: bool = False
|
|
||||||
) -> MultiProcessAdapter:
|
|
||||||
if log_level is None:
|
if log_level is None:
|
||||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
logger.setLevel(log_level.upper())
|
logger.setLevel(log_level.upper())
|
||||||
logger.root.setLevel(log_level.upper())
|
logger.root.setLevel(log_level.upper())
|
||||||
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})
|
return MultiProcessAdapter(logger, extra={})
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from axolotl.utils.schemas.trl import TRLConfig
|
|||||||
from axolotl.utils.schemas.validation import ValidationMixin
|
from axolotl.utils.schemas.validation import ValidationMixin
|
||||||
from axolotl.utils.schemas.vllm import VllmConfig
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModelInputConfig(BaseModel):
|
class ModelInputConfig(BaseModel):
|
||||||
|
|||||||
@@ -11,14 +11,14 @@ from typing import List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from datasets import IterableDataset, disable_caching, enable_caching
|
from datasets import IterableDataset, disable_caching, enable_caching
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -537,6 +537,12 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
|
||||||
|
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||||
|
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||||
|
# to model load.
|
||||||
|
init_distributed_state()
|
||||||
|
|
||||||
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
||||||
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user