(*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

module Make (L : Loc_sig.S) = struct
  module L = L

  type scope = int

  type use = L.t

  type uses = L.LSet.t

  module Def = struct
    type t = {
      locs: L.t Nel.t;
      name: int;
      actual_name: string;
      kind: Bindings.kind;
    }
    [@@deriving show]

    let compare =
      let rec iter locs1 locs2 =
        match (locs1, locs2) with
        | ([], []) -> 0
        | ([], _) -> -1
        | (_, []) -> 1
        | (loc1 :: locs1, loc2 :: locs2) ->
          let i = L.compare loc1 loc2 in
          if i = 0 then
            iter locs1 locs2
          else
            i
      in
      (fun t1 t2 -> iter (Nel.to_list t1.locs) (Nel.to_list t2.locs))

    let is x t = Nel.exists (L.equal x) t.locs
  end

  module DefMap = WrappedMap.Make (Def)

  type use_def_map = Def.t L.LMap.t [@@deriving show]

  module Scope = struct
    type t = {
      lexical: bool;
      parent: int option;
      defs: Def.t SMap.t;
      locals: use_def_map;
      globals: SSet.t;
      loc: L.t;
    }
    [@@deriving show]
  end

  type info = {
    (* number of distinct name ids *)
    max_distinct: int;
    (* map of scope ids to local scopes *)
    scopes: Scope.t IMap.t;
  }
  [@@deriving show]

  let all_uses { scopes; _ } =
    IMap.fold
      (fun _ scope acc ->
        L.LMap.fold (fun use _ uses -> L.LSet.add use uses) scope.Scope.locals acc)
      scopes
      L.LSet.empty

  let defs_of_all_uses { scopes; _ } =
    IMap.fold (fun _ scope acc -> L.LMap.union scope.Scope.locals acc) scopes L.LMap.empty

  let uses_of_all_defs info =
    let use_def_map = defs_of_all_uses info in
    L.LMap.fold
      (fun use def def_uses_map ->
        match DefMap.find_opt def def_uses_map with
        | None -> DefMap.add def (L.LSet.singleton use) def_uses_map
        | Some uses -> DefMap.add def (L.LSet.add use uses) def_uses_map)
      use_def_map
      DefMap.empty

  exception Missing_def of info * use

  let def_of_use_opt { scopes; _ } use =
    IMap.fold
      (fun _ scope acc ->
        match acc with
        | Some _ -> acc
        | None -> L.LMap.find_opt use scope.Scope.locals)
      scopes
      None

  let def_of_use info use =
    match def_of_use_opt info use with
    | Some def -> def
    | None -> raise (Missing_def (info, use))

  let use_is_def info use =
    let def = def_of_use info use in
    Def.is use def

  let uses_of_def { scopes; _ } ?(exclude_def = false) def =
    IMap.fold
      (fun _ scope acc ->
        L.LMap.fold
          (fun use def' uses ->
            if exclude_def && Def.is use def' then
              uses
            else if Def.compare def def' = 0 then
              L.LSet.add use uses
            else
              uses)
          scope.Scope.locals
          acc)
      scopes
      L.LSet.empty

  let scopes_of_uses_of_def { scopes; _ } def =
    IMap.fold
      (fun scope_id scope acc ->
        L.LMap.fold
          (fun use def' scopes ->
            if Def.is use def' then
              scopes
            else if Def.compare def def' = 0 then
              ISet.add scope_id scopes
            else
              scopes)
          scope.Scope.locals
          acc)
      scopes
      ISet.empty

  let uses_of_use info ?exclude_def use =
    try
      let def = def_of_use info use in
      uses_of_def info ?exclude_def def
    with
    | Missing_def _ -> L.LSet.empty

  let def_is_unused info def = L.LSet.is_empty (uses_of_def info ~exclude_def:true def)

  let toplevel_scopes = [0]

  let scope info scope_id =
    try IMap.find scope_id info.scopes with
    | Not_found -> failwith ("Scope " ^ string_of_int scope_id ^ " not found")

  let rec scope_within info scope_id s =
    match s.Scope.parent with
    | None -> false
    | Some p ->
      if p = scope_id then
        true
      else
        scope_within info scope_id (scope info p)

  let scope_of_loc info scope_loc =
    let scopes =
      IMap.fold
        (fun scope_id scope acc ->
          if scope.Scope.loc = scope_loc then
            scope_id :: acc
          else
            acc)
        info.scopes
        []
    in
    List.rev scopes

  let closest_enclosing_scope info loc in_range =
    let (scope_id, _) =
      IMap.fold
        (fun this_scope_id this_scope (prev_scope_id, prev_scope) ->
          if in_range loc this_scope.Scope.loc && in_range this_scope.Scope.loc prev_scope.Scope.loc
          then
            (this_scope_id, this_scope)
          else
            (prev_scope_id, prev_scope))
        info.scopes
        (0, scope info 0)
    in
    scope_id

  let is_local_use { scopes; _ } use =
    IMap.exists (fun _ scope -> L.LMap.mem use scope.Scope.locals) scopes

  let rec fold_scope_chain info f scope_id acc =
    let s = scope info scope_id in
    let acc = f scope_id s acc in
    match s.Scope.parent with
    | Some parent_id -> fold_scope_chain info f parent_id acc
    | None -> acc

  let rev_scope_pointers scopes =
    IMap.fold
      (fun id scope acc ->
        match scope.Scope.parent with
        | Some scope_id ->
          let children' =
            match IMap.find_opt scope_id acc with
            | Some children -> children
            | None -> []
          in
          IMap.add scope_id (id :: children') acc
        | None -> acc)
      scopes
      IMap.empty

  let build_scope_tree info =
    let scopes = info.scopes in
    let children_map = rev_scope_pointers scopes in
    let rec build_scope_tree scope_id =
      let children =
        match IMap.find_opt scope_id children_map with
        | None -> []
        | Some children_scope_ids -> List.rev_map build_scope_tree children_scope_ids
      in
      Tree.Node (IMap.find scope_id scopes, children)
    in
    build_scope_tree 0

  let compute_all_in_scope_bindings_per_scope info =
    let scopes = info.scopes in
    let children_map = rev_scope_pointers scopes in
    let rec build_bindings_per_scope scope_id parent_in_scope_bindings acc =
      let scope = IMap.find scope_id scopes in
      let all_in_scope_bindings =
        SMap.union ~combine:(fun _ l _ -> Some l) scope.Scope.defs parent_in_scope_bindings
      in
      let acc = IMap.add scope_id all_in_scope_bindings acc in
      match IMap.find_opt scope_id children_map with
      | None -> acc
      | Some children_scope_ids ->
        Base.List.fold_left
          ~init:acc
          ~f:(fun acc id -> build_bindings_per_scope id all_in_scope_bindings acc)
          children_scope_ids
    in
    build_bindings_per_scope 0 SMap.empty IMap.empty

  (* Let D be the declared names of some scope.

     The free variables F of the scope are the names in G + F' + L - D, where:
     * G contains the global names used in that scope
     * L contains the local names used in that scope
     * F' contains the free variables of its children

     The bound variables B of the scope are the names in B' + D, where:
     * B' contains the bound variables of its children
  *)
  let rec compute_free_and_bound_variables = function
    | Tree.Node (scope, children) ->
      let children' = Base.List.map ~f:compute_free_and_bound_variables children in
      let (free_children, bound_children) =
        List.fold_left
          (fun (facc, bacc) -> function
            | Tree.Node ((_, free, bound), _) -> (SSet.union free facc, SSet.union bound bacc))
          (SSet.empty, SSet.empty)
          children'
      in
      let def_locals = scope.Scope.defs in
      let is_def_local use_name = SMap.exists (fun def_name _ -> def_name = use_name) def_locals in
      let free =
        scope.Scope.globals
        |> L.LMap.fold
             (fun _loc use_def acc ->
               let use_name = use_def.Def.actual_name in
               if is_def_local use_name then
                 acc
               else
                 SSet.add use_name acc)
             scope.Scope.locals
        |> SSet.fold
             (fun use_name acc ->
               if is_def_local use_name then
                 acc
               else
                 SSet.add use_name acc)
             free_children
      in
      let bound = SMap.fold (fun name _def acc -> SSet.add name acc) def_locals bound_children in
      Tree.Node ((def_locals, free, bound), children')
end

module With_Loc = Make (Loc_sig.LocS)
module With_ALoc = Make (Loc_sig.ALocS)
module With_ILoc = Make (Loc_sig.ILocS)
