chore: lint

This commit is contained in:
Wing Lian
2025-04-22 14:13:48 -04:00
parent cb7185998b
commit 168ec339e5
7 changed files with 41 additions and 28 deletions

View File

@@ -95,4 +95,4 @@
} }
] ]
} }
} }

View File

@@ -574,4 +574,4 @@ torchdistx_path: ${TORCHDISTX_PATH}
pretraining_dataset: ${PRETRAINING_DATASET} pretraining_dataset: ${PRETRAINING_DATASET}
debug: ${DEBUG} debug: ${DEBUG}
seed: ${SEED} seed: ${SEED}
strict: ${STRICT} strict: ${STRICT}

View File

@@ -1,9 +1,14 @@
import runpod """
Runpod serverless entrypoint handler
"""
import os import os
import runpod
import yaml
from huggingface_hub._login import login
from train import train from train import train
from utils import get_output_dir from utils import get_output_dir
from huggingface_hub._login import login
import yaml
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume") BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
if not os.path.exists(BASE_VOLUME): if not os.path.exists(BASE_VOLUME):
@@ -30,7 +35,7 @@ async def handler(job):
args["runpod_job_id"] = runpod_job_id args["runpod_job_id"] = runpod_job_id
yaml_data = yaml.dump(args, default_flow_style=False) yaml_data = yaml.dump(args, default_flow_style=False)
with open(config_path, "w") as file: with open(config_path, "w", encoding="utf-8") as file:
file.write(yaml_data) file.write(yaml_data)
# Handle credentials # Handle credentials

View File

@@ -58,4 +58,4 @@
} }
} }
} }
} }

View File

@@ -1,8 +1,8 @@
import yaml """
from torch.cuda import device_count Runpod train entrypoint
"""
import asyncio import asyncio
import os
from typing import Optional, Dict, Any, AsyncGenerator
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True): async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
@@ -11,20 +11,23 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
:param config_path: Path to the YAML config file :param config_path: Path to the YAML config file
:param gpu_id: GPU ID to use (default: "0") :param gpu_id: GPU ID to use (default: "0")
:param preprocess: Whether to run preprocessing (default: True) :param preprocess: Whether to run preprocessing (default: True)
""" """
# First check if preprocessing is needed # First check if preprocessing is needed
if preprocess: if preprocess:
# Preprocess command # Preprocess command
preprocess_cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}" preprocess_cmd = (
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
)
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
preprocess_cmd, preprocess_cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT stderr=asyncio.subprocess.STDOUT,
) )
async for line in process.stdout: if process.stdout is not None:
yield f"Preprocessing: {line.decode().strip()}" async for line in process.stdout:
yield f"Preprocessing: {line.decode().strip()}"
await process.wait() await process.wait()
yield "Preprocessing completed." yield "Preprocessing completed."
else: else:
@@ -33,11 +36,10 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
# Training command # Training command
train_cmd = f"axolotl train {config_path}" train_cmd = f"axolotl train {config_path}"
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
train_cmd, train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT
) )
async for line in process.stdout: if process.stdout is not None:
yield f"Training: {line.decode().strip()}" async for line in process.stdout:
yield f"Training: {line.decode().strip()}"
await process.wait() await process.wait()

View File

@@ -1,4 +1,9 @@
"""
Runpod launcher utils
"""
import os import os
import yaml import yaml
@@ -27,7 +32,8 @@ def make_valid_config(input_args):
:return: str, path to the updated config file :return: str, path to the updated config file
""" """
# Load default config # Load default config
all_args = yaml.safe_load(open("config/config.yaml", "r")) with open("config/config.yaml", "r", encoding="utf-8") as fin:
all_args = yaml.safe_load(fin)
if not input_args: if not input_args:
print("No args provided, using defaults") print("No args provided, using defaults")
@@ -38,7 +44,7 @@ def make_valid_config(input_args):
updated_config_path = "config/updated_config.yaml" updated_config_path = "config/updated_config.yaml"
# Save updated config to new file # Save updated config to new file
with open(updated_config_path, "w") as f: with open(updated_config_path, "w", encoding="utf-8") as f:
yaml.dump(all_args, f) yaml.dump(all_args, f)
return updated_config_path return updated_config_path
@@ -57,9 +63,9 @@ def set_config_env_vars(args: dict):
"""Convert Python values to string format for environment variables""" """Convert Python values to string format for environment variables"""
if value is None: if value is None:
return "" return ""
elif isinstance(value, bool): if isinstance(value, bool):
return str(value).lower() return str(value).lower()
elif isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
return str(value) return str(value)
return str(value) return str(value)

View File

@@ -90,4 +90,4 @@
"12.0" "12.0"
] ]
} }
} }