fixes to accelerator so that iterable pretraining datasets work (#1759)
* fixes to accelerator so that iterable pretraining datasets work * fix the pretraining test params * split batches, not dispatch batches needs to be set * update c4 datasets * set epochs in pretrain config test * need to set both split_batches and dispatch_batches to false for pretraining * fix bool val in comment
This commit is contained in:
@@ -1481,6 +1481,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs[
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
|
||||
@@ -77,6 +77,7 @@ class PretrainingDataset(BaseModel):
|
||||
split: Optional[str] = "train"
|
||||
text_column: Optional[str] = "text"
|
||||
type: Optional[str] = "pretrain"
|
||||
trust_remote_code: Optional[bool] = False
|
||||
|
||||
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
@@ -118,6 +119,8 @@ class SFTDataset(BaseModel):
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
trust_remote_code: Optional[bool] = False
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
"""User defined typing for DPO"""
|
||||
@@ -158,6 +161,7 @@ class KTODataset(BaseModel):
|
||||
split: Optional[str] = None
|
||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
trust_remote_code: Optional[bool] = False
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
@@ -504,6 +508,8 @@ class AxolotlInputConfig(
|
||||
dataloader_prefetch_factor: Optional[int] = None
|
||||
dataloader_drop_last: Optional[bool] = None
|
||||
|
||||
accelerator_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
remove_unused_columns: Optional[bool] = None
|
||||
|
||||
push_dataset_to_hub: Optional[str] = None
|
||||
@@ -702,6 +708,24 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_pretraining_split_batches_accelerate(cls, data):
|
||||
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||
if data.get("pretraining_dataset"):
|
||||
accelerator_config = data.get("accelerator_config", {})
|
||||
if not accelerator_config:
|
||||
data["accelerator_config"] = {
|
||||
"split_batches": False,
|
||||
"dispatch_batches": False,
|
||||
}
|
||||
else:
|
||||
if accelerator_config.get("split_batches") is None:
|
||||
data["accelerator_config"]["split_batches"] = False
|
||||
if accelerator_config.get("dispatch_batches") is None:
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gptq_w_revision(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user