diff --git a/bale/drawer.py b/bale/drawer.py index 07b591a..999a816 100644 --- a/bale/drawer.py +++ b/bale/drawer.py @@ -66,7 +66,7 @@ class Drawer(object): ) self._table.tailwind.width("full") self._table.visible = False - for name in ssh.get_hosts("data"): + for name in ssh.get_hosts(): self._add_host_to_table(name) chevron = ui.button(icon="chevron_left", color=None, on_click=toggle_drawer).props("padding=0px") chevron.classes("absolute") @@ -87,7 +87,7 @@ class Drawer(object): save = None async def send_key(): - s = ssh.Ssh("data", host=host_input.value, hostname=hostname_input.value, username=username_input.value, password=password_input.value) + s = ssh.Ssh(host_input.value, hostname=hostname_input.value, username=username_input.value, password=password_input.value) result = await s.send_key() if result.stdout.strip() != "": el.notify(result.stdout.strip(), multi_line=True, type="positive") @@ -110,12 +110,12 @@ class Drawer(object): c.tailwind.width("full") with ui.scroll_area() as s: s.tailwind.height("[160px]") - public_key = await ssh.get_public_key("data") + public_key = await ssh.get_public_key() ui.label(public_key).classes("text-secondary break-all") el.DButton("SAVE", on_click=lambda: host_dialog.submit("save")).bind_enabled_from(save_em, "no_errors") host_input.value = name if name != "": - s = ssh.Ssh(path="data", host=name) + s = ssh.Ssh(name) hostname_input.value = s.hostname username_input.value = s.username @@ -125,11 +125,11 @@ class Drawer(object): default = Tab(spinner=None).common.get("default", "") if default == name: Tab(spinner=None).common["default"] = "" - ssh.Ssh(path="data", host=name).remove() + ssh.Ssh(name).remove() for row in self._table.rows: if name == row["name"]: self._table.remove_rows(row) - ssh.Ssh(path="data", host=host_input.value, hostname=hostname_input.value, username=username_input.value) + ssh.Ssh(host_input.value, hostname=hostname_input.value, username=username_input.value) self._add_host_to_table(host_input.value) def _modify_host(self, mode): @@ -162,7 +162,7 @@ class Drawer(object): if self._selection_mode == "remove": if len(e.selection) > 0: for row in e.selection: - ssh.Ssh(path="data", host=row["name"]).remove() + ssh.Ssh(row["name"]).remove() self._table.remove_rows(row) self._modify_host(None) diff --git a/bale/interfaces/cli.py b/bale/interfaces/cli.py index 8ff8351..ee01d44 100644 --- a/bale/interfaces/cli.py +++ b/bale/interfaces/cli.py @@ -116,30 +116,49 @@ class Cli: self._terminate.clear() self._busy = False return Result( - command=command, return_code=process.returncode, stdout_lines=self.stdout.copy(), stderr_lines=self.stderr.copy(), terminated=terminated, truncated=self._truncated + command=command, + return_code=process.returncode, + stdout_lines=self.stdout.copy(), + stderr_lines=self.stderr.copy(), + terminated=terminated, + truncated=self._truncated, ) - async def shell(self, command: str) -> Result: + async def shell(self, command: str, max_output_lines: int = 0) -> Result: self._busy = True try: process = await asyncio.create_subprocess_shell(command, stdout=PIPE, stderr=PIPE) if process is not None and process.stdout is not None and process.stderr is not None: - self.clear_buffers() + self.stdout.clear() + self.stderr.clear() self._terminate.clear() + self._truncated = False + terminated = False now = datetime.now().strftime("%Y/%m/%d %H:%M:%S") self.prefix_line = f"<{now}> {command}\n" for terminal in self._stdout_terminals: terminal.call_terminal_method("write", "\n" + self.prefix_line) await asyncio.gather( + self._controller(process=process, max_output_lines=max_output_lines), self._read_stdout(stream=process.stdout), self._read_stderr(stream=process.stderr), ) + if self._terminate.is_set(): + terminated = True await process.wait() except Exception as e: raise e finally: + self._terminate.clear() self._busy = False - return Result(command=command, return_code=process.returncode, stdout_lines=self.stdout.copy(), stderr_lines=self.stderr.copy(), terminated=False) + return Result( + command=command, + return_code=process.returncode, + stdout_lines=self.stdout.copy(), + stderr_lines=self.stderr.copy(), + terminated=terminated, + truncated=self._truncated, + ) def clear_buffers(self): self.prefix_line = "" diff --git a/bale/interfaces/ssh.py b/bale/interfaces/ssh.py index 44b87ca..6621fa3 100644 --- a/bale/interfaces/ssh.py +++ b/bale/interfaces/ssh.py @@ -1,12 +1,10 @@ -from typing import Dict, Union +from typing import Dict, Optional, Union import os -import asyncio from pathlib import Path -from bale.result import Result -from bale.interfaces.cli import Cli +from bale.interfaces import cli -def get_hosts(path): +def get_hosts(path: str = "data"): path = f"{Path(path).resolve()}/config" hosts = [] try: @@ -20,32 +18,42 @@ def get_hosts(path): return [] -async def get_public_key(path: str) -> str: +async def get_public_key(path: str = "data") -> str: path = Path(path).resolve() if "id_rsa.pub" not in os.listdir(path) or "id_rsa" not in os.listdir(path): - await Cli().shell(f"""ssh-keygen -t rsa -N "" -f {path}/id_rsa""") + await cli.Cli().shell(f"""ssh-keygen -t rsa -N "" -f {path}/id_rsa""") with open(f"{path}/id_rsa.pub", "r", encoding="utf-8") as reader: return reader.read() -class Ssh(Cli): - def __init__(self, path: str, host: str, hostname: str = "", username: str = "", password: Union[str, None] = None, seperator: bytes = b"\n") -> None: +class Ssh(cli.Cli): + def __init__( + self, + host: str, + hostname: str = "", + username: str = "", + password: Optional[str] = None, + options: Optional[Dict[str, str]] = None, + path: str = "data", + seperator: bytes = b"\n", + ) -> None: super().__init__(seperator=seperator) self._raw_path: str = path self._path: Path = Path(path).resolve() - self.host: str = host + self.host: str = host.replace(" ", "") self.password: Union[str, None] = password self.use_key: bool = False if password is None: self.use_key = True + self.options: Optional[Dict[str, str]] = options self.key_path: str = f"{self._path}/id_rsa" - self._base_cmd: str = "" - self._full_cmd: str = "" + self._base_command: str = "" + self._full_command: str = "" self._config_path: str = f"{self._path}/config" self._config: Dict[str, Dict[str, str]] = {} self.read_config() - self.hostname: str = hostname or self._config.get(host, {}).get("HostName", "") - self.username: str = username or self._config.get(host, {}).get("User", "") + self.hostname: str = hostname or self._config.get(host.replace(" ", ""), {}).get("HostName", "") + self.username: str = username or self._config.get(host.replace(" ", ""), {}).get("User", "") self.set_config() def read_config(self) -> None: @@ -57,7 +65,7 @@ class Ssh(Cli): if line == "" or line.startswith("#"): continue if line.startswith("Host "): - current_host = line.split(" ")[1].strip() + current_host = line.split(" ", 1)[1].strip().replace('"', "") self._config[current_host] = {} else: key, value = line.split(" ", 1) @@ -76,30 +84,40 @@ class Ssh(Cli): def set_config(self) -> None: self._config[self.host] = { "IdentityFile": self.key_path, - "PasswordAuthentication": "no", "StrictHostKeychecking": "no", "IdentitiesOnly": "yes", } + self._config[self.host]["PasswordAuthentication"] = "no" if self.password is None else "yes" if self.hostname != "": self._config[self.host]["HostName"] = self.hostname if self.username != "": self._config[self.host]["User"] = self.username + if self.options is not None: + self._config[self.host].update(self.options) self.write_config() def remove(self) -> None: del self._config[self.host] self.write_config() - async def execute(self, command: str, max_output_lines: int = 0) -> Result: - self._base_cmd = f"{'' if self.use_key else f'sshpass -p {self.password} '} ssh -F {self._config_path} {self.host}" - self._full_cmd = f"{self._base_cmd} {command}" - return await super().execute(self._full_cmd, max_output_lines) + async def execute(self, command: str, max_output_lines: int = 0) -> cli.Result: + self._full_command = f"{self.base_command} {command}" + return await super().execute(self._full_command, max_output_lines) - async def send_key(self) -> Result: + async def shell(self, command: str, max_output_lines: int = 0) -> cli.Result: + self._full_command = f"{self.base_command} {command}" + return await super().shell(self._full_command, max_output_lines) + + async def send_key(self) -> cli.Result: await get_public_key(self._raw_path) cmd = f"sshpass -p {self.password} " f"ssh-copy-id -o IdentitiesOnly=yes -i {self.key_path} " f"-o StrictHostKeychecking=no {self.username}@{self.hostname}" - return await super().execute(cmd) + return await super().shell(cmd) @property def config_path(self): return self._config_path + + @property + def base_command(self): + self._base_command = f'{"" if self.use_key else f"sshpass -p {self.password} "} ssh -F {self._config_path} {self.host}' + return self._base_command diff --git a/bale/interfaces/zfs.py b/bale/interfaces/zfs.py index f9b2789..e480b1f 100644 --- a/bale/interfaces/zfs.py +++ b/bale/interfaces/zfs.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import re from datetime import datetime from dataclasses import dataclass @@ -243,8 +243,17 @@ class Zfs: class Ssh(ssh.Ssh, Zfs): - def __init__(self, path: str, host: str, hostname: str = "", username: str = "", password: Union[str, None] = None) -> None: - super().__init__(path, host, hostname, username, password) + def __init__( + self, + host: str, + hostname: str = "", + username: str = "", + password: Optional[str] = None, + options: Optional[Dict[str, str]] = None, + path: str = "data", + seperator: bytes = b"\n", + ) -> None: + super().__init__(host, hostname, username, password, options, path, seperator) Zfs.__init__(self) def notify(self, command: str): diff --git a/bale/tabs/__init__.py b/bale/tabs/__init__.py index ec81498..6915484 100644 --- a/bale/tabs/__init__.py +++ b/bale/tabs/__init__.py @@ -83,7 +83,7 @@ class Tab: @classmethod def register_connection(cls, host: str) -> None: - cls._zfs[host] = Ssh(path="data", host=host) + cls._zfs[host] = Ssh(host) async def _display_result(self, result: Result) -> None: with ui.dialog() as dialog, el.Card(): diff --git a/bale/tabs/automation.py b/bale/tabs/automation.py index 3086c33..06e1bb7 100644 --- a/bale/tabs/automation.py +++ b/bale/tabs/automation.py @@ -48,7 +48,7 @@ def populate_job_handler(app: str, job_id: str, host: str): tab = Tab(host=None, spinner=None) if job_id not in job_handlers: if app == "remote": - job_handlers[job_id] = ssh.Ssh("data", host=host) + job_handlers[job_id] = ssh.Ssh(host) else: job_handlers[job_id] = cli.Cli() return job_handlers[job_id]