refactor extension_manager to use pathlib.Path instead of os.path.join

This commit is contained in:
Pavol Rusnak 2023-03-12 18:28:40 +01:00 committed by Vlad Stan
parent 281d37f79c
commit a805d2a0e8

View file

@ -201,18 +201,18 @@ class InstallableExtension(BaseModel):
return "not-installed" return "not-installed"
@property @property
def zip_path(self) -> str: def zip_path(self) -> Path:
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") extensions_data_dir = Path(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True) os.makedirs(extensions_data_dir, exist_ok=True)
return os.path.join(extensions_data_dir, f"{self.id}.zip") return Path(extensions_data_dir, f"{self.id}.zip")
@property @property
def ext_dir(self) -> str: def ext_dir(self) -> Path:
return os.path.join("lnbits", "extensions", self.id) return Path(settings.lnbits_path, "extensions", self.id)
@property @property
def ext_upgrade_dir(self) -> str: def ext_upgrade_dir(self) -> Path:
return os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}") return Path("lnbits", "upgrades", f"{self.id}-{self.hash}")
@property @property
def module_name(self) -> str: def module_name(self) -> str:
@ -224,15 +224,19 @@ class InstallableExtension(BaseModel):
@property @property
def has_installed_version(self) -> bool: def has_installed_version(self) -> bool:
if not Path(self.ext_dir).is_dir(): if not self.ext_dir.is_dir():
return False return False
config_file = os.path.join(self.ext_dir, "config.json") config_file = Path(self.ext_dir, "config.json")
return Path(config_file).is_file() if not config_file.is_file():
return False
with open(config_file, "r") as json_file:
config_json = json.load(json_file)
return config_json.get("is_installed") is True
def download_archive(self): def download_archive(self):
logger.info(f"Downloading extension {self.name}.") logger.info(f"Downloading extension {self.name}.")
ext_zip_file = self.zip_path ext_zip_file = self.zip_path
if os.path.isfile(ext_zip_file): if ext_zip_file.isfile():
os.remove(ext_zip_file) os.remove(ext_zip_file)
try: try:
download_url(self.installed_release.archive, ext_zip_file) download_url(self.installed_release.archive, ext_zip_file)
@ -246,7 +250,7 @@ class InstallableExtension(BaseModel):
archive_hash = file_hash(ext_zip_file) archive_hash = file_hash(ext_zip_file)
if self.installed_release.hash and self.installed_release.hash != archive_hash: if self.installed_release.hash and self.installed_release.hash != archive_hash:
# remove downloaded archive # remove downloaded archive
if os.path.isfile(ext_zip_file): if ext_zip_file.isfile():
os.remove(ext_zip_file) os.remove(ext_zip_file)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
@ -255,20 +259,20 @@ class InstallableExtension(BaseModel):
def extract_archive(self): def extract_archive(self):
logger.info(f"Extracting extension {self.name}.") logger.info(f"Extracting extension {self.name}.")
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True) os.makedirs(Path("lnbits", "upgrades"), exist_ok=True)
shutil.rmtree(self.ext_upgrade_dir, True) shutil.rmtree(self.ext_upgrade_dir, True)
with zipfile.ZipFile(self.zip_path, "r") as zip_ref: with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
zip_ref.extractall(self.ext_upgrade_dir) zip_ref.extractall(self.ext_upgrade_dir)
generated_dir_name = os.listdir(self.ext_upgrade_dir)[0] generated_dir_name = os.listdir(self.ext_upgrade_dir)[0]
os.rename( os.rename(
os.path.join(self.ext_upgrade_dir, generated_dir_name), Path(self.ext_upgrade_dir, generated_dir_name),
os.path.join(self.ext_upgrade_dir, self.id), Path(self.ext_upgrade_dir, self.id),
) )
# Pre-packed extensions can be upgraded # Pre-packed extensions can be upgraded
# Mark the extension as installed so we know it is not the pre-packed version # Mark the extension as installed so we know it is not the pre-packed version
with open( with open(
os.path.join(self.ext_upgrade_dir, self.id, "config.json"), "r+" Path(self.ext_upgrade_dir, self.id, "config.json"), "r+"
) as json_file: ) as json_file:
config_json = json.load(json_file) config_json = json.load(json_file)
@ -286,8 +290,8 @@ class InstallableExtension(BaseModel):
shutil.rmtree(self.ext_dir, True) shutil.rmtree(self.ext_dir, True)
shutil.copytree( shutil.copytree(
os.path.join(self.ext_upgrade_dir, self.id), Path(self.ext_upgrade_dir, self.id),
os.path.join("lnbits", "extensions", self.id), Path(settings.lnbits_path, "extensions", self.id),
) )
logger.success(f"Extension {self.name} installed.") logger.success(f"Extension {self.name} installed.")
@ -306,7 +310,7 @@ class InstallableExtension(BaseModel):
def clean_extension_files(self): def clean_extension_files(self):
# remove downloaded archive # remove downloaded archive
if os.path.isfile(self.zip_path): if self.zip_path.isfile():
os.remove(self.zip_path) os.remove(self.zip_path)
# remove module from extensions # remove module from extensions