(**************************************************************************)
(*                                                                        *)
(*  This file is part of Frama-C.                                         *)
(*                                                                        *)
(*  Copyright (C) 2007-2020                                               *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  you can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Cil_types
open Logic_ptree

type slevel_annotation =
  | SlevelMerge
  | SlevelDefault
  | SlevelLocal of int
  | SlevelFull

type unroll_annotation =
  | UnrollAmount of Cil_types.term
  | UnrollFull

type flow_annotation =
  | FlowSplit of term
  | FlowMerge of term

type allocation_kind = By_stack | Fresh | Fresh_weak | Imprecise

(* We use two representations for annotations :
   - the high level representation (HL) which is exported from this module
   - the low level representation (Cil) which is used by the kernel to store
     any annotation

   Annotations in this module define the export and import function to go from
   one to another. Then, the parse and print functions works directly on the
   high level representation.

             add  --+
                    |
   ACSL --> parse --+--> HL --> export --> Cil --> import --+--> HL --> print
                                                            |
                                                            +--> get
*)

exception Parse_error

module type Annotation =
sig
  type t

  val name : string
  val is_loop_annot : bool
  val parse : typing_context:Logic_typing.typing_context -> lexpr list -> t
  val export : t -> acsl_extension_kind
  val import : acsl_extension_kind -> t
  val print : Format.formatter -> t -> unit
end

module Register (M : Annotation) =
struct
  include M

  let typer typing_context loc args =
    try export (parse ~typing_context args)
    with Parse_error ->
      typing_context.Logic_typing.error loc "Invalid %s directive" name

  let printer _pp fmt lp =
    print fmt (import lp)

  let () =
    if is_loop_annot then
      Acsl_extension.register_code_annot_next_loop name typer ~printer false
    else
      Acsl_extension.register_code_annot_next_stmt name typer ~printer false

  let get stmt =
    let filter_add _emitter annot acc =
      match annot.annot_content with
      | Cil_types.AExtended (_, is_loop_annot', {ext_name=name'; ext_kind})
        when name' = name && is_loop_annot' = is_loop_annot ->
        import ext_kind :: acc
      | _ -> acc
    in
    List.rev (Annotations.fold_code_annot filter_add stmt [])

  let add ~emitter ~loc stmt annot =
    let param = M.export annot in
    let extension = Logic_const.new_acsl_extension "slevel" loc false param in
    let annot_node = Cil_types.AExtended ([], false, extension) in
    let code_annotation = Logic_const.new_code_annotation annot_node in
    Annotations.add_code_annot emitter stmt code_annotation
end


module Slevel = Register (struct
    type t = slevel_annotation

    let name = "slevel"
    let is_loop_annot = false

    let parse ~typing_context:_ = function
      | [{lexpr_node = PLvar "default"}] -> SlevelDefault
      | [{lexpr_node = PLvar "merge"}] -> SlevelMerge
      | [{lexpr_node = PLvar "full"}] -> SlevelFull
      | [{lexpr_node = PLconstant (IntConstant i)}] ->
        let i =
          try int_of_string i
          with Failure _ -> raise Parse_error
        in
        if i < 0 then raise Parse_error;
        SlevelLocal i
      | _ -> raise Parse_error

    let export = function
      | SlevelDefault -> Ext_terms [Logic_const.tstring "default"]
      | SlevelMerge -> Ext_terms [Logic_const.tstring "merge"]
      | SlevelLocal i -> Ext_terms [Logic_const.tinteger i]
      | SlevelFull -> Ext_terms [Logic_const.tstring "full"]

    let import = function
      | Ext_terms [{term_node}] ->
        begin match term_node with
          | TConst (LStr "default") -> SlevelDefault
          | TConst (LStr "merge") -> SlevelMerge
          | TConst (LStr "full") -> SlevelFull
          | TConst (Integer (i, _)) -> SlevelLocal (Integer.to_int i)
          | _ -> SlevelDefault (* be kind. Someone is bound to write a visitor
                                  that will simplify our term into something
                                  unrecognizable... *)
        end
      | _ -> assert false

    let print fmt = function
      | SlevelDefault -> Format.pp_print_string fmt "default"
      | SlevelMerge -> Format.pp_print_string fmt "merge"
      | SlevelLocal i -> Format.pp_print_int fmt i
      | SlevelFull -> Format.pp_print_string fmt "full"
  end)

module SimpleTermAnnotation =
struct
  type t = term

  let parse ~typing_context = function
    | [t] ->
      let open Logic_typing in
      typing_context.type_term typing_context typing_context.pre_state t
    | _ -> raise Parse_error

  let export t =
    Ext_terms [t]

  let import = function
    | Ext_terms [t] -> t
    | _ -> assert false

  let print = Printer.pp_term
end

module Unroll = Register (struct
    type t = unroll_annotation

    let name = "unroll"
    let is_loop_annot = true

    let parse ~typing_context = function
      | [] -> UnrollFull
      | [t] ->
        let open Logic_typing in
        UnrollAmount
          (typing_context.type_term typing_context typing_context.pre_state t)
      | _ -> raise Parse_error

    let export = function
      | UnrollFull -> Ext_terms []
      | UnrollAmount t -> Ext_terms [t]

    let import = function
      | Ext_terms [] -> UnrollFull
      | Ext_terms [t] -> UnrollAmount t
      | _ -> assert false

    let print fmt = function
      | UnrollFull -> ()
      | UnrollAmount t -> Printer.pp_term fmt t
  end)

module Split = Register (struct
    include SimpleTermAnnotation
    let name = "split"
    let is_loop_annot = false
  end)

module Merge = Register (struct
    include SimpleTermAnnotation
    let name = "merge"
    let is_loop_annot = false
  end)


let get_slevel_annot stmt =
  try Some (List.hd (Slevel.get stmt))
  with Failure _ -> None

let get_unroll_annot stmt = Unroll.get stmt

let get_flow_annot stmt =
  List.map (fun a -> FlowSplit a) (Split.get stmt) @
  List.map (fun a -> FlowMerge a) (Merge.get stmt)


let add_slevel_annot = Slevel.add

let add_unroll_annot = Unroll.add

let add_flow_annot ~emitter ~loc stmt = function
  | FlowSplit annot -> Split.add ~emitter ~loc stmt annot
  | FlowMerge annot -> Merge.add ~emitter ~loc stmt annot


module Subdivision = Register (struct
    type t = int
    let name = "subdivide"
    let is_loop_annot = false

    let parse ~typing_context:_ = function
      | [{lexpr_node = PLconstant (IntConstant i)}] ->
        let i =
          try int_of_string i
          with Failure _ -> raise Parse_error
        in
        if i < 0 then raise Parse_error;
        i
      | _ -> raise Parse_error

    let export i = Ext_terms [Logic_const.tinteger i]
    let import = function
      | Ext_terms [{term_node = TConst (Integer (i, _))}] -> Integer.to_int i
      | _ -> assert false

    let print fmt i = Format.pp_print_int fmt i
  end)

let get_subdivision_annot = Subdivision.get
let add_subdivision_annot = Subdivision.add


module Allocation = struct
  let of_string = function
    | "by_stack"   -> Some By_stack
    | "fresh"      -> Some Fresh
    | "fresh_weak" -> Some Fresh_weak
    | "imprecise"  -> Some Imprecise
    | _            -> None

  let to_string = function
    | By_stack   -> "by_stack"
    | Fresh      -> "fresh"
    | Fresh_weak -> "fresh_weak"
    | Imprecise  -> "imprecise"

  include Register (struct
      type t = allocation_kind
      let name = "eva_allocate"
      let is_loop_annot = false

      let parse ~typing_context:_ = function
        | [{lexpr_node = PLvar s}] -> Extlib.the ~exn:Parse_error (of_string s)
        | _ -> raise Parse_error

      let export alloc_kind =
        Ext_terms [Logic_const.tstring (to_string alloc_kind)]

      let import = function
        | Ext_terms [{term_node}] ->
          (* Be kind and return By_stack by default. Someone is bound to write a
             visitor that will simplify our term into something unrecognizable. *)
          begin match term_node with
            | TConst (LStr s) -> Extlib.opt_conv By_stack (of_string s)
            | _ -> By_stack
          end
        | _ -> assert false

      let print fmt alloc_kind =
        Format.pp_print_string fmt (to_string alloc_kind)
    end)

  let get stmt =
    match get stmt with
    | [] -> Extlib.the (of_string (Value_parameters.AllocBuiltin.get ()))
    | [x] -> x
    | x :: _ ->
      Value_parameters.warning ~current:true ~once:true
        "Several eva_allocate annotations at the same statement; selecting %s\
         and ignoring the others." (to_string x);
      x
end

let get_allocation = Allocation.get
