When consuming configs inside a dir, ignore nested dirs and files without valid YAML extensions

This commit is contained in:
root 2023-03-27 18:44:27 -07:00
parent 30bb98dff0
commit 9a2efa9f0a

View File

@ -24,12 +24,19 @@ import yaml
class BackupRotator: class BackupRotator:
__DEFAULT_VALID_EXTENSIONS = [
"yaml",
"yml"
]
def __init__(self): def __init__(self):
self.__dry_run = False self.__dry_run = False
self.__configs = [] self.__configs = []
self.__config_paths = [] self.__config_paths = []
self.__calculated_actions = [] self.__calculated_actions = []
self.__valid_extensions = self.__DEFAULT_VALID_EXTENSIONS
def run(self, configs, dry_run: bool = False): def run(self, configs, dry_run: bool = False):
@ -77,8 +84,9 @@ class BackupRotator:
# Use each config path # Use each config path
for path in paths: for path in paths:
# If this is a single path # If this is a single file
if os.path.isfile(path): if os.path.isfile(path):
self._consume_config(path) self._consume_config(path)
# If this is a directory # If this is a directory
@ -86,7 +94,11 @@ class BackupRotator:
# Iterate over each file inside # Iterate over each file inside
for file_name in os.listdir(path): for file_name in os.listdir(path):
self._consume_config(os.path.join(path, file_name))
one_file = os.path.join(path, file_name)
if os.path.isfile(one_file) and self._check_file_extension(file_path=one_file, extensions=None):
self._consume_config(one_file)
def _consume_config(self, path: str): def _consume_config(self, path: str):
@ -386,3 +398,20 @@ class BackupRotator:
self.log("No value found for \"minimum-items\"; Will not enforce minimum item constraint.") self.log("No value found for \"minimum-items\"; Will not enforce minimum item constraint.")
return minimum_items return minimum_items
def _check_file_extension(self, file_path, extensions: list=None):
if extensions is None:
extensions = self.__valid_extensions
file_name, file_extension = os.path.splitext(file_path)
if len(file_extension) > 0 and file_extension[0] == ".":
file_extension = file_extension[1:]
file_extension = file_extension.lower()
for valid_extension in extensions:
#print(file_name, "---", file_extension, "---", valid_extension)
if file_extension == valid_extension:
return True
return False