diff --git a/backend/ssh_client.py b/backend/ssh_client.py index 0841ad4..e1885c0 100644 --- a/backend/ssh_client.py +++ b/backend/ssh_client.py @@ -66,15 +66,24 @@ class SSHClient: self.connected = False break + def _conda_activate_cmd(self) -> str: + candidates = [ + f"/home/{self.username}/miniconda3/etc/profile.d/conda.sh", + f"/home/{self.username}/miniforge3/etc/profile.d/conda.sh", + f"/home/{self.username}/anaconda3/etc/profile.d/conda.sh", + "/opt/conda/etc/profile.d/conda.sh", + ] + checks = " || ".join( + f'[ -f "{p}" ] && source "{p}"' for p in candidates + ) + return f'( {checks} ) 2>/dev/null; conda activate synthetic-data 2>/dev/null' + def execute(self, command: str, use_conda: bool = True) -> tuple: if not self.is_connected(): raise Exception("Not connected to SSH server") if use_conda: - full_cmd = ( - f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && " - f"conda activate synthetic-data && {command}" - ) + full_cmd = f"{self._conda_activate_cmd()} && {command}" else: full_cmd = command @@ -90,10 +99,7 @@ class SSHClient: raise Exception("Not connected to SSH server") if use_conda: - full_cmd = ( - f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && " - f"conda activate synthetic-data && {command}" - ) + full_cmd = f"{self._conda_activate_cmd()} && {command}" else: full_cmd = command @@ -131,17 +137,14 @@ class SSHClient: channel.get_pty(term=term, width=width, height=height) channel.invoke_shell() - # Auto-activate conda env - activate = ( - f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && " - f"conda activate synthetic-data\n" - ) - channel.send(activate) + channel.send(self._conda_activate_cmd() + "\n") return channel def upload_file(self, local_path: str, remote_path: str): if not self.is_connected(): raise Exception("Not connected to SSH server") + remote_dir = remote_path.rsplit("/", 1)[0] + self.execute(f"mkdir -p '{remote_dir}'", use_conda=False) sftp = self.client.open_sftp() try: sftp.put(local_path, remote_path)