#!/usr/bin/python3
#
# This file is part of OpenMediaVault.
#
# @license   http://www.gnu.org/licenses/gpl.html GPL Version 3
# @author    Volker Theile <volker.theile@openmediavault.org>
# @copyright Copyright (c) 2009-2025 Volker Theile
#
# OpenMediaVault is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
#
# OpenMediaVault is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with OpenMediaVault. If not, see <http://www.gnu.org/licenses/>.

import datetime
import os
import re
import subprocess
import sys
from collections import namedtuple
from typing import Dict, List, Literal

import click
import openmediavault.config
import openmediavault.log
import openmediavault.procutils
import openmediavault.stringutils

import openmediavault

Snapshot = namedtuple('Snapshot', 'id otime name path mount_dir kind')
SnapshotName = str
SnapshotKind = Literal['custom', 'hourly',
                       'daily', 'weekly', 'monthly', 'yearly']
UnixTimestamp = int

SNAPSHOT_KINDS: List[SnapshotKind] = [
    'custom', 'hourly', 'daily', 'weekly', 'monthly', 'yearly']


def list_snapshots(mount_dir: str, verbose: bool = False) -> Dict[SnapshotName, Dict[SnapshotKind, Dict[UnixTimestamp, Snapshot]]]:
    openmediavault.log.info(f'Searching for snapshots in "{mount_dir}" ...',
                            verbose=verbose)
    # Get the names of the configured shared folders. These are used to
    # filter snapshots that can not be assigned to a shared folder.
    db: openmediavault.config.Database = openmediavault.config.Database()
    sf_objs: List[openmediavault.config.object.Object] = db.get(
        'conf.system.sharedfolder')
    sf_names: List[str] = [obj.get('name') for obj in sf_objs]
    # Get the list of Btrfs snapshots.
    output = openmediavault.procutils.check_output(
        ['btrfs', 'subvolume', 'list', '-s', '-q', '-u', mount_dir]
    )
    snapshots: Dict[SnapshotName, Dict[SnapshotKind,
                                       Dict[UnixTimestamp, Snapshot]]] = {}
    for line in output.decode().split('\n'):
        matches = re.match(
            r'^ID (\d+) gen (\d+) cgen (\d+) top level (\d+) otime (.+) parent_uuid (.+) uuid (.+) path (.+)$',
            line)
        if matches is not None:
            # Analyze the snapshot file name using the openmediavault naming scheme.
            # Examples:
            # - homes@daily_20230228T095758
            # - homes_20230311T120024
            name_matches = re.match(r'^(.+)_(\d{8}T\d{6})$',
                                    os.path.basename(matches.group(8)))
            if name_matches is None:
                continue
            # Get the snapshot name.
            name: SnapshotName = name_matches.group(1)
            # Get the name of the shared folder and the kind of snapshot,
            # e.g. `custom` or `weekly`.
            name_parts: List[str] = name.rsplit('@', 1)
            sf_name: str = name_parts[0]
            kind: SnapshotKind
            try:
                kind = name_parts[1]  # type: ignore
            except IndexError:
                kind = 'custom'
            # Skip snapshots that do not match a shared folder.
            if sf_name not in sf_names:
                continue
            # Skip snapshots whose `kind` is not supported.
            if kind not in SNAPSHOT_KINDS:
                continue
            if sf_name not in snapshots:
                snapshots[sf_name] = {kind: {} for kind in SNAPSHOT_KINDS}
            timestamp: UnixTimestamp = int(datetime.datetime.timestamp(
                datetime.datetime.strptime(
                    matches.group(5),
                    '%Y-%m-%d %H:%M:%S')))
            snapshots[sf_name][kind][timestamp] = Snapshot(
                id=int(matches.group(1)),
                otime=matches.group(5),
                name=os.path.basename(matches.group(8)),
                path=matches.group(8),
                mount_dir=mount_dir,
                kind=kind)
    return snapshots


def delete_snapshot(snapshot: Snapshot, simulate: bool = False) -> None:
    abs_path: str = os.path.join(snapshot.mount_dir, snapshot.path)
    # Check if the subvolume has childs. In that case abort immediately
    # otherwise the `btrfs` command will fail because it is not possible
    # to delete those subvolumes.
    output = openmediavault.procutils.check_output(
        ["btrfs", "subvolume", "list", "-o", abs_path]
    )
    if output:  # Skip if not empty
        openmediavault.log.info(
            f'The snapshot "{abs_path}" has been skipped as it contains subordinate subvolumes.')
        return
    if not simulate:
        openmediavault.procutils.check_call(
            ['btrfs', 'subvolume', 'delete', abs_path],
            stdout=subprocess.DEVNULL
        )
    openmediavault.log.info(f'The snapshot "{abs_path}" has been deleted.')


class Context:
    verbose: bool


pass_context = click.make_pass_decorator(Context, ensure=True)


@click.group()
@click.option(
    '-v',
    '--verbose',
    is_flag=True,
    default=False,
    help='Shows verbose output.'
)
@pass_context
def cli(
    ctx: Context,
    verbose: bool
):
    ctx.verbose = verbose
    return 0


@cli.command(name='list', help='List shared folder snapshots.')
@pass_context
def list_cmd(ctx: Context):
    sys.stdout.write('{:<5} {:<20} {:<35} {:<8} {:<20}\n'.format(
        'ID', 'SHARED FOLDER', 'NAME', 'KIND', 'CREATED'))
    sys.stdout.write('{:<5} {:<20} {:<35} {:<8} {:<20}\n'.format(
        '--', '-------------', '----', '----', '-------'))
    # Get the configured Btrfs file systems.
    db: openmediavault.config.Database = openmediavault.config.Database()
    objs: List[openmediavault.config.object.Object] = db.get_by_filter(
        'conf.system.filesystem.mountpoint',
        openmediavault.config.DatabaseFilter(
            {'operator': 'stringEquals', 'arg0': 'type', 'arg1': 'btrfs'}
        ))
    for obj in objs:
        snapshots: Dict[SnapshotName, Dict[SnapshotKind, Dict[UnixTimestamp, Snapshot]]] = list_snapshots(
            obj.get("dir"), verbose=False)
        for sf_name in sorted(snapshots.keys()):  # Sort by shared folder name.
            for kind in SNAPSHOT_KINDS:
                # Get the Unix timestamps of the snapshots in ascending
                # order (oldest first).
                timestamps: List[UnixTimestamp] = sorted(
                    snapshots[sf_name][kind].keys())
                for timestamp in timestamps:
                    snapshot: Snapshot = snapshots[sf_name][kind][timestamp]
                    sys.stdout.write(
                        f'{snapshot.id:<5} {sf_name:<20} {snapshot.name:<35} {snapshot.kind:<8} {snapshot.otime:<20}\n')
    sys.exit(0)


@cli.command(name='cleanup', help='Clean up shared folder snapshots.')
@click.option(
    '-s',
    '--simulate',
    is_flag=True,
    default=False,
    help='Perform a simulation. Do not delete any snapshot.'
)
@pass_context
def cleanup_cmd(ctx: Context, simulate: bool):
    db: openmediavault.config.Database = openmediavault.config.Database()

    lifecycle: openmediavault.config.object.Object = db.get(
        'conf.system.sharedfolder.snapshot.lifecycle')
    limits: Dict[SnapshotKind, int] = {kind: lifecycle.get(
        'limit' + kind) for kind in SNAPSHOT_KINDS}

    # Exit if snapshot cleanup is disabled.
    if not lifecycle.get('enable'):
        openmediavault.log.info(
            'Cleaning up snapshots from shared folders has been disabled.',
            verbose=ctx.verbose)
        sys.exit(1)

    openmediavault.log.info(
        'Start cleaning snapshots from shared folders ...',
        verbose=ctx.verbose)

    # Get the configured Btrfs file systems.
    objs: List[openmediavault.config.object.Object] = db.get_by_filter(
        'conf.system.filesystem.mountpoint',
        openmediavault.config.DatabaseFilter(
            {'operator': 'stringEquals', 'arg0': 'type', 'arg1': 'btrfs'}
        ))
    for obj in objs:
        snapshots: Dict[SnapshotName, Dict[SnapshotKind, Dict[UnixTimestamp, Snapshot]]] = list_snapshots(
            obj.get("dir"), verbose=ctx.verbose)
        for sf_name in sorted(snapshots.keys()):  # Sort by shared folder name.
            for kind in SNAPSHOT_KINDS:
                # Get the Unix timestamps of the snapshots in descending
                # order (youngest first).
                timestamps: List[UnixTimestamp] = sorted(
                    snapshots[sf_name][kind].keys(), reverse=True)

                openmediavault.log.info(
                    f'{len(timestamps)} {kind} snapshot(s) were found from the shared folder "{sf_name}".',
                    verbose=ctx.verbose)

                # Determine the snapshot candidates to be deleted.
                candidates: List[UnixTimestamp] = []
                for timestamp in timestamps:
                    # Skip snapshot if it is used as a shared folder.
                    snapshot: Snapshot = snapshots[sf_name][kind][timestamp]
                    if db.exists(
                        "conf.system.sharedfolder",
                        openmediavault.config.DatabaseFilter(
                            {
                                "operator": "and",
                                "arg0": {
                                    "operator": "stringEquals",
                                    "arg0": "mntentref",
                                    "arg1": obj.get("uuid"),
                                },
                                "arg1": {
                                    "operator": "stringEquals",
                                    "arg0": "reldirpath",
                                    "arg1": openmediavault.stringutils.path_prettify(
                                        snapshot.path),
                                },
                            },
                        )
                    ):
                        abs_path: str = os.path.join(
                            snapshot.mount_dir, snapshot.path)
                        openmediavault.log.info(
                            f'The snapshot "{abs_path}" has been skipped because it is used as a shared folder.',
                            verbose=ctx.verbose)
                        continue
                    # To check if a snapshot can be deleted, append the time
                    # of the retention period to each timestamp and check if
                    # it has already expired.
                    withhold_until = datetime.datetime.fromtimestamp(
                        timestamp + lifecycle.get('retentionperiod'))
                    if datetime.datetime.now() > withhold_until:
                        candidates.append(timestamp)

                if len(candidates) > limits[kind]:
                    openmediavault.log.info(
                        f'The maximum number of {limits[kind]} snapshot(s) was exceeded from the shared folder "{sf_name}". The youngest snapshot(s) will be kept.',
                        verbose=ctx.verbose)

                    # Deleting the snapshots, starting with the oldest ones.
                    for timestamp in candidates[limits[kind]:]:
                        snapshot: Snapshot = snapshots[sf_name][kind][timestamp]
                        delete_snapshot(snapshot, simulate)
    sys.exit(0)


def main():
    cli()


if __name__ == '__main__':
    sys.exit(main())
