chore: lint
This commit is contained in:
@@ -95,4 +95,4 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,4 +574,4 @@ torchdistx_path: ${TORCHDISTX_PATH}
|
||||
pretraining_dataset: ${PRETRAINING_DATASET}
|
||||
debug: ${DEBUG}
|
||||
seed: ${SEED}
|
||||
strict: ${STRICT}
|
||||
strict: ${STRICT}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import runpod
|
||||
"""
|
||||
Runpod serverless entrypoint handler
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import runpod
|
||||
import yaml
|
||||
from huggingface_hub._login import login
|
||||
from train import train
|
||||
from utils import get_output_dir
|
||||
from huggingface_hub._login import login
|
||||
import yaml
|
||||
|
||||
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
|
||||
if not os.path.exists(BASE_VOLUME):
|
||||
@@ -30,7 +35,7 @@ async def handler(job):
|
||||
args["runpod_job_id"] = runpod_job_id
|
||||
|
||||
yaml_data = yaml.dump(args, default_flow_style=False)
|
||||
with open(config_path, "w") as file:
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
file.write(yaml_data)
|
||||
|
||||
# Handle credentials
|
||||
|
||||
@@ -58,4 +58,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import yaml
|
||||
from torch.cuda import device_count
|
||||
"""
|
||||
Runpod train entrypoint
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
|
||||
|
||||
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
||||
@@ -11,20 +11,23 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
||||
:param config_path: Path to the YAML config file
|
||||
:param gpu_id: GPU ID to use (default: "0")
|
||||
:param preprocess: Whether to run preprocessing (default: True)
|
||||
|
||||
|
||||
"""
|
||||
# First check if preprocessing is needed
|
||||
if preprocess:
|
||||
# Preprocess command
|
||||
preprocess_cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
|
||||
preprocess_cmd = (
|
||||
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
|
||||
)
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
preprocess_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
async for line in process.stdout:
|
||||
yield f"Preprocessing: {line.decode().strip()}"
|
||||
|
||||
if process.stdout is not None:
|
||||
async for line in process.stdout:
|
||||
yield f"Preprocessing: {line.decode().strip()}"
|
||||
await process.wait()
|
||||
yield "Preprocessing completed."
|
||||
else:
|
||||
@@ -33,11 +36,10 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
||||
# Training command
|
||||
train_cmd = f"axolotl train {config_path}"
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
train_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT
|
||||
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
|
||||
)
|
||||
|
||||
async for line in process.stdout:
|
||||
yield f"Training: {line.decode().strip()}"
|
||||
|
||||
if process.stdout is not None:
|
||||
async for line in process.stdout:
|
||||
yield f"Training: {line.decode().strip()}"
|
||||
await process.wait()
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""
|
||||
Runpod launcher utils
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@@ -27,7 +32,8 @@ def make_valid_config(input_args):
|
||||
:return: str, path to the updated config file
|
||||
"""
|
||||
# Load default config
|
||||
all_args = yaml.safe_load(open("config/config.yaml", "r"))
|
||||
with open("config/config.yaml", "r", encoding="utf-8") as fin:
|
||||
all_args = yaml.safe_load(fin)
|
||||
|
||||
if not input_args:
|
||||
print("No args provided, using defaults")
|
||||
@@ -38,7 +44,7 @@ def make_valid_config(input_args):
|
||||
updated_config_path = "config/updated_config.yaml"
|
||||
|
||||
# Save updated config to new file
|
||||
with open(updated_config_path, "w") as f:
|
||||
with open(updated_config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(all_args, f)
|
||||
|
||||
return updated_config_path
|
||||
@@ -57,9 +63,9 @@ def set_config_env_vars(args: dict):
|
||||
"""Convert Python values to string format for environment variables"""
|
||||
if value is None:
|
||||
return ""
|
||||
elif isinstance(value, bool):
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
elif isinstance(value, (list, dict)):
|
||||
if isinstance(value, (list, dict)):
|
||||
return str(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
@@ -90,4 +90,4 @@
|
||||
"12.0"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user