(*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

(* This cache maps from spreads (indexed by spread ids, which are generated by eval_destructor
 * and Statement.ml) to one of the lower bounds of their operands in each index and the reasons
 * that reach that spread operand.
 *
 * cache(spread_id)(resolve_idx) gets you a slice and reasons for a spread index. If any
 * two indices have more than one lower bound, then we will get at least quadratic results
 * by computing the spreads. Instead of going through with the computation, we emit an error.
 *
 * We need to track the slice explicitly in addition to the list of reasons because
 * only count an index as having multiple lower bounds if those lower bounds are actually
 * different. This is important: if an early index receives multiple lower bounds, it will
 * kick off a resolution job for each. Both of those resolution jobs share a spread id,
 * so it's important that we don't double count a later operand being resolved as multiple
 * lower bounds when it's just resolved twice because of the separate resolution jobs.
 *)

open Utils_js

type reason_state = ALoc.t Flow_intermediate_error_types.exponential_spread_reason_group

type cache_state =
  (int option * (Type.Object.slice * reason_state) IMap.t, reason_state * reason_state) result

type t = cache_state IMap.t

let add_lower_bound cache spread_id resolve_idx r objtypes =
  let state =
    match IMap.find_opt spread_id !cache with
    | None -> Ok (None, IMap.empty)
    | Some state -> state
  in
  let state' =
    match state with
    | Error error_groups -> Error error_groups
    | Ok (idx_option, map) ->
      let (map', has_multiple_lower_bounds) =
        let (new_entry, has_multiple_lower_bounds) =
          let open Flow_intermediate_error_types in
          match IMap.find_opt resolve_idx map with
          | None ->
            ((Nel.hd objtypes, { first_reason = r; second_reason = None }), Nel.length objtypes <> 1)
          | Some (slice, reason_state) ->
            let has_new_lower_bound =
              Nel.fold_left (fun acc s -> acc || slice <> s) false objtypes
            in
            let reason_state = { reason_state with second_reason = Some r } in
            ((slice, reason_state), has_new_lower_bound)
        in
        (IMap.add resolve_idx new_entry map, has_multiple_lower_bounds)
      in
      (* If we have multiple lower bounds and this isn't the only idx with multiple
       * lower bounds then we enter our error state *)
      if has_multiple_lower_bounds then
        match idx_option with
        | Some idx when idx <> resolve_idx ->
          let group1 = snd (IMap.find resolve_idx map') in
          let group2 = snd (IMap.find idx map') in
          Error (group1, group2)
        | _ -> Ok (Some resolve_idx, map')
      else
        Ok (idx_option, map')
  in
  cache := IMap.add spread_id state' !cache

let get_error_groups cache spread_id =
  match IMap.find_opt spread_id !cache with
  | Some (Error (group1, group2)) -> (group1, group2)
  | _ ->
    assert_false
      "Invariant violation: make sure can_spread is false before calling get_error_groups"

let can_spread cache spread_id =
  match IMap.find_opt spread_id !cache with
  | Some (Error _) -> false
  | _ -> true
