///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================
#include "rheolef/geo_domain.h"
#include "rheolef/interpolate.h"
#include "rheolef/piola.h"

#ifdef _RHEOLEF_HAVE_MPI
#include "rheolef/mpi_scatter_init.h"
#include "rheolef/mpi_scatter_begin.h"
#include "rheolef/mpi_scatter_end.h"
#endif // _RHEOLEF_HAVE_MPI

namespace rheolef {

// ============================================================================
// specialized interpolate algorithm:
//   interpolate one field on disarray of nodes with element locations
// and put the result into an disarray of values
// => used by characteristic
// ============================================================================
namespace details {

template<class T, class M>
void
interpolate_pass1_symbolic (
    const geo_basic<T,M>&                      omega,
    const disarray<point_basic<T>,M>&          x,
    const disarray<geo_element::size_type,M>&  ix2dis_ie, // x has been already localized in K
          disarray<index_set,M>&               ie2dis_ix, // K -> list of ix
          disarray<point_basic<T>,M>&          hat_x)
{
  typedef typename field_basic<T,M>::size_type size_type;
  // -----------------------------------------------------------------------------
  // 1) in order to call only one time per K element the dis_inod() function,
  //     we group the x(i) per K element of omega
  // -----------------------------------------------------------------------------
  size_type first_dis_ix = x.ownership().first_index();
  distributor ie_ownership = omega.geo_element_ownership (omega.map_dimension());
  size_type first_dis_ie = ie_ownership.first_index();
  std::set<size_type> ext_ie_set;
  index_set empty_set;
  ie2dis_ix.resize (ie_ownership, empty_set); 
  for (size_type ix = 0, nx = ix2dis_ie.size(); ix < nx; ix++) {
    size_type dis_ie = ix2dis_ie [ix];
    size_type dis_ix = first_dis_ix + ix;
    index_set dis_ix_set; dis_ix_set.insert(dis_ix);
    check_macro (dis_ie != std::numeric_limits<size_type>::max(), "node("<<ix<<")="<<x[ix]
       <<" cannot be localized in "<<omega.name() << " (HINT: curved geometry ?)");
    ie2dis_ix.dis_entry (dis_ie) += dis_ix_set;
  }
  ie2dis_ix.dis_entry_assembly();
  // -----------------------------------------------------------------------------
  // 2) loop on K and list external x(i)
  // -----------------------------------------------------------------------------
  // TODO: it would be more efficient to send (dis_ix,x) together : less mpi calls
  //  => how to do an assembly() with a disarray<pair<ix,x> > ?
  //     ... it does not support it yet
  // also: would use more memory: massive coordinates copy...
  communicator comm = ie_ownership.comm();
  distributor ix_ownership = x.ownership();
  index_set ext_dis_ix_set;
  if (is_distributed<M>::value && comm.size() > 1) {
    for (size_type ie = 0, ne = ie2dis_ix.size(); ie < ne; ie++) {
      const index_set& dis_ix_set = ie2dis_ix [ie];
      for (typename index_set::const_iterator iter = dis_ix_set.begin(), last = dis_ix_set.end(); iter != last; ++iter) {
        size_type dis_ix = *iter;
        if (ix_ownership.is_owned (dis_ix)) continue;
        ext_dis_ix_set.insert (dis_ix);
      }
    }
    // TODO: when x=xdofs lives in space, then it modifies (enlarges)
    // definitively the set of externals indexes & values
    // => massive permanent storage
    // otherwise: manage the scatter in a separate map<ext_ix,x>
    //    map<ix,x> ext_x_map;
    //    x.append_dis_indexes (ext_dis_ix_set, ext_x_map);
    // bug:
    // x.append_dis_indexes (ext_dis_ix_set);
    x.set_dis_indexes (ext_dis_ix_set);
  }
  // -----------------------------------------------------------------------------
  // 3) loop on K and compute hat_x(dis_ix)
  // -----------------------------------------------------------------------------
  T infty = std::numeric_limits<T>::max();
  hat_x.resize (ix_ownership, point_basic<T>(infty,infty,infty));
  std::vector<size_type> dis_inod;
  for (size_type ie = 0, ne = ie2dis_ix.size(); ie < ne; ie++) {
    const index_set&   dis_ix_set = ie2dis_ix [ie];
    if (dis_ix_set.size() == 0) continue;
    const geo_element& K          = omega [ie];
    omega.dis_inod (K, dis_inod);
    for (typename index_set::const_iterator iter = dis_ix_set.begin(), last = dis_ix_set.end(); iter != last; ++iter) {
      size_type dis_ix = *iter;
      const point_basic<T>& xi = x.dis_at(dis_ix);
      hat_x.dis_entry(dis_ix) = inverse_piola_transformation (omega, K, dis_inod, xi);
    }
  }
  hat_x.dis_entry_assembly();
  hat_x.set_dis_indexes (ext_dis_ix_set);
}
template<class T, class M>
void
interpolate_pass2_valued (
    const field_basic<T,M>&                uh,
    const disarray<point_basic<T>,M>&      x,
    const disarray<index_set,M>&           ie2dis_ix, // K -> list of ix
    const disarray<point_basic<T>,M>&      hat_x,     // ix -> hat_x
          disarray<T,M>&                   ux)
{
  typedef typename field_basic<T,M>::size_type size_type;
  uh.dis_dof_update();
  T infty = std::numeric_limits<T>::max();
  distributor x_ownership = x.ownership();
  size_type n_comp = 1;
  switch (uh.valued_tag()) {
    case space_constant::scalar: {
      ux.resize (x.ownership(), infty);
      break;
    }
    case space_constant::vector: 
    case space_constant::tensor:
    case space_constant::unsymmetric_tensor: {
      n_comp = uh.size();
      distributor value_ownership (n_comp*x.dis_size(), x.comm(), n_comp*x.size());
      ux.resize (value_ownership, infty);
      break;
    }
    // TODO: tensor, etc: code should also work, but code is not yet tested here.
    default: error_macro ("interpolate: unsupported "<<uh.valued()<<"-valued field");
  }
  // -----------------------------------------------------------------------------
  // on locally managed K, evaluate at hat_x that could be external
  // -----------------------------------------------------------------------------
  const geo_basic<T,M>& omega = uh.get_geo();
  const space_basic<T,M>& Xh = uh.get_space();
  const numbering<T,M>& fem = uh.get_space().get_numbering();
  const basis_basic<T>& b = fem.get_basis();
  std::vector<size_type> dis_idof;
  std::vector<T>         dof;
  std::vector<T>         b_value;
  for (size_type ie = 0, ne = ie2dis_ix.size(); ie < ne; ie++) {
    const index_set&   dis_ix_set = ie2dis_ix [ie];
    if (dis_ix_set.size() == 0) continue;
    // extract dis_idof[] & dof[] one time for all on K
    const geo_element& K          = omega [ie];
    Xh.dis_idof (K, dis_idof);
    size_type loc_ndof = dis_idof.size();
    dof.resize(loc_ndof);
    for (size_type loc_idof = 0; loc_idof < loc_ndof; loc_idof++) {
      dof [loc_idof] = uh.dis_dof (dis_idof [loc_idof]);
    }
    assert_macro (loc_ndof % n_comp == 0, "invalid component count");
    size_type loc_comp_ndof = loc_ndof / n_comp;
    b_value.resize (loc_comp_ndof);
    // loop on all hat_x in K: eval basis at hat_x and combine with dof[]
    for (typename index_set::const_iterator iter = dis_ix_set.begin(), last = dis_ix_set.end(); iter != last; ++iter) {
      size_type dis_ix           = *iter;
      size_type iproc            = x_ownership.find_owner(dis_ix);
      size_type nx               = x_ownership.size(iproc);
      size_type first_dis_ix     = x_ownership.first_index(iproc);
      size_type ix               = dis_ix - first_dis_ix;
      const point_basic<T>& hat_xi = hat_x.dis_at(dis_ix);
      b.eval (K, hat_xi, b_value);
      if (n_comp == 1) { // scalar-valued
        T value = 0;
        for (size_type loc_idof = 0; loc_idof < loc_ndof; loc_idof++) {
          value += dof [loc_idof] * b_value[loc_idof]; // sum_i w_coef(i)*hat_phi(hat_x)
        }
        ux.dis_entry (dis_ix) = value;
      } else { // multi-component case: 
        size_type dis_iux = n_comp*first_dis_ix + ix;
        for (size_type i_comp = 0, loc_idof = 0; i_comp < n_comp; i_comp++, dis_iux += nx) {
          T value = 0;
          for (size_type loc_comp_idof = 0; loc_comp_idof < loc_comp_ndof; loc_comp_idof++, loc_idof++) {
            value += dof [loc_idof] * b_value[loc_comp_idof];
          }
          ux.dis_entry (dis_iux) = value;
        }
      }
    }
  }
  ux.dis_entry_assembly();
}
template<class T, class M>
void
interpolate_on_a_different_mesh (
    const field_basic<T,M>&                   uh,
    const disarray<point_basic<T>,M>&         x,
    const disarray<geo_element::size_type,M>& ix2dis_ie,
          disarray<T,M>&                      ux)
{
  const geo_basic<T,M>&       omega = uh.get_geo();
  disarray<index_set,M>       ie2dis_ix;
  disarray<point_basic<T>,M>  hat_x;
  interpolate_pass1_symbolic (omega, x, ix2dis_ie, ie2dis_ix, hat_x);
  interpolate_pass2_valued   (uh,    x,            ie2dis_ix, hat_x, ux);
}

} // namespace details

// ============================================================================
// interpolate function:
// re-interpolate one field on another mesh and space:
//   field u2h = interpolate (V2h, u1h);
// ============================================================================
template<class T, class M>
field_basic<T,M>
interpolate (const space_basic<T,M>& V2h, const field_basic<T,M>& u1h)
{
  u1h.dis_dof_update();
  typedef typename field_basic<T,M>::size_type size_type;
  if (u1h.get_space() == V2h) {
    size_type have_same_dofs = (u1h.u().size() == V2h.iu_ownership().size());
#ifdef _RHEOLEF_HAVE_MPI
    if (is_distributed<M>::value) {
      have_same_dofs = mpi::all_reduce (V2h.ownership().comm(), have_same_dofs, mpi::minimum<size_type>());
    }
#endif // _RHEOLEF_HAVE_MPI
    if (have_same_dofs) {
      // spaces are exactly the same: no need to re-interpolate or copy
      return u1h;
    }
    // spaces differs only by blocked/unblocked dofs: need to copy
    field_basic<T,M> u2h (V2h);
    for (size_type idof = 0, ndof = V2h.ndof(); idof < ndof; ++idof) {
      u2h.dof(idof) = u1h.dof(idof);
    }
    u2h.dis_dof_update();
    return u2h;
  }
  if (u1h.get_geo().get_background_geo() == V2h.get_geo().get_background_geo()) {
    // meshes are compatible:
    // => buid a wrapper expression and call back interpolate with the
    //    general nonlinear specialized version
    return interpolate (V2h, details::field_expr_v2_nonlinear_terminal_field<T,M>(u1h));
  }
  // -----------------------------------------------------------------------------
  // 1) locate each xdof of V2 in omega1
  // -----------------------------------------------------------------------------
  trace_macro ("reinterpolate with locate: "<<u1h.get_space().stamp()<<" --> " << V2h.stamp());
  const geo_basic<T,M>& omega1 = u1h.get_geo();
  const disarray<point_basic<T>,M>& xdof2 = V2h.get_xdofs();
  disarray<size_type,M> dis_ie1_tab;
  disarray<point_basic<T>,M> xdof2_nearest (xdof2.ownership());
  omega1.nearest (xdof2, xdof2_nearest, dis_ie1_tab);
  // -----------------------------------------------------------------------------
  // 2) interpolate uh1 at xdof2_nearest: get the values in a disarray
  // -----------------------------------------------------------------------------
  disarray<T,M> u2h_dof;
  details::interpolate_on_a_different_mesh (u1h, xdof2_nearest, dis_ie1_tab, u2h_dof);
  // -----------------------------------------------------------------------------
  // 3) copy into a new field uh2
  // -----------------------------------------------------------------------------
  field_basic<T,M> u2h (V2h);
  assert_macro (u2h.ownership().size() ==  u2h_dof.size(), "invalid size");
  copy (u2h_dof.begin(), u2h_dof.end(), u2h.begin_dof());
  u2h.dis_dof_update();
  trace_macro ("reinterpolate with locate done");
  return u2h;
}
// ----------------------------------------------------------------------------
// instanciation in library
// ----------------------------------------------------------------------------
#define _RHEOLEF_instanciation(T,M) 				\
template							\
void								\
details::interpolate_pass1_symbolic (				\
    const geo_basic<T,M>&,					\
    const disarray<point_basic<T>,M>&,				\
    const disarray<geo_element::size_type,M>&,			\
          disarray<index_set,M>&,				\
          disarray<point_basic<T>,M>&);				\
template							\
void								\
details::interpolate_pass2_valued (				\
    const field_basic<T,M>&,					\
    const disarray<point_basic<T>,M>&,				\
    const disarray<index_set,M>&,				\
    const disarray<point_basic<T>,M>&,				\
          disarray<T,M>&);					\
template							\
field_basic<T,M>						\
interpolate (const space_basic<T,M>&, const field_basic<T,M>&);

_RHEOLEF_instanciation(Float,sequential)
#ifdef _RHEOLEF_HAVE_MPI
_RHEOLEF_instanciation(Float,distributed)
#endif // _RHEOLEF_HAVE_MPI

} // namespace rheolef
