torch compile and cuda alloc improvements (#1755)
* enable experimental expandable_segments * hf trainer seems to be missing torch compile * disable PYTORCH_CUDA_ALLOC_CONF to see if that fixes cicd
This commit is contained in:
@@ -290,6 +290,18 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
|
|||||||
@@ -52,6 +52,13 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
# torch_version = torch.__version__.split(".")
|
||||||
|
# torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
# if torch_major == 2 and torch_minor >= 2:
|
||||||
|
# if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
|
|||||||
Reference in New Issue
Block a user