import importlib
import os
import signal
import struct
import time
import subprocess
from collections.abc import Callable, ValuesView
from abc import ABC, abstractmethod
from multiprocessing import Process

from setproctitle import setproctitle

from cereal import car, log
import cereal.messaging as messaging
import openpilot.selfdrive.sentry as sentry
from openpilot.common.basedir import BASEDIR
from openpilot.common.params import Params
from openpilot.common.swaglog import cloudlog

WATCHDOG_FN = "/dev/shm/wd_"
ENABLE_WATCHDOG = os.getenv("NO_WATCHDOG") is None
ENABLE_WATCHDOG = False # Fixme
_log_dir = ""

def nativelauncher(pargs: list[str], cwd: str, name: str, log_path: str) -> None:
    os.environ['MANAGER_DAEMON'] = name
    with open(log_path, 'a') as log_file:
        os.chdir(cwd)
        proc = subprocess.Popen(pargs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True)
        log_file.write("Started "+name)
        for line in proc.stdout:
            print(line, end='')
            log_file.write(line)
        proc.wait()

def launcher(proc: str, name: str, log_path: str) -> None:
  for _ in iter(int, 1):
    try:
      mod = importlib.import_module(proc)
      setproctitle(proc)
      messaging.context = messaging.Context()
      cloudlog.bind(daemon=name)
      sentry.set_tag("daemon", name)
      with open(log_path, 'a') as log_file, subprocess.Popen(['python', '-m', proc], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True) as proc:
        log_file.write("Started "+name)
        for line in proc.stdout:
          print(line, end='')
          log_file.write(line)
        proc.wait()
    except Exception as e:
      print ("Fatal: "+name)
      print (e)
      sentry.capture_exception()

def join_process(process: Process, timeout: float) -> None:
  t = time.monotonic()
  while time.monotonic() - t < timeout and process.exitcode is None:
    time.sleep(0.001)

class ManagerProcess(ABC):
  daemon = False
  sigkill = False
  should_run: Callable[[bool, Params, car.CarParams], bool]
  proc: Process | None = None
  enabled = True
  name = ""
  last_watchdog_time = 0
  watchdog_max_dt: int | None = None
  watchdog_seen = False
  shutting_down = False

  @abstractmethod
  def prepare(self) -> None:
    pass

  @abstractmethod
  def start(self) -> None:
    pass

  def restart(self) -> None:
    if self.proc is not None and self.proc.exitcode is not None:
      self.stop(sig=signal.SIGKILL, block=False)
    self.start()

  def check_watchdog(self, started: bool) -> None:
    if self.watchdog_max_dt is None or self.proc is None:
      return
    try:
      fn = WATCHDOG_FN + str(self.proc.pid)
      with open(fn, "rb") as f:
        self.last_watchdog_time = struct.unpack('Q', f.read())[0]
    except Exception:
      pass
    dt = time.monotonic() - self.last_watchdog_time / 1e9
    if dt > self.watchdog_max_dt:
      if (self.watchdog_seen or self.always_watchdog and self.proc.exitcode is not None) and ENABLE_WATCHDOG:
        cloudlog.error(f"Watchdog timeout for {self.name} (exitcode {self.proc.exitcode}) restarting ({started=})")
        self.restart()
    else:
      self.watchdog_seen = True

  def stop(self, retry: bool = True, block: bool = True, sig: signal.Signals = None) -> int | None:
    if self.proc is None:
      return None
    if self.proc.exitcode is None:
      if not self.shutting_down:
        cloudlog.info(f"killing {self.name}")
        if sig is None:
          sig = signal.SIGKILL if self.sigkill else signal.SIGINT
        self.signal(sig)
        self.shutting_down = True
        if not block:
          return None
      join_process(self.proc, 5)
      if self.proc.exitcode is None and retry:
        cloudlog.info(f"killing {self.name} with SIGKILL")
        self.signal(signal.SIGKILL)
        self.proc.join()
    ret = self.proc.exitcode
    cloudlog.info(f"{self.name} is dead with {ret}")
    if self.proc.exitcode is not None:
      self.shutting_down = False
      self.proc = None
    return ret

  def signal(self, sig: int) -> None:
    if self.proc is None or self.proc.exitcode is not None or self.proc.pid is None:
      return
    cloudlog.info(f"sending signal {sig} to {self.name}")
    os.kill(self.proc.pid, sig)

  def get_process_state_msg(self):
    state = log.ManagerState.ProcessState.new_message()
    state.name = self.name
    if self.proc:
      state.running = self.proc.is_alive()
      state.shouldBeRunning = self.proc is not None and not self.shutting_down
      state.pid = self.proc.pid or 0
      state.exitCode = self.proc.exitcode or 0
    return state


class NativeProcess(ManagerProcess):
  def __init__(self, name, cwd, cmdline, should_run, enabled=True, sigkill=False, watchdog_max_dt=None, always_watchdog=False):
    self.name = name
    self.cwd = cwd
    self.cmdline = cmdline
    self.should_run = should_run
    self.enabled = enabled
    self.sigkill = sigkill
    self.watchdog_max_dt = watchdog_max_dt
    self.launcher = nativelauncher
    self.always_watchdog = always_watchdog

  def prepare(self) -> None:
    pass

  def start(self) -> None:
    global _log_dir
    log_path = _log_dir+"/"+self.name+".log"
    if self.shutting_down or self.proc is not None:
       return
    self.proc = Process(target=nativelauncher, args=(self.cmdline, os.path.join(BASEDIR, self.cwd), self.name, log_path))
    self.proc.start()

class PythonProcess(ManagerProcess):
  def __init__(self, name, module, should_run, enabled=True, sigkill=False, watchdog_max_dt=None):
    self.name = name
    self.module = module
    self.should_run = should_run
    self.enabled = enabled
    self.sigkill = sigkill
    self.watchdog_max_dt = watchdog_max_dt
    self.launcher = launcher

  def prepare(self) -> None:
    if self.enabled:
      cloudlog.info(f"preimporting {self.module}")
      importlib.import_module(self.module)

  def start(self) -> None:
    global _log_dir
    log_path = _log_dir+"/"+self.name+".log"
    if self.shutting_down or self.proc is not None:
      return
    self.proc = Process(name=self.name, target=launcher, args=(self.module, self.name, log_path))
    self.proc.start()
    self.watchdog_seen = False
    self.shutting_down = False

class DaemonProcess(ManagerProcess):
  """Python process that has to stay running across manager restart.
  This is used for athena so you don't lose SSH access when restarting manager."""
  def __init__(self, name, module, param_name, enabled=True):
    self.name = name
    self.module = module
    self.param_name = param_name
    self.enabled = enabled
    self.params = None

  @staticmethod
  def should_run(started, params, CP):
    return True

  def prepare(self) -> None:
    pass

  def start(self) -> None:
    global _log_dir
    log_path = _log_dir+"/"+self.name+".log"
    if self.params is None:
        self.params = Params()

    pid = self.params.get(self.param_name, encoding='utf-8')
    if pid is not None:
        try:
            os.kill(int(pid), 0)
            return  # Process is already running
        except OSError:
            pass  # Process not running, continue to start it

    cloudlog.info(f"starting daemon {self.name}")
    self.proc = subprocess.Popen(['python', '-m', self.module],
                                  stdin=open('/dev/null'),
                                  stdout=open(log_path, 'a'),
                                  stderr=subprocess.STDOUT,
                                  preexec_fn=os.setpgrp)
    self.params.put(self.param_name, str(self.proc.pid))

  def stop(self, retry=True, block=True, sig=None) -> None:
    pass


def ensure_running(procs: ValuesView[ManagerProcess], started: bool, params=None, CP: car.CarParams=None, not_run: list[str] | None=None, log_dir: str = None) -> list[ManagerProcess]:
    global _log_dir
    _log_dir = log_dir
    if not_run is None:
        not_run = []

    running = []
    for p in procs:
        if p.enabled and p.name not in not_run and p.should_run(started, params, CP):
            if p.proc is None or (hasattr(p.proc, 'exitcode') and p.proc.exitcode is not None):
                p.start()
            running.append(p)
        else:
            p.stop(block=False)

        p.check_watchdog(started)

    return running
