#!/usr/bin/env python3
import argparse
import datetime as dt
import getpass
import json
import os
from pathlib import Path
import re
import sys
import time
import urllib.error
import urllib.request


DEFAULT_BASE_URL = os.environ.get("CLOCKBRIDGE_OAUTH_BASE_URL", "https://feishu-oauth.dev.clock-p.com")
USER_NAME_RE = re.compile(r"^[a-z0-9-]+$")


def resolve_windows_roaming_dir() -> Path:
    def normalize_store_python_roaming(p: Path) -> Path:
        raw = str(p)
        lowered = raw.lower().replace("/", "\\")
        if (
            "\\appdata\\local\\packages\\pythonsoftwarefoundation.python." in lowered
            and lowered.endswith("\\localcache\\roaming")
        ):
            userprofile = str(os.environ.get("USERPROFILE", "")).strip().strip('"')
            if userprofile:
                fixed = Path(userprofile) / "AppData" / "Roaming"
                if fixed.is_absolute():
                    return fixed
        return p

    # Windows-only: prefer OS API instead of manual environment concatenation.
    if os.name != "nt":
        return Path.home() / "AppData" / "Roaming"
    try:
        import ctypes

        CSIDL_APPDATA = 26
        SHGFP_TYPE_CURRENT = 0
        buf = ctypes.create_unicode_buffer(32768)
        rc = ctypes.windll.shell32.SHGetFolderPathW(None, CSIDL_APPDATA, None, SHGFP_TYPE_CURRENT, buf)
        if rc == 0:
            value = str(buf.value or "").strip().strip('"')
            if value:
                p = Path(value)
                if p.is_absolute():
                    return normalize_store_python_roaming(p)
    except Exception:
        pass

    # Fallbacks are only used when OS API is unavailable.
    candidates = [
        os.environ.get("APPDATA", ""),
        os.path.join(os.environ.get("USERPROFILE", ""), "AppData", "Roaming"),
        str(Path.home() / "AppData" / "Roaming"),
    ]
    for raw in candidates:
        value = str(raw or "").strip().strip('"')
        if not value:
            continue
        expanded = os.path.expandvars(os.path.expanduser(value))
        if not expanded or "%" in expanded:
            continue
        p = Path(expanded)
        if p.is_absolute():
            return normalize_store_python_roaming(p)
    return normalize_store_python_roaming(Path.home() / "AppData" / "Roaming")


def print_path_diagnostics(token_file: Path, user_id_file: Path, user_name_file: Path) -> None:
    print(f"token file: {token_file}")
    print(f"user_id file: {user_id_file}")
    print(f"user_name file: {user_name_file}")
    if os.name == "nt":
        print(f"windows APPDATA env: {os.environ.get('APPDATA', '')}")
        print(f"windows USERPROFILE env: {os.environ.get('USERPROFILE', '')}")
        print(f"windows roaming dir(api): {resolve_windows_roaming_dir()}")


def resolve_token_file(local: bool) -> Path:
    if local:
        return Path.cwd() / ".dev.clock-p.com" / "feishu-token"
    if os.name == "nt":
        # Keep Windows path resolution isolated to avoid affecting Linux/macOS behavior.
        base = resolve_windows_roaming_dir()
        return base / "dev.clock-p.com" / "feishu-token"
    return Path.home() / ".dev.clock-p.com" / "feishu-token"


def resolve_user_id_file(local: bool) -> Path:
    return resolve_token_file(local).with_name("feishu-user_id")


def resolve_user_name_file(local: bool) -> Path:
    return resolve_token_file(local).with_name("feishu-user_name")


def ensure_parent(path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    if os.name != "nt":
        os.chmod(path.parent, 0o700)


def write_and_verify(path: Path, value: str) -> None:
    path = path.expanduser()
    ensure_parent(path)
    parent = path.parent
    data = value.strip() + "\n"

    tmp_path = parent / f".{path.name}.tmp.{os.getpid()}.{int(time.time() * 1000)}"
    with tmp_path.open("w", encoding="utf-8", newline="\n") as f:
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp_path, path)

    if os.name != "nt":
        os.chmod(path, 0o600)

    # Some Windows environments have delayed visibility; retry a few times.
    deadline = time.time() + 1.5
    last_verify = ""
    while time.time() < deadline:
        if path.exists():
            last_verify = path.read_text(encoding="utf-8")
            if last_verify == data:
                return
        time.sleep(0.05)
    if not path.exists():
        raise RuntimeError(f"write failed: file not found after save: {path}")
    raise RuntimeError(f"write verification failed: content mismatch: {path} got={last_verify!r}")


def read_token(path: Path) -> str:
    if not path.exists():
        return ""
    return path.read_text(encoding="utf-8").strip()


def save_token(path: Path, token: str) -> None:
    write_and_verify(path, token)


def save_user_id(path: Path, user_id: str) -> None:
    write_and_verify(path, user_id)


def save_user_name(path: Path, user_name: str) -> None:
    write_and_verify(path, user_name)


def request_json(method: str, url: str, payload=None, bearer_token: str = ""):
    data = None
    headers = {"Accept": "application/json"}
    if payload is not None:
        data = json.dumps(payload).encode("utf-8")
        headers["Content-Type"] = "application/json"
    if bearer_token:
        headers["Authorization"] = f"Bearer {bearer_token}"

    req = urllib.request.Request(url, data=data, headers=headers, method=method.upper())
    try:
        with urllib.request.urlopen(req, timeout=30) as resp:
            body = resp.read().decode("utf-8", errors="replace")
            if not body.strip():
                return {}
            return json.loads(body)
    except urllib.error.HTTPError as exc:
        body = exc.read().decode("utf-8", errors="replace")
        message = f"http {exc.code}"
        try:
            parsed = json.loads(body) if body.strip() else {}
            if isinstance(parsed, dict):
                message = str(parsed.get("message") or parsed.get("error") or message)
        except Exception:
            if body.strip():
                message = body.strip()
        raise RuntimeError(message) from exc
    except urllib.error.URLError as exc:
        raise RuntimeError(f"request failed: {exc.reason}") from exc


def fetch_me(base_url: str, token: str, strict: bool = False):
    if not token:
        return None
    url = f"{base_url.rstrip('/')}/api/auth/me"
    try:
        data = request_json("GET", url, bearer_token=token)
    except RuntimeError:
        if strict:
            raise
        return None
    if not isinstance(data, dict) or not data.get("ok"):
        if strict:
            raise RuntimeError("auth/me returned non-ok payload")
        return None
    user = data.get("user")
    if not isinstance(user, dict):
        if strict:
            raise RuntimeError("auth/me missing user payload")
        return None
    return user


def looks_like_token(value: str) -> bool:
    return value.count(".") == 2 and len(value) > 32


def derive_user_id(user: dict) -> str:
    user_id = str(user.get("userId") or "").strip()
    if user_id:
        return user_id
    principal = str(user.get("principalId") or user.get("username") or "").strip()
    if principal.startswith("feishu:"):
        parts = principal.split(":")
        if len(parts) >= 3:
            fallback = str(parts[-1] or "").strip()
            if fallback:
                return fallback
    return ""


def derive_user_name(user: dict) -> str:
    raw = str(user.get("userName") or "").strip().lower()
    if not raw:
        return ""
    if USER_NAME_RE.fullmatch(raw):
        return raw
    return ""


def fmt_exp_utc(exp: object) -> str:
    if not isinstance(exp, int):
        return ""
    return dt.datetime.fromtimestamp(exp, tz=dt.timezone.utc).isoformat()


def cmd_login(
    base_url: str, token_file: Path, user_id_file: Path, user_name_file: Path, with_token: bool, debug_path: bool
) -> int:
    ensure_parent(token_file)
    ensure_parent(user_id_file)
    ensure_parent(user_name_file)
    if debug_path:
        print_path_diagnostics(token_file, user_id_file, user_name_file)
    existing = read_token(token_file)
    if existing:
        me = fetch_me(base_url, existing)
        if me:
            principal = str(me.get("principalId") or me.get("username") or "")
            resolved_user_id = derive_user_id(me) or read_token(user_id_file)
            resolved_user_name = derive_user_name(me) or read_token(user_name_file)
            if not resolved_user_name:
                raise RuntimeError("cannot resolve user_name from /api/auth/me; check oauth_aliases and re-login")
            if resolved_user_id:
                save_user_id(user_id_file, resolved_user_id)
            if resolved_user_name:
                save_user_name(user_name_file, resolved_user_name)
            print(f"already login: {principal}")
            print(f"token file: {token_file}")
            if resolved_user_id:
                print(f"user_id: {resolved_user_id}")
            if resolved_user_name:
                print(f"user_name: {resolved_user_name}")
            print(f"user_id file: {user_id_file}")
            print(f"user_name file: {user_name_file}")
            return 0

    start_url = f"{base_url.rstrip('/')}/api/nonweb/flow/start"
    start = request_json("POST", start_url, payload={"showTokenOnPage": with_token})
    flow_id = str(start.get("flowId") or "").strip()
    verify_url = str(start.get("verifyUrl") or "").strip()
    if not flow_id or not verify_url:
        raise RuntimeError("invalid flow response")

    print("open this url in browser and finish Feishu OAuth:")
    print(verify_url)
    pasted = input("paste code or token: " if with_token else "paste code: ").strip()
    if not pasted:
        raise RuntimeError("empty input")

    if looks_like_token(pasted):
        token = pasted
    else:
        exchange_url = f"{base_url.rstrip('/')}/api/nonweb/flow/exchange"
        exchanged = request_json("POST", exchange_url, payload={"flowId": flow_id, "code": pasted})
        token = str(exchanged.get("accessToken") or "").strip()
        if not token:
            raise RuntimeError("exchange succeeded but token is empty")

    try:
        me = fetch_me(base_url, token, strict=True)
    except RuntimeError as exc:
        raise RuntimeError(f"token verification failed: {exc}") from exc

    resolved_user_id = derive_user_id(me)
    if not resolved_user_id:
        raise RuntimeError("cannot resolve user_id from token")
    resolved_user_name = derive_user_name(me)
    if not resolved_user_name:
        raise RuntimeError("cannot resolve user_name from /api/auth/me; check oauth_aliases and re-login")

    save_token(token_file, token)
    save_user_id(user_id_file, resolved_user_id)
    save_user_name(user_name_file, resolved_user_name)
    principal = str(me.get("principalId") or me.get("username") or "")
    exp_utc = fmt_exp_utc(me.get("exp"))
    print(f"login success: {principal}")
    print(f"user_id: {resolved_user_id}")
    print(f"user_name: {resolved_user_name}")
    if exp_utc:
        print(f"token exp (utc): {exp_utc}")
    print(f"token file: {token_file}")
    print(f"user_id file: {user_id_file}")
    print(f"user_name file: {user_name_file}")
    return 0


def cmd_logout(token_file: Path, user_id_file: Path, user_name_file: Path) -> int:
    removed_any = False
    if token_file.exists():
        token_file.unlink()
        print(f"logout success: removed {token_file}")
        removed_any = True
    if user_id_file.exists():
        user_id_file.unlink()
        print(f"logout success: removed {user_id_file}")
        removed_any = True
    if user_name_file.exists():
        user_name_file.unlink()
        print(f"logout success: removed {user_name_file}")
        removed_any = True
    if not removed_any:
        print("already logout: token/user_id/user_name file not found")
    return 0


def cmd_status(base_url: str, token_file: Path, user_id_file: Path, user_name_file: Path, debug_path: bool) -> int:
    if debug_path:
        print_path_diagnostics(token_file, user_id_file, user_name_file)
    token = read_token(token_file)
    stored_user_id = read_token(user_id_file)
    stored_user_name = read_token(user_name_file)
    if not token:
        print("not login")
        print(f"token file: {token_file}")
        print(f"user_id file: {user_id_file}")
        print(f"user_name file: {user_name_file}")
        if stored_user_id:
            print(f"user_id: {stored_user_id}")
        if stored_user_name:
            print(f"user_name: {stored_user_name}")
        return 1
    me = fetch_me(base_url, token)
    if not me:
        print("token exists but invalid/expired")
        print(f"token file: {token_file}")
        print(f"user_id file: {user_id_file}")
        print(f"user_name file: {user_name_file}")
        if stored_user_id:
            print(f"user_id: {stored_user_id}")
        if stored_user_name:
            print(f"user_name: {stored_user_name}")
        return 1
    principal = str(me.get("principalId") or me.get("username") or "")
    resolved_user_id = derive_user_id(me) or stored_user_id
    resolved_user_name = derive_user_name(me) or stored_user_name
    if resolved_user_id and resolved_user_id != stored_user_id:
        save_user_id(user_id_file, resolved_user_id)
    if resolved_user_name and resolved_user_name != stored_user_name:
        save_user_name(user_name_file, resolved_user_name)
    exp_utc = fmt_exp_utc(me.get("exp"))
    print(f"login: {principal}")
    if resolved_user_id:
        print(f"user_id: {resolved_user_id}")
    if resolved_user_name:
        print(f"user_name: {resolved_user_name}")
    if exp_utc:
        print(f"token exp (utc): {exp_utc}")
    print(f"token file: {token_file}")
    print(f"user_id file: {user_id_file}")
    print(f"user_name file: {user_name_file}")
    return 0


def cmd_save(
    base_url: str,
    token_file: Path,
    user_id_file: Path,
    user_name_file: Path,
    token: str,
    user_id: str,
    user_name: str,
    debug_path: bool,
) -> int:
    ensure_parent(token_file)
    ensure_parent(user_id_file)
    ensure_parent(user_name_file)
    if debug_path:
        print_path_diagnostics(token_file, user_id_file, user_name_file)
    value = token.strip()
    if not value:
        raise RuntimeError("empty token")

    try:
        me = fetch_me(base_url, value, strict=True)
    except RuntimeError as exc:
        raise RuntimeError(f"token verification failed: {exc}") from exc

    resolved_user_id = user_id.strip() or derive_user_id(me)
    if not resolved_user_id:
        raise RuntimeError("cannot resolve user_id from token; use --user-id")
    resolved_user_name = user_name.strip() or derive_user_name(me)
    if not resolved_user_name:
        raise RuntimeError("cannot resolve user_name from /api/auth/me; use --user-name or check oauth_aliases")

    save_token(token_file, value)
    save_user_id(user_id_file, resolved_user_id)
    save_user_name(user_name_file, resolved_user_name)
    principal = str(me.get("principalId") or me.get("username") or "")
    exp_utc = fmt_exp_utc(me.get("exp"))
    print(f"save success: {token_file}")
    print(f"principal: {principal}")
    print(f"user_id: {resolved_user_id}")
    print(f"user_name: {resolved_user_name}")
    if exp_utc:
        print(f"token exp (utc): {exp_utc}")
    print(f"user_id file: {user_id_file}")
    print(f"user_name file: {user_name_file}")
    return 0


def main() -> int:
    parser = argparse.ArgumentParser(description="Feishu OAuth token client for dev.clock-p.com")
    parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="OAuth service base URL")
    parser.add_argument(
        "--local",
        action="store_true",
        help="use local token path: ./.dev.clock-p.com/feishu-token",
    )
    parser.add_argument("--debug-path", action="store_true", help="print resolved credential paths before execution")
    sub = parser.add_subparsers(dest="cmd", required=True)
    login_parser = sub.add_parser("login", help="start non-web login flow")
    login_parser.add_argument(
        "--with-token",
        action="store_true",
        help="ask server to show token on browser page (default only shows code)",
    )
    save_parser = sub.add_parser("save", help="save token to local file without login flow")
    save_parser.add_argument("token", nargs="?", help="token value (omit to input securely)")
    save_parser.add_argument("--stdin", action="store_true", help="read token from stdin")
    save_parser.add_argument("--user-id", default="", help="explicit user_id override")
    save_parser.add_argument("--user-name", default="", help="explicit user_name override")
    sub.add_parser("logout", help="remove local token")
    sub.add_parser("status", help="print local login status")

    args = parser.parse_args()
    token_file = resolve_token_file(args.local)
    user_id_file = resolve_user_id_file(args.local)
    user_name_file = resolve_user_name_file(args.local)

    if args.cmd == "login":
        return cmd_login(
            args.base_url, token_file, user_id_file, user_name_file, bool(getattr(args, "with_token", False)), bool(args.debug_path)
        )
    if args.cmd == "logout":
        return cmd_logout(token_file, user_id_file, user_name_file)
    if args.cmd == "status":
        return cmd_status(args.base_url, token_file, user_id_file, user_name_file, bool(args.debug_path))
    if args.cmd == "save":
        token_value = ""
        if bool(getattr(args, "stdin", False)):
            token_value = sys.stdin.read().strip()
        else:
            token_value = str(getattr(args, "token", "") or "").strip()
            if not token_value:
                token_value = getpass.getpass("paste token: ").strip()
        user_id_value = str(getattr(args, "user_id", "") or "").strip()
        user_name_value = str(getattr(args, "user_name", "") or "").strip()
        return cmd_save(
            args.base_url,
            token_file,
            user_id_file,
            user_name_file,
            token_value,
            user_id_value,
            user_name_value,
            bool(args.debug_path),
        )
    print(f"unsupported command: {args.cmd}", file=sys.stderr)
    return 2


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except RuntimeError as exc:
        print(f"error: {exc}", file=sys.stderr)
        raise SystemExit(1)
