chore: lint
This commit is contained in:
@@ -95,4 +95,4 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -574,4 +574,4 @@ torchdistx_path: ${TORCHDISTX_PATH}
|
|||||||
pretraining_dataset: ${PRETRAINING_DATASET}
|
pretraining_dataset: ${PRETRAINING_DATASET}
|
||||||
debug: ${DEBUG}
|
debug: ${DEBUG}
|
||||||
seed: ${SEED}
|
seed: ${SEED}
|
||||||
strict: ${STRICT}
|
strict: ${STRICT}
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
import runpod
|
"""
|
||||||
|
Runpod serverless entrypoint handler
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import runpod
|
||||||
|
import yaml
|
||||||
|
from huggingface_hub._login import login
|
||||||
from train import train
|
from train import train
|
||||||
from utils import get_output_dir
|
from utils import get_output_dir
|
||||||
from huggingface_hub._login import login
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
|
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
|
||||||
if not os.path.exists(BASE_VOLUME):
|
if not os.path.exists(BASE_VOLUME):
|
||||||
@@ -30,7 +35,7 @@ async def handler(job):
|
|||||||
args["runpod_job_id"] = runpod_job_id
|
args["runpod_job_id"] = runpod_job_id
|
||||||
|
|
||||||
yaml_data = yaml.dump(args, default_flow_style=False)
|
yaml_data = yaml.dump(args, default_flow_style=False)
|
||||||
with open(config_path, "w") as file:
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
file.write(yaml_data)
|
file.write(yaml_data)
|
||||||
|
|
||||||
# Handle credentials
|
# Handle credentials
|
||||||
|
|||||||
@@ -58,4 +58,4 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import yaml
|
"""
|
||||||
from torch.cuda import device_count
|
Runpod train entrypoint
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator
|
|
||||||
|
|
||||||
|
|
||||||
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
||||||
@@ -11,20 +11,23 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
|||||||
:param config_path: Path to the YAML config file
|
:param config_path: Path to the YAML config file
|
||||||
:param gpu_id: GPU ID to use (default: "0")
|
:param gpu_id: GPU ID to use (default: "0")
|
||||||
:param preprocess: Whether to run preprocessing (default: True)
|
:param preprocess: Whether to run preprocessing (default: True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# First check if preprocessing is needed
|
# First check if preprocessing is needed
|
||||||
if preprocess:
|
if preprocess:
|
||||||
# Preprocess command
|
# Preprocess command
|
||||||
preprocess_cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
|
preprocess_cmd = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
|
||||||
|
)
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
preprocess_cmd,
|
preprocess_cmd,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.STDOUT
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for line in process.stdout:
|
if process.stdout is not None:
|
||||||
yield f"Preprocessing: {line.decode().strip()}"
|
async for line in process.stdout:
|
||||||
|
yield f"Preprocessing: {line.decode().strip()}"
|
||||||
await process.wait()
|
await process.wait()
|
||||||
yield "Preprocessing completed."
|
yield "Preprocessing completed."
|
||||||
else:
|
else:
|
||||||
@@ -33,11 +36,10 @@ async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
|||||||
# Training command
|
# Training command
|
||||||
train_cmd = f"axolotl train {config_path}"
|
train_cmd = f"axolotl train {config_path}"
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
train_cmd,
|
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.STDOUT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for line in process.stdout:
|
if process.stdout is not None:
|
||||||
yield f"Training: {line.decode().strip()}"
|
async for line in process.stdout:
|
||||||
|
yield f"Training: {line.decode().strip()}"
|
||||||
await process.wait()
|
await process.wait()
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
|
"""
|
||||||
|
Runpod launcher utils
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +32,8 @@ def make_valid_config(input_args):
|
|||||||
:return: str, path to the updated config file
|
:return: str, path to the updated config file
|
||||||
"""
|
"""
|
||||||
# Load default config
|
# Load default config
|
||||||
all_args = yaml.safe_load(open("config/config.yaml", "r"))
|
with open("config/config.yaml", "r", encoding="utf-8") as fin:
|
||||||
|
all_args = yaml.safe_load(fin)
|
||||||
|
|
||||||
if not input_args:
|
if not input_args:
|
||||||
print("No args provided, using defaults")
|
print("No args provided, using defaults")
|
||||||
@@ -38,7 +44,7 @@ def make_valid_config(input_args):
|
|||||||
updated_config_path = "config/updated_config.yaml"
|
updated_config_path = "config/updated_config.yaml"
|
||||||
|
|
||||||
# Save updated config to new file
|
# Save updated config to new file
|
||||||
with open(updated_config_path, "w") as f:
|
with open(updated_config_path, "w", encoding="utf-8") as f:
|
||||||
yaml.dump(all_args, f)
|
yaml.dump(all_args, f)
|
||||||
|
|
||||||
return updated_config_path
|
return updated_config_path
|
||||||
@@ -57,9 +63,9 @@ def set_config_env_vars(args: dict):
|
|||||||
"""Convert Python values to string format for environment variables"""
|
"""Convert Python values to string format for environment variables"""
|
||||||
if value is None:
|
if value is None:
|
||||||
return ""
|
return ""
|
||||||
elif isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return str(value).lower()
|
return str(value).lower()
|
||||||
elif isinstance(value, (list, dict)):
|
if isinstance(value, (list, dict)):
|
||||||
return str(value)
|
return str(value)
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|||||||
@@ -90,4 +90,4 @@
|
|||||||
"12.0"
|
"12.0"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user