lint
This commit is contained in:
@@ -11,12 +11,11 @@ from accelerate.logging import get_logger
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.train import (
|
from axolotl.train import (
|
||||||
TrainDatasetMeta,
|
TrainDatasetMeta,
|
||||||
setup_model_and_tokenizer,
|
setup_model_and_tokenizer,
|
||||||
)
|
)
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.train import TrainDatasetMeta
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from peft import (
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from axolotl.loaders.utils import (
|
|||||||
load_model_config,
|
load_model_config,
|
||||||
)
|
)
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from transformers import (
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from transformers import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
|
|||||||
@@ -59,12 +59,14 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
self.telemetry_manager.send_event(
|
self.telemetry_manager.send_event(
|
||||||
event_type="train-end",
|
event_type="train-end",
|
||||||
properties={
|
properties={
|
||||||
"loss": state.log_history[-1].get("loss", 0)
|
"loss": (
|
||||||
if state.log_history
|
state.log_history[-1].get("loss", 0) if state.log_history else None
|
||||||
else None,
|
),
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
"learning_rate": (
|
||||||
if state.log_history
|
state.log_history[-1].get("learning_rate", 0)
|
||||||
else None,
|
if state.log_history
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
| self.tracker.metrics.to_dict(),
|
| self.tracker.metrics.to_dict(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -307,9 +307,11 @@ class TelemetryManager:
|
|||||||
gpu_info.append(
|
gpu_info.append(
|
||||||
{
|
{
|
||||||
"name": torch.hip.get_device_name(i),
|
"name": torch.hip.get_device_name(i),
|
||||||
"memory": torch.hip.get_device_properties(i).total_memory
|
"memory": (
|
||||||
if hasattr(torch.hip, "get_device_properties")
|
torch.hip.get_device_properties(i).total_memory
|
||||||
else None,
|
if hasattr(torch.hip, "get_device_properties")
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -202,8 +202,8 @@ class RuntimeMetricsTracker:
|
|||||||
memory_used = self._get_allocated_memory()
|
memory_used = self._get_allocated_memory()
|
||||||
for i, memory in memory_used.items():
|
for i, memory in memory_used.items():
|
||||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||||
memory_metrics[
|
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
|
||||||
f"gpu_{i}_peak_memory_bytes"
|
self.metrics.peak_gpu_memory.get(i, 0)
|
||||||
] = self.metrics.peak_gpu_memory.get(i, 0)
|
)
|
||||||
|
|
||||||
return memory_metrics
|
return memory_metrics
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from axolotl.loaders import (
|
|||||||
load_tokenizer,
|
load_tokenizer,
|
||||||
)
|
)
|
||||||
from axolotl.telemetry.errors import send_errors
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -91,11 +90,11 @@ def setup_model_and_tokenizer(
|
|||||||
if model.generation_config is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
TELEMETRY_MANAGER.track_event(
|
TELEMETRY_MANAGER.send_event(
|
||||||
event_type="model-load", properties=model.config.to_dict()
|
event_type="model-load", properties=model.config.to_dict()
|
||||||
)
|
)
|
||||||
if peft_config:
|
if peft_config:
|
||||||
TELEMETRY_MANAGER.track_event(
|
TELEMETRY_MANAGER.send_event(
|
||||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for telemetry callback module."""
|
"""Tests for telemetry callback module."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -15,9 +16,9 @@ def calc_expected_metrics(step, last_step, current_time, last_time, start_time=9
|
|||||||
time_diff = current_time - last_time
|
time_diff = current_time - last_time
|
||||||
step_diff = step - last_step
|
step_diff = step - last_step
|
||||||
return {
|
return {
|
||||||
"steps_per_second": step_diff / time_diff
|
"steps_per_second": (
|
||||||
if time_diff > 0 and step_diff > 0
|
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
|
||||||
else 0,
|
),
|
||||||
"time_since_last_report": time_diff,
|
"time_since_last_report": time_diff,
|
||||||
"elapsed_time": current_time - start_time,
|
"elapsed_time": current_time - start_time,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for telemetry error utilities"""
|
"""Tests for telemetry error utilities"""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for TelemetryManager class and utilities"""
|
"""Tests for TelemetryManager class and utilities"""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name,protected-access
|
# pylint: disable=redefined-outer-name,protected-access
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -38,9 +39,13 @@ def telemetry_manager_class():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def manager(telemetry_manager_class, mock_whitelist):
|
def manager(telemetry_manager_class, mock_whitelist):
|
||||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||||
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
with (
|
||||||
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
patch("posthog.capture"),
|
||||||
), patch.dict(os.environ, {"RANK": "0"}):
|
patch("posthog.flush"),
|
||||||
|
patch("time.sleep"),
|
||||||
|
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||||
|
patch.dict(os.environ, {"RANK": "0"}),
|
||||||
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
# Manually enable for most tests
|
# Manually enable for most tests
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
@@ -49,8 +54,10 @@ def manager(telemetry_manager_class, mock_whitelist):
|
|||||||
|
|
||||||
def test_singleton_instance(telemetry_manager_class):
|
def test_singleton_instance(telemetry_manager_class):
|
||||||
"""Test that TelemetryManager is a singleton"""
|
"""Test that TelemetryManager is a singleton"""
|
||||||
with patch("posthog.capture"), patch("time.sleep"), patch.dict(
|
with (
|
||||||
os.environ, {"RANK": "0"}
|
patch("posthog.capture"),
|
||||||
|
patch("time.sleep"),
|
||||||
|
patch.dict(os.environ, {"RANK": "0"}),
|
||||||
):
|
):
|
||||||
first = telemetry_manager_class()
|
first = telemetry_manager_class()
|
||||||
second = telemetry_manager_class()
|
second = telemetry_manager_class()
|
||||||
@@ -60,8 +67,10 @@ def test_singleton_instance(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_telemetry_disabled_by_default(telemetry_manager_class):
|
def test_telemetry_disabled_by_default(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled by default (opt-in)"""
|
"""Test that telemetry is disabled by default (opt-in)"""
|
||||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch(
|
with (
|
||||||
"logging.Logger.info"
|
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||||
|
patch("time.sleep"),
|
||||||
|
patch("logging.Logger.info"),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
@@ -69,8 +78,9 @@ def test_telemetry_disabled_by_default(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
||||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch(
|
with (
|
||||||
"time.sleep"
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
||||||
|
patch("time.sleep"),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert manager.enabled
|
assert manager.enabled
|
||||||
@@ -78,8 +88,9 @@ def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}), patch(
|
with (
|
||||||
"time.sleep"
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
||||||
|
patch("time.sleep"),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
@@ -87,17 +98,21 @@ def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||||
with patch.dict(
|
with (
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
patch.dict(
|
||||||
), patch("time.sleep"):
|
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
||||||
|
),
|
||||||
|
patch("time.sleep"),
|
||||||
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||||
"""Test that telemetry is disabled for non-main processes"""
|
"""Test that telemetry is disabled for non-main processes"""
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}), patch(
|
with (
|
||||||
"time.sleep"
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
|
||||||
|
patch("time.sleep"),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
assert not manager.enabled
|
assert not manager.enabled
|
||||||
@@ -105,9 +120,11 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_opt_in_info_displayed(telemetry_manager_class):
|
def test_opt_in_info_displayed(telemetry_manager_class):
|
||||||
"""Test that opt-in info is displayed when telemetry is not configured"""
|
"""Test that opt-in info is displayed when telemetry is not configured"""
|
||||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch(
|
with (
|
||||||
"logging.Logger.warning"
|
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||||
) as mock_warning, patch("time.sleep"):
|
patch("logging.Logger.warning") as mock_warning,
|
||||||
|
patch("time.sleep"),
|
||||||
|
):
|
||||||
telemetry_manager_class()
|
telemetry_manager_class()
|
||||||
info_displayed = False
|
info_displayed = False
|
||||||
for call in mock_warning.call_args_list:
|
for call in mock_warning.call_args_list:
|
||||||
@@ -120,8 +137,9 @@ def test_opt_in_info_displayed(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
|
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
|
||||||
"""Test org whitelist functionality"""
|
"""Test org whitelist functionality"""
|
||||||
with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), patch.dict(
|
with (
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}
|
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||||
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
|
|
||||||
@@ -150,8 +168,9 @@ def test_system_info_collection(manager):
|
|||||||
|
|
||||||
def test_send_event(telemetry_manager_class):
|
def test_send_event(telemetry_manager_class):
|
||||||
"""Test basic event sending"""
|
"""Test basic event sending"""
|
||||||
with patch("posthog.capture") as mock_capture, patch.dict(
|
with (
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}
|
patch("posthog.capture") as mock_capture,
|
||||||
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
|
|
||||||
@@ -171,8 +190,9 @@ def test_send_event(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_send_system_info(telemetry_manager_class):
|
def test_send_system_info(telemetry_manager_class):
|
||||||
"""Test sending system info"""
|
"""Test sending system info"""
|
||||||
with patch("posthog.capture") as mock_capture, patch.dict(
|
with (
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}
|
patch("posthog.capture") as mock_capture,
|
||||||
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
manager.send_system_info()
|
manager.send_system_info()
|
||||||
@@ -183,8 +203,9 @@ def test_send_system_info(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_redacted_properties(telemetry_manager_class):
|
def test_redacted_properties(telemetry_manager_class):
|
||||||
"""Test path redaction in send_event method"""
|
"""Test path redaction in send_event method"""
|
||||||
with patch("posthog.capture") as mock_capture, patch.dict(
|
with (
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}
|
patch("posthog.capture") as mock_capture,
|
||||||
|
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
# Test with properties containing various paths and non-paths
|
# Test with properties containing various paths and non-paths
|
||||||
@@ -237,9 +258,10 @@ def test_disable_telemetry(manager):
|
|||||||
|
|
||||||
def test_exception_handling_during_send(manager):
|
def test_exception_handling_during_send(manager):
|
||||||
"""Test that exceptions in PostHog are handled gracefully"""
|
"""Test that exceptions in PostHog are handled gracefully"""
|
||||||
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
|
with (
|
||||||
"logging.Logger.warning"
|
patch("posthog.capture", side_effect=Exception("Test error")),
|
||||||
) as mock_warning:
|
patch("logging.Logger.warning") as mock_warning,
|
||||||
|
):
|
||||||
manager.send_event("test_event")
|
manager.send_event("test_event")
|
||||||
warning_logged = False
|
warning_logged = False
|
||||||
for call in mock_warning.call_args_list:
|
for call in mock_warning.call_args_list:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Tests for runtime metrics telemetry module"""
|
"""Tests for runtime metrics telemetry module"""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|||||||
Reference in New Issue
Block a user