make sure to register the base chatml template even if no system message is provided (#1207)
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -106,3 +106,7 @@ jobs:
|
|||||||
- name: GPU Unit Tests monkeypatched w docker image
|
- name: GPU Unit Tests monkeypatched w docker image
|
||||||
run: |
|
run: |
|
||||||
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
|
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
|
||||||
|
- name: Prune image from docker
|
||||||
|
if: github.ref != 'refs/heads/main'
|
||||||
|
run: |
|
||||||
|
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
)
|
)
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
|
|||||||
@@ -43,7 +43,10 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
||||||
)
|
)
|
||||||
register_chatml_template(cfg.default_system_message)
|
register_chatml_template(cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
|
||||||
|
if cfg.rl:
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from datasets import (
|
|||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.utils import HFValidationError
|
||||||
from torch.utils.data import RandomSampler
|
from torch.utils.data import RandomSampler
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@@ -213,7 +214,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError):
|
except (FileNotFoundError, ConnectionError, HFValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
ds_from_cloud = False
|
||||||
|
|||||||
Reference in New Issue
Block a user