diff --git a/domain/SSHConfigChanger.py b/domain/SSHConfigChanger.py index b06cd9b..e19312a 100755 --- a/domain/SSHConfigChanger.py +++ b/domain/SSHConfigChanger.py @@ -1,10 +1,12 @@ from domain.config.Config import Config +from domain.config.Config import Target from domain.Logger import Logger import os +from pathlib import Path import re import subprocess import sys @@ -12,10 +14,6 @@ import sys class SSHConfigChanger: - __DEFAULT_NORMAL_SSH_CONFIG_FILE_NAME = "config" - __DEFAULT_TARGET_NAME = "default" - - # def __init__( self, action_interface, action_command, @@ -62,16 +60,16 @@ class SSHConfigChanger: self.quit(f"We don't need to run for action command: {self.__action_command}") # Determine which ssh config file we should use - ssh_config_name = self.determine_ssh_config_name() - if not ssh_config_name: + ssh_config_target = self.determine_ssh_config_target() + if ssh_config_target is None: self.die("Unable to determine appropriate ssh config name; Quitting") self.__logger.log( - f"Determined ssh config name: {ssh_config_name}" + f"Determined ssh config name: {ssh_config_target.name}" ) # Make paths - ssh_config_path_link = self.__config.ssh_dir / self.__DEFAULT_NORMAL_SSH_CONFIG_FILE_NAME - ssh_config_path_target = self.__config.ssh_dir / ssh_config_name + ssh_config_path_link = self.__config.ssh_dir / self.__config.default_normal_ssh_config_file_name + ssh_config_path_target = self.__config.ssh_dir / ssh_config_target.ssh_config_file_name self.__logger.log( f"Selecting source config file \"{ssh_config_path_target}\"" f" for link \"{ssh_config_path_link}\"" @@ -104,44 +102,37 @@ class SSHConfigChanger: self.__logger.log("Finished") # - def require_symlink_or_none(self, file_path): + def require_symlink_or_none(self, file_path: Path): - # - if ( not os.path.isfile(file_path) or os.path.islink(file_path) ): - return True - - # - self.die("For safety, we cannot continue if the target link exists and is a file (" + file_path + ")") + if file_path.is_file() and file_path.exists() and not file_path.is_symlink(): + self.die( + f"For safety, refusing to continue because the target ssh config file exists and is not a symlink:" + f" {file_path}" + ) - def determine_ssh_config_name(self): + def determine_ssh_config_target(self) -> Target: # self.__logger.log("Attempting to determine SSH config name") + # Start off by assuming the default target + # noinspection PyTypeChecker + selected_target = None + selected_target: Target + # Check each section - found_ssh_config_name = None for target_name in self.__config.targets: - target = self.__config.targets[target_name] self.__logger.log(f"Examining target: {target_name}") - # Don't examine default if anything is picked already - if target_name == self.__DEFAULT_TARGET_NAME and found_ssh_config_name: - self.__logger.log( - f"Skipping default section ({self.__DEFAULT_TARGET_NAME}) because we've already found at least one match: {found_ssh_config_name}" - ) + if selected_target is not None and target_name == selected_target.name: + self.__logger.log(f"Ignoring target because it is already selected: {target_name}") continue - # Check the interface - if ( - # Matches, if current interface found in adapters - self.__action_interface in target.adapters - - # Can also match if we're in the default section - or target_name == self.__DEFAULT_TARGET_NAME - ): - pass - else: + target = self.__config.targets[target_name] + + # Matches, if current interface found in adapters + if self.__action_interface not in target.adapters: self.__logger.log( f"Target \"{target_name}\" didn't match any interfaces; Skipping" ) @@ -151,36 +142,41 @@ class SSHConfigChanger: interface_ssid = self.get_interface_ssid(self.__action_interface) if not interface_ssid: self.__logger.log( - f"Interface \"{interface_ssid}\" isn't connected to anything ... " + f"Interface \"{interface_ssid}\" isn't connected to anything; Done looking" ) + break self.__logger.log( f"Interface \"{self.__action_interface}\" is currently connected to: \"{interface_ssid}\"" ) - # Must also match at least one SSID, - # OR we're in the default section - if interface_ssid not in target.ssids and target_name != self.__DEFAULT_TARGET_NAME: + # Must also match at least one SSID + if interface_ssid in target.ssids: + self.__logger.log( - f"Did not find SSID \"{interface_ssid}\" in target ssids: " + str(target.ssids) + f"Found SSID \"{interface_ssid}\" in target {target_name}" ) - continue - - # Found a match! - found_ssh_config_name = target.ssh_config_name - self.__logger.log( - f"Found matching ssh config name: {found_ssh_config_name}" - ) + + # Only override selected target if this one has less SSIDs, + # or there is no currently selected target + if selected_target is None: + self.__logger.log( + f"Found first suitable target: {target_name}" + ) + selected_target = target + if len(target.ssids) < len(selected_target.ssids): + self.__logger.log( + f"Target \"{target_name}\"" + f" seems to be a better match than \"{selected_target.name}\"" + f" because it has fewer specified SSIDs" + f" ({len(target.ssids)} vs. {len(selected_target.ssids)})" + ) + selected_target = target - # Didn't find anything? Go default ... - if not found_ssh_config_name: - if self.__DEFAULT_TARGET_NAME in self.__config.targets.keys(): - target = self.__config.targets[self.__DEFAULT_TARGET_NAME] - found_ssh_config_name = target.ssh_config_name - self.__logger.log( - f"No matches found; Defaulting to: {found_ssh_config_name}" - ) + if selected_target is None: + selected_target = self.__config.targets[self.__config.default_target_name] + self.__logger.log(f"No suitable target found; Defaulting to: {selected_target.name}") - return found_ssh_config_name + return selected_target @staticmethod def get_interface_ssid(interface_name): diff --git a/domain/config/Config.py b/domain/config/Config.py index 69562ba..1c49d4a 100644 --- a/domain/config/Config.py +++ b/domain/config/Config.py @@ -21,14 +21,14 @@ class Target: self.__ssids = [] # noinspection PyTypeChecker - self.__ssh_config_name: str = None + self.__ssh_config_file_name: str = None def __str__(self): s = "" s += f"Target: {self.__name}" - s += f"\n> SSH config file name: {self.__ssh_config_name}" + s += f"\n> SSH config file name: {self.__ssh_config_file_name}" s += f"\n> Adapters: " if len(self.__adapters_names) > 0: @@ -64,6 +64,15 @@ class Target: self.__data = data + assert "config-file-name" in self.__data.keys(), ( + f"Name of ssh config file must be present at config-file-name" + ) + config_file_name = self.__data["config-file-name"] + assert isinstance(config_file_name, str), ( + f"config-file-name must be a string, but got: {type(config_file_name).__name__}" + ) + self.__ssh_config_file_name = config_file_name + if "adapters" in self.__data.keys(): adapters = self.__data["adapters"] if isinstance(adapters, list): @@ -107,13 +116,14 @@ class Target: assert len(self.__adapters_names) > 0, ( f"At least one adapter must be configured at target-name::adapters" ) - assert len(self.__ssids) > 0, ( - f"At least one ssid must be configured at target-name::ssids" - ) @property - def ssh_config_name(self) -> str: - return self.__ssh_config_name + def name(self) -> str: + return self.__name + + @property + def ssh_config_file_name(self) -> str: + return self.__ssh_config_file_name @property def adapters(self) -> list[str]: @@ -126,6 +136,9 @@ class Target: class Config: + __DEFAULT_NORMAL_SSH_CONFIG_FILE_NAME = "config" + __DEFAULT_SSH_DIRECTORY_NAME = ".ssh" + def __init__(self, logger: Logger, file_path: str): self.__logger = logger @@ -142,7 +155,9 @@ class Config: self.__data = {} self.__dry_run = False - self.__ssh_dir = Path(os.path.expanduser("~")) + self.__ssh_dir = Path(os.path.expanduser("~")) / self.__DEFAULT_SSH_DIRECTORY_NAME + # noinspection PyTypeChecker + self.__default_target_name: str = None self.__targets = {} self._load_config() @@ -155,6 +170,7 @@ class Config: s = "" s += "*** Config ***" + s += "\n Dry run: " + "True" if self.__dry_run else "False" for target in self.__targets.values(): s += "\n" + str(target) @@ -184,7 +200,7 @@ class Config: d = options["dry-run"] assert isinstance(d, bool), "options::dry-run must be a bool" if d: - self.__logger.log(f"Found configured dry run") + self.__logger.complain(f"Dry run enabled in config") self.__dry_run = d if "ssh-dir" in options.keys(): @@ -195,14 +211,16 @@ class Config: else: self.__logger.log(f"options::ssh-dir not found") - self.__targets = {} - - # Setup a default target - t = Target( - logger=self.__logger, - name="default", + assert "default-target" in options.keys(), ( + f"Must specify the name of the default target at options::default-target" ) - self.__targets["default"] = t + default_target_name = options["default-target"] + assert isinstance(default_target_name, str), ( + f"Default target name must be a string but got: {type(default_target_name).__name__}" + ) + self.__default_target_name = default_target_name + + self.__targets = {} assert "targets" in self.__data.keys(), "Config should specify targets" targets = self.__data["targets"] @@ -225,6 +243,15 @@ class Config: raise e self.__targets[target_name] = t + + if self.__default_target_name not in self.__targets.keys(): + raise AssertionError( + f"Default target specified as {self.__default_target_name} but was not found in dict of targets" + ) + + @property + def default_normal_ssh_config_file_name(self) -> str: + return self.__DEFAULT_NORMAL_SSH_CONFIG_FILE_NAME @property def file_path(self) -> Path: @@ -238,6 +265,10 @@ class Config: def dry_run(self, b: bool): self.__dry_run = b + @property + def default_target_name(self) -> str: + return self.__default_target_name + @property def ssh_dir(self) -> Path | None: return self.__ssh_dir