refactor: move more logic to InstallableExtension

This commit is contained in:
Vlad Stan 2023-01-11 10:25:09 +02:00
parent cb6349fd76
commit 3ed2b3cdeb
4 changed files with 121 additions and 132 deletions

View file

@ -1,18 +1,8 @@
import hashlib
import importlib
import os
import re
import urllib.request
from http import HTTPStatus
from typing import List
import httpx
from fastapi.exceptions import HTTPException
from loguru import logger
from lnbits.helpers import InstallableExtension, get_valid_extensions
from lnbits.settings import settings
from . import db as core_db
from .crud import update_migration_version
@ -48,97 +38,3 @@ async def run_migration(db, migrations_module, current_version):
else:
async with core_db.connect() as conn:
await update_migration_version(conn, db_name, version)
async def get_installable_extensions() -> List[InstallableExtension]:
extension_list: List[InstallableExtension] = []
async with httpx.AsyncClient() as client:
for url in settings.lnbits_extensions_manifests:
resp = await client.get(url)
if resp.status_code != 200:
raise HTTPException(
status_code=404,
detail=f"Unable to fetch extension list for repository: {url}",
)
for e in resp.json()["extensions"]:
extension_list += [
InstallableExtension(
id=e["id"],
name=e["name"],
archive=e["archive"],
hash=e["hash"],
short_description=e["shortDescription"],
details=e["details"] if "details" in e else "",
icon=e["icon"],
dependencies=e["dependencies"] if "dependencies" in e else [],
)
]
return extension_list
async def get_installable_extension_meta(
ext_id: str, hash: str
) -> InstallableExtension:
installable_extensions: List[
InstallableExtension
] = await get_installable_extensions()
valid_extensions = [
e for e in installable_extensions if e.id == ext_id and e.hash == hash
]
if len(valid_extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = valid_extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
return extension
def download_extension_archive(archive: str, ext_zip_file: str, hash: str):
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()

View file

@ -1,6 +1,5 @@
import asyncio
import hashlib
import importlib
import inspect
import json
import os
@ -30,26 +29,18 @@ from fastapi import (
)
from fastapi.exceptions import HTTPException
from fastapi.params import Body
from genericpath import isfile
from loguru import logger
from pydantic import BaseModel
from pydantic.fields import Field
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.responses import StreamingResponse
from lnbits import bolt11, lnurl
from lnbits.core.helpers import (
download_extension_archive,
file_hash,
get_installable_extension_meta,
get_installable_extensions,
migrate_extension_database,
)
from lnbits.core.helpers import migrate_extension_database
from lnbits.core.models import Payment, User, Wallet
from lnbits.decorators import (
WalletTypeInfo,
check_admin,
check_user_exists,
get_key_type,
require_admin_key,
require_invoice_key,
@ -737,34 +728,34 @@ async def websocket_update_get(item_id: str, data: str):
async def api_install_extension(
ext_id: str, hash: str, user: User = Depends(check_admin)
):
ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash)
download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash)
ext_info: InstallableExtension = await InstallableExtension.get_extension_info(
ext_id, hash
)
ext_info.download_archive()
try:
ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True)
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions")
ext_upgrade_dir = os.path.join(
"lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}"
"lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}"
)
os.makedirs("lnbits/upgrades", exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
module_name = f"lnbits.extensions.{ext_id}"
module_installed = module_name in sys.modules
# todo: is admin only
ext = Extension(
code=ext_meta.id,
code=ext_info.id,
is_valid=True,
is_admin_only=False,
name=ext_meta.name,
hash=ext_meta.hash if module_installed else "",
name=ext_info.name,
hash=ext_info.hash if module_installed else "",
)
current_versions = await get_dbversions()
@ -791,8 +782,8 @@ async def api_install_extension(
except Exception as ex:
logger.warning(ex)
# remove downloaded archive
if os.path.isfile(ext_meta.zip_path):
os.remove(ext_meta.zip_path)
if os.path.isfile(ext_info.zip_path):
os.remove(ext_info.zip_path)
# remove module from extensions
shutil.rmtree(ext_dir, True)
@ -804,7 +795,9 @@ async def api_install_extension(
@core_app.delete("/api/v1/extension/{ext_id}")
async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)):
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
extension_list: List[
InstallableExtension
] = await InstallableExtension.get_installable_extensions()
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,

View file

@ -11,7 +11,6 @@ from pydantic.types import UUID4
from starlette.responses import HTMLResponse, JSONResponse
from lnbits.core import db
from lnbits.core.helpers import get_installable_extensions
from lnbits.core.models import User
from lnbits.decorators import check_admin, check_user_exists
from lnbits.helpers import template_renderer, url_for
@ -81,7 +80,9 @@ async def extensions_install(
)
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
extension_list: List[
InstallableExtension
] = await InstallableExtension.get_installable_extensions()
except Exception as ex:
logger.warning(ex)
raise HTTPException(

View file

@ -1,12 +1,18 @@
import glob
import hashlib
import json
import os
import shutil
import urllib.request
from http import HTTPStatus
from typing import Any, List, NamedTuple, Optional
import httpx
import jinja2
import shortuuid # type: ignore
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.jinja2_templating import Jinja2Templates
@ -52,10 +58,87 @@ class InstallableExtension(NamedTuple):
def zip_path(self):
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, self.id)
shutil.rmtree(ext_data_dir, True)
return os.path.join(extensions_data_dir, f"{self.id}.zip")
def download_archive(self):
ext_zip_file = self.zip_path
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(self.archive, ext_zip_file)
except Exception as ex:
logger.warning(ex)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if self.hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
@classmethod
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
installable_extensions: List[
InstallableExtension
] = await InstallableExtension.get_installable_extensions()
valid_extensions = [
e for e in installable_extensions if e.id == ext_id and e.hash == hash
]
if len(valid_extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = valid_extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
return extension
@classmethod
async def get_installable_extensions(cls) -> List["InstallableExtension"]:
extension_list: List[InstallableExtension] = []
async with httpx.AsyncClient() as client:
for url in settings.lnbits_extensions_manifests:
resp = await client.get(url)
if resp.status_code != 200:
raise HTTPException(
status_code=404,
detail=f"Unable to fetch extension list for repository: {url}",
)
for e in resp.json()["extensions"]:
extension_list += [
InstallableExtension(
id=e["id"],
name=e["name"],
archive=e["archive"],
hash=e["hash"],
short_description=e["shortDescription"],
details=e["details"] if "details" in e else "",
icon=e["icon"],
dependencies=e["dependencies"]
if "dependencies" in e
else [],
)
]
return extension_list
class ExtensionManager:
def __init__(self, include_disabled_exts=False):
@ -289,3 +372,19 @@ def get_current_extension_name() -> str:
except:
ext_name = extension_director_name
return ext_name
def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()