#!/usr/bin/env python3
"""
Find public dn42 peers from endpoints listed in this repository.
"""

import argparse
import dataclasses
import logging
from pathlib import Path
import re
import shutil
import subprocess
import sys
from collections import defaultdict

import yaml

log = logging.getLogger('peerfinder')

try:
    import coloredlogs
    coloredlogs.install(level=logging.INFO, logger=log)
except ImportError:
    coloredlogs = None
    logging.basicConfig(level=logging.INFO)


def _check_deps():
    if not shutil.which('fping'):
        raise RuntimeError("fping is not installed")

@dataclasses.dataclass
class Network:
    asn: int
    as_name: str | None = None
    description: str | None = None
    url: str | None = None

@dataclasses.dataclass
class Server:
    network: Network
    server_name: str | None = None
    country: str | None = None


_DEFAULT_PING_COUNT = 3
_DEFAULT_PING_BATCH_SIZE = 400
_DEFAULT_RESULT_SIZE = 10
_DEFAULT_NODES_PER_PEER = 4
_FPING_RE = re.compile(
    r'^(?P<addr>\S+)\s+: xmt/rcv/%loss = (?P<xmt>\d+)/(?P<rcv>\d+)/(?P<pctloss>\d+)%(?:, min/avg/max = (?P<minrtt>[0-9.]+)/(?P<avgrtt>[0-9.]+)/(?P<maxrtt>[0-9.]+))?$',
    flags=re.MULTILINE,
)

class Peerfinder:
    def __init__(self, data_dir: Path, n_results: int = _DEFAULT_RESULT_SIZE,
                 ping_count: int = _DEFAULT_PING_COUNT,
                 nodes_per_peer: int = _DEFAULT_NODES_PER_PEER,
                 skip_asns: list[int] = None,
                 batch_size: int = _DEFAULT_PING_BATCH_SIZE):
        self.data_dir = data_dir
        self.n_results = n_results
        self.ping_count = ping_count
        self.nodes_per_peer = nodes_per_peer
        self.skip_asns = skip_asns or []
        self.batch_size = batch_size

        self.endpoints: dict[str, Server] = {}

    def get_endpoints(self):
        self.endpoints.clear()
        files_processed = 0
        for filename in self.data_dir.glob('*.y*ml'):
            with open(filename, encoding='utf-8') as f:
                try:
                    asn = int(filename.stem)
                except ValueError:
                    log.warning('Could not get ASN from filename %r', filename)
                if asn in self.skip_asns:
                    log.info('Ignoring skipped ASN %d', asn)
                    continue
                try:
                    yaml_data = yaml.safe_load(f)
                except yaml.YAMLError as e:
                    log.warning('FAILED to read %s: %s', filename, e)
                    continue

                net_info = Network(
                        asn=asn,
                        as_name=yaml_data.get('name'),
                        description=yaml_data.get('description'),
                        url=yaml_data.get('url'),
                )
                for server_name, server_info in yaml_data.get('servers', {}).items():
                    if not (addr := server_info['address']):
                        continue
                    if existing_entry := self.endpoints.get(addr):
                        log.warning('Duplicate address %r from %s and %s', addr,
                                    existing_entry.network.asn, asn)
                    endpoint = Server(
                        network=net_info,
                        server_name=server_name,
                        country=server_info.get('country'),
                    )
                    self.endpoints[addr] = endpoint
                files_processed += 1
        log.info('Read %d peer endpoints from %d networks', len(self.endpoints),
                 files_processed)

    @staticmethod
    def _fping_sort(match: re.Match) -> float:
        if match.group('avgrtt') is None:
            return float('inf')
        return float(match.group('avgrtt'))

    def run(self) -> int:
        if not self.endpoints:
            log.error("No endpoints found to ping")
            return 1

        batch_size = self.batch_size or len(self.endpoints)
        all_results = []
        endpoints_list = list(self.endpoints)
        total_batches = (len(self.endpoints) + batch_size - 1) // batch_size
        for i in range(0, len(endpoints_list), batch_size):

            batch = endpoints_list[i:i + batch_size]
            batch_strlist = '\n'.join(batch)
            log.info('Processing batch %d of %d (%d endpoints)...', i // batch_size + 1, total_batches, len(batch))

            result = subprocess.run(
                ["fping", '-c', str(self.ping_count), '-q', '-f', '-'],
                input=batch_strlist,
                capture_output=True,
                text=True,
                check=False
            )
            all_results.append(result.stderr)

        combined_output = '\n'.join(all_results)
        matches = list(_FPING_RE.finditer(combined_output))

        # Group results by ASN
        asn_groups = defaultdict(list)
        for match in matches:
            addr = match.group('addr')
            server_info = self.endpoints[addr]
            asn_groups[server_info.network.asn].append(match)

        # Sort each ASN group by latency
        for asn in asn_groups:
            asn_groups[asn].sort(key=self._fping_sort)

        # Sort ASN groups by the smallest latency in each group
        sorted_asns = sorted(asn_groups,
                             key=lambda asn_v: self._fping_sort(asn_groups[asn_v][0]))

        if self.n_results:
            print(file=sys.stderr)
            print(f"Showing the top {self.n_results} results. Use '-a' to view all results.", file=sys.stderr)

        # Display results grouped by ASN
        total_shown = 0
        for asn in sorted_asns:
            total_shown += 1
            if total_shown > self.n_results > 0:
                break
            matches_for_asn = asn_groups[asn]
            if not matches_for_asn:
                continue

            # Get ASN info from first match
            first_addr = matches_for_asn[0].group('addr')
            network_info = self.endpoints[first_addr].network

            print(f'\nAS{asn}', end='')
            if network_info.as_name:
                print(f' ({network_info.as_name})', end='')
            print(':')
            if network_info.description:
                print(f" Description: {network_info.description}")
            if network_info.url:
                print(f" URL: {network_info.url}")
            print()

            nodes_shown = 0
            for match in matches_for_asn:
                if nodes_shown >= self.nodes_per_peer > 0:
                    print(f' Use "-a" to view the remaining {len(matches_for_asn) - nodes_shown} nodes')
                    break
                nodes_shown += 1
                print(f'  {match.group(0)}')
                addr = match.group('addr')
                server_info = self.endpoints[addr]
                print(f'    {server_info.server_name}', end='')
                if server_info.country:
                    print(f' ({server_info.country})', end='')
                print()
            print()
        return 0

def main():
    default_data_dir = Path(__file__).parent / 'servers'

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-d', '--data-dir', type=Path, default=default_data_dir,
                        help='data directory for peers')
    parser.add_argument('-n', '--num', type=int, default=_DEFAULT_RESULT_SIZE,
                        help="number of peers to return")
    parser.add_argument('-N', '--nodes-per-peer', type=int, default=_DEFAULT_NODES_PER_PEER,
                        help="maximum number of nodes to return per peer")
    parser.add_argument('-a', '--all', action='store_true',
                        help="show all peers, not just the closest ones")
    parser.add_argument('-c', '--ping-count', type=int,
                        default=_DEFAULT_PING_COUNT,
                        help="number of times to ping each endpoint")
    parser.add_argument('-s', '--skip-asns', nargs='+', type=int,
                        help="list of ASNs to filter out from the output")
    parser.add_argument('-b', '--batch-size', type=int, default=_DEFAULT_PING_BATCH_SIZE,
                        help="ping hosts in batches of this size (0 disables batching)")
    args = parser.parse_args()

    _check_deps()
    peerfinder = Peerfinder(
        args.data_dir,
        n_results=0 if args.all else args.num,
        ping_count=args.ping_count,
        nodes_per_peer=0 if args.all else args.nodes_per_peer,
        skip_asns=args.skip_asns,
        batch_size=args.batch_size,
    )
    peerfinder.get_endpoints()
    sys.exit(peerfinder.run())


if __name__ == '__main__':
    main()
