diff --git a/src/axolotl/integrations/kd/README.md b/src/axolotl/integrations/kd/README.md index 4b15ad31d..b1974e979 100644 --- a/src/axolotl/integrations/kd/README.md +++ b/src/axolotl/integrations/kd/README.md @@ -21,3 +21,32 @@ datasets: ``` An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample) + +## Online KD (sglang) + +```bash +export UV_TORCH_BACKEND=cu124 +uv venv sglang --python 3.11 +source sglang/bin/activate +uv pip install --upgrade pip +uv pip install setuptools +uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124 +uv pip install sgl-kernel --force-reinstall --no-deps +uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ +``` + +## Online KD (vllm) + +```bash +VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs +256 --gpu_memory_utilization 0.2 --enable-chunked-prefill +``` + +```bash +vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0 +``` + + +```bash +python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init +``` diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 3214ced35..c6ea9602f 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -47,6 +47,16 @@ class KDPlugin(BasePlugin): if cfg.eval_sample_packing and is_eval: use_batch_sampler_collator = True + if cfg.kd_online_server_base_url: + from .collator_online_teacher import OnlineTeacherCollator + + return OnlineTeacherCollator, { + "kd_online_server_base_url": cfg.kd_online_server_base_url, + "kd_online_topk": cfg.kd_online_topk, + "kd_temperature": cfg.kd_temperature, + "kd_online_server": cfg.kd_online_server, + } + if use_batch_sampler_collator: return KDBatchSamplerDataCollatorForSeq2Seq, {} return DataCollatorForKD, {} diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 35cf7a114..2029f6509 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,9 +15,16 @@ """ Plugin args for KD support. """ +from enum import Enum + from pydantic import BaseModel +class InferenceServerType(str, Enum): + vllm = "vllm" + sglang = "sglang" + + class KDArgs(BaseModel): """ Input args for knowledge distillation. @@ -31,5 +38,6 @@ class KDArgs(BaseModel): kd_temperature: float | None = None # temperature for sampling during KD # TODO online kd - # kd_online_server_base_url: str | None = None - # kd_online_topk: int | None = None + kd_online_server_base_url: str | None = None + kd_online_topk: int | None = None + kd_online_server: InferenceServerType | None = "vllm" diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py new file mode 100644 index 000000000..a43936335 --- /dev/null +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -0,0 +1,422 @@ +import requests +import logging +from typing import List, Optional, Dict, Any + +import torch +from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data.utils import retry_on_request_exceptions + +LOG = logging.getLogger(__name__) + + +class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): + DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 + + def __init__( + self, + *args: Any, + kd_online_server_base_url: Optional[str] = None, + kd_online_topk: Optional[int] = None, + kd_temperature: Optional[float] = 1.0, + kd_online_server: Optional[str] = "vllm", + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + + if kd_online_server_base_url is None: + raise ValueError("kd_online_server_base_url must be provided for OnlineTeacherDataloader") + if kd_online_topk is None or kd_online_topk <= 0: + raise ValueError("kd_online_topk must be a positive integer for OnlineTeacherDataloader") + + self.kd_online_server_base_url = kd_online_server_base_url.rstrip('/') + self.kd_online_topk = kd_online_topk + self.kd_temperature = kd_temperature + self.kd_online_server = kd_online_server + self.http_session = requests.Session() + + def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + if not raw_logprobs or self.kd_online_topk == 0: + return [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + + # Ensure raw_logprobs matches kd_online_topk length for tensor operations + # This should ideally be handled by the caller ensuring correct padding/truncation first + if len(raw_logprobs) != self.kd_online_topk: + # This case should be rare if pre-padding/truncation is done correctly + LOG.warning( + f"Logprobs length mismatch in _normalize_logprobs. " + f"Expected {self.kd_online_topk}, got {len(raw_logprobs)}. Will pad/truncate." + ) + padded_logprobs = raw_logprobs[:self.kd_online_topk] + if len(padded_logprobs) < self.kd_online_topk: + padded_logprobs.extend([-float("inf")] * (self.kd_online_topk - len(padded_logprobs))) + raw_logprobs = padded_logprobs + + try: + position_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) + + # Convert logprobs at T_online to probabilities + teacher_probs_t_online = position_logprobs_tensor.exp() + + # Normalize probabilities (sum to 1) + # This is important if the top-k from server aren't a full distribution + teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=0, keepdim=True) + if teacher_probs_t_online_sum.item() > 1e-9: + teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + else: + # If sum is zero, create uniform distribution to avoid NaN/Inf later + # This can happen if all raw_logprobs are -inf + if self.kd_online_topk > 0: + teacher_probs_t_online = torch.ones_like(teacher_probs_t_online) / self.kd_online_topk + # else: leave as is, will result in -inf logprobs + # + # teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online.sum( + # dim=0, keepdim=True + # ) + # TODO Convert back to logprobs using log sum exp trick + teacher_probs_t_online_max = torch.logsumexp(teacher_probs_t_online, dim=-1, keepdim=True) + final_logprobs_tensor = teacher_probs_t_online - teacher_probs_t_online_max + # final_logprobs_tensor = torch.log(teacher_probs_t_online) + + return final_logprobs_tensor.tolist() + + except Exception as e: + LOG.error(f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True) + # Fallback to (padded/truncated) raw logprobs if scaling fails + return raw_logprobs + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_sglang(self, batch_input_ids: List[List[int]], labels: List[List[int]]): + """ + Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/generate" + + payload = { + "input_ids": batch_input_ids, + "return_logprob": True, + "top_logprobs_num": self.kd_online_topk, + "logprob_start_len": 0, + "return_text_in_logprobs": True, + "echo": True, + "sampling_params": { + "max_new_tokens": 0, + "temperature": self.kd_temperature, + "skip_special_tokens": False, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_logprobs_data = { + "target_token_ids": [], + "target_logprobs": [], + "target_mask": [], + } + + try: + response = self.http_session.post(api_endpoint, json=payload, timeout=60) + response.raise_for_status() + api_data: list[dict] = response.json() + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return ret_logprobs_data + + + for sequence_data, seq_input_ids, seq_labels in zip(api_data, batch_input_ids, labels): + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + meta_info = sequence_data.pop("meta_info", {}) + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None |list[tuple]]] = meta_info.pop("input_top_logprobs", []) + if not isinstance(input_top_logprobs, list): + LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + # so we replace the None value with padding data + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: + pos_top_logprobs_data = input_top_logprobs[i] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not (isinstance(pos_top_logprobs_data, list) and \ + all(isinstance(item, list) for item in pos_top_logprobs_data) and \ + len(pos_top_logprobs_data) > 0 and len(pos_top_logprobs_data[0]) == 3): # [logprob, token_id, token_str] + LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_logprobs_raw, pos_token_ids, _ = [list(row) for row in zip(*pos_top_logprobs_data)] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) + scaled_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) + current_target_logprobs.append(scaled_logprobs_for_position) + + # Mask depends on the corresponding label for the student + label_for_pos = seq_labels[i] if i < len(seq_labels) else self.DEFAULT_LABEL_PAD_TOKEN_ID + if label_for_pos == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + + + ret_logprobs_data["target_token_ids"].append(current_target_token_ids) + ret_logprobs_data["target_logprobs"].append(current_target_logprobs) + ret_logprobs_data["target_mask"].append(current_target_mask) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) + raise e + # Return initialized empty data + # return { + # "target_token_ids": [], + # "target_logprobs": [], + # "target_mask": [], + # } + + + return ret_logprobs_data + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_vllm(self, batch_input_ids: List[List[int]], labels: List[List[int]]): + """ + Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/v1/completions" + + payload = { + "prompt": batch_input_ids, + "echo": True, + "logprobs": True, + "prompt_logprobs": self.kd_online_topk, + "top_logprobs": self.kd_online_topk, + "max_new_tokens": 0, + "skip_special_tokens": False, + "temperature": self.kd_temperature, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_logprobs_data = { + "target_token_ids": [], + "target_logprobs": [], + "target_mask": [], + } + + try: + response = self.http_session.post(api_endpoint, json=payload, timeout=60) + response.raise_for_status() + api_data: dict = response.json() + choices: list[dict] = api_data["choices"] + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(choices, list) or len(choices) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return ret_logprobs_data + + for sequence_data, seq_input_ids, seq_labels in zip(choices, batch_input_ids, labels): + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | list[tuple]]] = sequence_data.pop("prompt_logprobs", []) + """ + vllm api data for prompt logprobs looks like: + "prompt_logprobs": [ + null, # first token is always null + { # second token logprobs + "8948": { # token ID + "logprob": -2.3841830625315197e-06, + "rank": 1, + "decoded_token": "system" + }, + "1849": { # token ID + "logprob": -13.187501907348633, + "rank": 2, + "decoded_token": "Ġsystem" + }, + ... # rest of the top-k tokens/logprobs + }, + ... # more tokens + } + """ + if not isinstance(input_top_logprobs, list): + LOG.warning(f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence.") + input_top_logprobs = [] # Treat as empty + + for i, input_id, label in zip(range(len(seq_input_ids)), seq_input_ids, seq_labels): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + # so we replace the None value with padding data + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + elif i < len(input_top_logprobs) and input_top_logprobs[i] is not None: + pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not (isinstance(pos_top_logprobs_data, dict) and \ + all(isinstance(item, dict) for item in pos_top_logprobs_data.values()) and \ + len(pos_top_logprobs_data.keys()) > 0): # [logprob, token_id, token_str] + LOG.warning(f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position.") + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_token_ids = pos_top_logprobs_data.keys() + pos_logprobs_dict = pos_top_logprobs_data.values() + pos_token_ids = [int(token_id) for token_id in pos_token_ids] + pos_logprobs_raw = [float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + LOG.debug(f"Padding position {i} with {pad_len} top-k tokens and logprobs.") + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append(pos_token_ids[:self.kd_online_topk]) + + # normalized_logprobs_for_position = self._normalize_logprobs(pos_logprobs_raw[:self.kd_online_topk]) + # current_target_logprobs.append(normalized_logprobs_for_position) + # don't normalize for now as the probs seem to sum to 1.0 already + current_target_logprobs.append(pos_logprobs_raw[:self.kd_online_topk]) + + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append([-float("inf")] * self.kd_online_topk) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + + ret_logprobs_data["target_token_ids"].append(current_target_token_ids) + ret_logprobs_data["target_logprobs"].append(current_target_logprobs) + ret_logprobs_data["target_mask"].append(current_target_mask) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error(f"Unexpected error processing API response in fetch_online_logprobs: {e}", exc_info=True) + raise e + + return ret_logprobs_data + + def __call__(self, features: List[List[Dict[str, Any]]], + return_tensors: Optional[str] = None) -> Dict[str, Any]: + if not features: + return super().__call__(features, return_tensors=return_tensors) + + for sub_batch_features in features: # sub_batch_features is List[Dict[str, Any]] + if not sub_batch_features: + continue + + input_ids_for_api_call: List[List[int]] = [] + labels_for_api_call: List[List[int]] = [] + # Store references to the original item dictionaries to update them in-place + items_for_api_call: List[Dict[str, Any]] = [] + + for item_dict in sub_batch_features: + if not isinstance(item_dict, dict): + LOG.warning(f"Skipping non-dict item in sub_batch_features: {item_dict}") + continue + + current_input_ids = item_dict.get("input_ids") + current_labels = item_dict.get("labels") + + if current_input_ids is not None and current_labels is not None: + # Ensure input_ids and labels are lists of ints for JSON serialization + input_ids_list = current_input_ids.tolist() if hasattr(current_input_ids, "tolist") else list(current_input_ids) + labels_list = current_labels.tolist() if hasattr(current_labels, "tolist") else list(current_labels) + + input_ids_for_api_call.append(input_ids_list) + labels_for_api_call.append(labels_list) + items_for_api_call.append(item_dict) + else: + # This item will not get teacher logprobs from the API. + # Initialize KD fields to empty lists so downstream collators handle them uniformly. + item_dict.setdefault("target_token_ids", []) + item_dict.setdefault("target_logprobs", []) + item_dict.setdefault("target_mask", []) + + # print(items_for_api_call) + if items_for_api_call: # Only call API if there's something to process + if self.kd_online_server == "sglang": + api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(input_ids_for_api_call, labels_for_api_call) + else: + api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(input_ids_for_api_call, labels_for_api_call) + + # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" + # Each value is a list, corresponding to items_for_api_call + for i, item_to_update in enumerate(items_for_api_call): + # Check if API call was successful and returned data for this item. + # fetch_online_logprobs returns dict with empty lists if API fails or malformed. + # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. + if api_responses_for_sub_batch and \ + i < len(api_responses_for_sub_batch.get("target_token_ids", [])): # Check bounds + item_to_update["target_token_ids"] = api_responses_for_sub_batch["target_token_ids"][i] + item_to_update["target_logprobs"] = api_responses_for_sub_batch["target_logprobs"][i] + item_to_update["target_mask"] = api_responses_for_sub_batch["target_mask"][i] + else: + # API call failed for this item, or response was shorter than expected. + # Ensure KD fields are initialized as empty lists. + LOG.warning( + f"Failed to get online KD data for an item in the batch (index {i}), or API response was too short. " + f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" + ) + item_to_update.setdefault("target_token_ids", []) + item_to_update.setdefault("target_logprobs", []) + item_to_update.setdefault("target_mask", []) + + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 42d4c1d6b..bb7daa092 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -62,12 +62,6 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - # target_logprobs = inputs.pop("target_logprobs") - # target_token_ids = inputs.pop("target_token_ids") - # target_mask = inputs.pop("target_mask") - - # seq_len = target_token_ids.shape[1] - if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: @@ -75,36 +69,3 @@ class AxolotlKDTrainer(AxolotlTrainer): inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) return outputs[0] - # - # # FIXME: account for tokenizer.padding_side - # student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() - # - # shift_logits = student_logits.contiguous() - # target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - # target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - # target_mask_for_loss = target_mask[..., 1:, :].contiguous() - # - # loss_kd = self.kd_loss_fn( - # shift_logits, - # target_token_ids_for_loss, - # target_logprobs_for_loss, - # target_mask_for_loss, - # num_items_in_batch=num_items_in_batch, - # ) - # - # if self.args.kd_ce_alpha > 0: - # kd_alpha = self.args.kd_alpha - # loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - # else: - # loss = loss_kd - # # Save past state if it exists - # # TODO: this needs to be fixed and made cleaner later. - # if self.args.past_index >= 0: - # self._past = outputs[ # pylint: disable=attribute-defined-outside-init - # self.args.past_index - # ] - # - # if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - # loss *= self.accelerator.num_processes - - # return (loss, outputs) if return_outputs else loss diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 5f3b8d3cc..6e117e5d2 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -40,6 +40,7 @@ def retry_on_request_exceptions( except ( requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, huggingface_hub.errors.HfHubHTTPError, ) as exc: if attempt < max_retries - 1: