/*
 * arithmetic.cpp: basic arithmetic operations. Implemented by means
 * of a special-purpose spigot core, which takes two sources and can
 * fetch a matrix from either one.
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#include "spigot.h"
#include "funcs.h"
#include "error.h"

class Core2 : public Core {
    Source *sx, *sy;
    bool same;                         // means sy is just the same as sx
    bool started;
    int orig_matrix[8];
    bigint matrix[8];
    bigint xbot, xtop, ybot, ytop;
    bool xtop_inf, ytop_inf;
    int eval_endpoint(bigint out[2], const bigint T[8],
                      const bigint &x, bool xinf,
                      const bigint &y, bool yinf);
  public:
    Core2(Source *ax, Source *ay, const int m[8], bool same = false);
    ~Core2();
    virtual void premultiply(const bigint matrix[4]);
    virtual Core *clone();
    virtual void refine();
    virtual int max_endpoints();
    virtual int endpoints(bigint *endpoints);
};

/*
 * Core2 works by having an eight-element 'matrix' (really a tensor,
 * I'm told, though that's beyond my maths) in which (a,b,c,d,e,f,g,h)
 * represents a sort of 'dyadic Mobius transformation'
 *
 *           a x y + b x + c y + d
 *  T(x,y) = ---------------------
 *           e x y + f x + g y + h
 *
 * We need to be able to do three things to this matrix: premultiply
 * in an output transform provided by our owning Generator,
 * postmultiply in an input transform from the x source, and ditto the
 * y source. All of these operations involve an ordinary 4-element
 * matrix as the other parameter. So let's suppose we have a matrix
 * (p,q,r,s) representing the ordinary Mobius transformation
 *
 *         p x + q
 *  M(x) = -------
 *         r x + s
 *
 * Then premultiplication means finding a replacement 8-element matrix
 * representing T'(x,y) = M(T(x,y)), and the two postmultiplications
 * similarly find T(M(x),y) and T(x,M(y)) respectively.
 *
 * Getting Maxima to do the tedious algebra for us, the following
 * commands will produce the rational functions we need, from which
 * the numerator and denominator terms are easily collected back up
 * into coefficients of xy, x, y and 1:
 *
 *    T : lambda([x,y], (a*x*y + b*x + c*y + d) / (e*x*y + f*x + g*y + h));
 *    M : lambda([x], (p*x + q) / (r*x + s));
 *    factor(M(T(x,y)));
 *    factor(T(M(x),y));
 *    factor(T(x,M(y)));
 */

static void tensor_pre(bigint out[8], const bigint M[4], const bigint T[8])
{
    const bigint &p = M[0], &q = M[1], &r = M[2], &s = M[3];
    bigint a = T[0], b = T[1], c = T[2], d = T[3];
    bigint e = T[4], f = T[5], g = T[6], h = T[7];
    /*
     * Maxima generates M(T(x,y)) as:
     *
     *   e q x y + a p x y + g q y + c p y + f q x + b p x + h q + d p
     *   -------------------------------------------------------------
     *   e s x y + a r x y + g s y + c r y + f s x + b r x + h s + d r
     */
    out[0] = a*p + e*q;
    out[1] = b*p + f*q;
    out[2] = c*p + g*q;
    out[3] = d*p + h*q;
    out[4] = a*r + e*s;
    out[5] = b*r + f*s;
    out[6] = c*r + g*s;
    out[7] = d*r + h*s;
}

static void tensor_postx(bigint out[8], const bigint T[8], const bigint M[4])
{
    const bigint &p = M[0], &q = M[1], &r = M[2], &s = M[3];
    bigint a = T[0], b = T[1], c = T[2], d = T[3];
    bigint e = T[4], f = T[5], g = T[6], h = T[7];
    /*
     * Maxima generates T(M(x),y) as:
     *
     *   c r x y + a p x y + c s y + a q y + d r x + b p x + d s + b q
     *   -------------------------------------------------------------
     *   g r x y + e p x y + g s y + e q y + h r x + f p x + h s + f q
     */
    out[0] = a*p + c*r;
    out[1] = b*p + d*r;
    out[2] = a*q + c*s;
    out[3] = b*q + d*s;
    out[4] = e*p + g*r;
    out[5] = f*p + h*r;
    out[6] = e*q + g*s;
    out[7] = f*q + h*s;
}

static void tensor_posty(bigint out[8], const bigint T[8], const bigint M[4])
{
    const bigint &p = M[0], &q = M[1], &r = M[2], &s = M[3];
    bigint a = T[0], b = T[1], c = T[2], d = T[3];
    bigint e = T[4], f = T[5], g = T[6], h = T[7];
    /*
     * Maxima generates T(x,M(y)) as:
     *
     *   b r x y + a p x y + d r y + c p y + b s x + a q x + d s + c q
     *   -------------------------------------------------------------
     *   f r x y + e p x y + h r y + g p y + f s x + e q x + h s + g q
     */
    out[0] = a*p + b*r;
    out[1] = a*q + b*s;
    out[2] = c*p + d*r;
    out[3] = c*q + d*s;
    out[4] = e*p + f*r;
    out[5] = e*q + f*s;
    out[6] = g*p + h*r;
    out[7] = g*q + h*s;
}

Core2::Core2(Source *ax, Source *ay, const int m[8], bool asame)
    : sx(ax)
    , sy(ay)
    , started(false)
    , same(asame)
{
    for (int i = 0; i < 8; i++)
        matrix[i] = orig_matrix[i] = m[i];
    dprint("hello Core2 %p %p %8m [same=%B]", sx, sy, matrix, same);
}

Core2::~Core2()
{
    delete sx;
    if (sy)
        delete sy;
}

void Core2::premultiply(const bigint inmatrix[4])
{
    dprint("premultiply: %4m %8m", inmatrix, matrix);
    tensor_pre(matrix, inmatrix, matrix);
    dprint("matrix after premult %8m", matrix);
}

Core *Core2::clone()
{
    return new Core2(sx->clone(), sy ? sy->clone() : NULL, orig_matrix, same);
}

void Core2::refine()
{
    bool force_absorb_x = false, force_absorb_y = false;

    dprint("refine started");

    if (!started) {
        /*
         * Fetch the interval bounds.
         */
        force_absorb_x = sx->gen_interval(&xbot, &xtop);
        xtop_inf = (xtop == 0);  /* special case meaning infinity */
        if (same) {
            ybot = xbot;
            ytop = xtop;
            ytop_inf = xtop_inf;
        } else {
            force_absorb_y = sy->gen_interval(&ybot, &ytop);
            ytop_inf = (ytop == 0);
        }
        started = true;
        dprint("Core2 init: intervals [%b,%b] [%b,%b]",
               &xbot, &xtop, &ybot, &ytop);
    }

    do {
        bigint inmatrix[4];
        if (same) {
            /*
             * If our two sources are the same source, then we have no
             * choice anyway about which one to fetch a matrix from,
             * because there is no possible answer except 'both at
             * once'.
             */
            force_absorb_x = true;
            force_absorb_y = false;
        } else if (!force_absorb_x && !force_absorb_y) {
            /*
             * We need to fetch something from _one_ of our two
             * sources, but which? To answer that, we'll evaluate our
             * current interval endpoints and see which ones are the
             * furthest apart.
             */
            bigint ends[10];
            int n_ends = endpoints(ends);
            if (n_ends == 5) {
                /*
                 * This is the special case in which both starting
                 * intervals are infinite and something exceptionally
                 * annoying has happened to the tensor. Grab another
                 * matrix from both sources in the hope that things
                 * settle down.
                 */
                force_absorb_x = force_absorb_y = true;
            } else {
                /*
                 * OK, only four endpoints, which are respectively
                 * from (xbot,ybot), (xbot,ytop), (xtop,ybot) and
                 * (xtop,ytop).
                 *
                 * If there's a pole between either pair of endpoints
                 * differing in the x value, then we should fetch from
                 * x to try to get rid of it. Similarly y.
                 *
                 * (A pole _at_ any endpoint is treated as between
                 * that endpoint and everything else, and causes a
                 * fetch from both sources.)
                 */
                if (bigint_sign(ends[1]) * bigint_sign(ends[5]) != 1 ||
                    bigint_sign(ends[3]) * bigint_sign(ends[7]) != 1)
                    force_absorb_x = true;
                if (bigint_sign(ends[1]) * bigint_sign(ends[3]) != 1 ||
                    bigint_sign(ends[5]) * bigint_sign(ends[7]) != 1)
                    force_absorb_y = true;
                if (!force_absorb_x && !force_absorb_y) {
                    /*
                     * If that still hasn't settled the matter, we'll
                     * have to look at the actual numeric differences.
                     */
                    bigint exy = fdiv(ends[0], ends[1]);
                    bigint exY = fdiv(ends[2], ends[3]);
                    bigint eXy = fdiv(ends[4], ends[5]);
                    bigint eXY = fdiv(ends[6], ends[7]);
                    bigint xdiff = bigint_abs(eXy-exy) + bigint_abs(eXY-exY);
                    bigint ydiff = bigint_abs(exY-exy) + bigint_abs(eXY-eXy);
                    dprint("decide: xdiff=%b ydiff=%b", &xdiff, &ydiff);
                    if (xdiff >= ydiff)
                        force_absorb_x = true;
                    if (ydiff >= xdiff)
                        force_absorb_y = true;
                }
            }
        }
        if (force_absorb_x) {
            force_absorb_x = sx->gen_matrix(inmatrix);
            dprint("postmultiply x: %8m %4m", matrix, inmatrix);
            tensor_postx(matrix, matrix, inmatrix);
            dprint("matrix after postmult %8m", matrix);
            if (same) {
                dprint("postmultiply y: %8m %4m", matrix, inmatrix);
                tensor_posty(matrix, matrix, inmatrix);
                dprint("matrix after postmult %8m", matrix);
            }
        }
        if (force_absorb_y) {
            force_absorb_y = sy->gen_matrix(inmatrix);
            dprint("postmultiply y: %8m %4m", matrix, inmatrix);
            tensor_posty(matrix, matrix, inmatrix);
            dprint("matrix after postmult %8m", matrix);
        }
    } while (force_absorb_x || force_absorb_y);
}

int Core2::max_endpoints()
{
    /*
     * We usually return 4 endpoints, but in one special case if both
     * starting intervals have infinite top ends, we might have to
     * return 5.
     */
    return 5;
}

int Core2::eval_endpoint(bigint out[2], const bigint T[8],
                         const bigint &x, bool xinf,
                         const bigint &y, bool yinf)
{
    const bigint &a = T[0], &b = T[1], &c = T[2], &d = T[3];
    const bigint &e = T[4], &f = T[5], &g = T[6], &h = T[7];

    /*
     * Evaluating the image of an input point under the tensor T is
     * annoyingly fiddly because x or y or both or neither could be
     * infinite, and if either one is then we have multiple special
     * cases in turn (similarly to Core1 in spigot.cpp).
     */
    if (!xinf) {
        if (!yinf) {
            /*
             * The easy finite case. A spot of factorisation reduces
             * the number of multiplications by x.
             */
            out[0] = (a*y + b) * x + (c*y + d);
            out[1] = (e*y + f) * x + (g*y + h);
            dprint("  finite endpoint %b / %b", &out[0], &out[1]);
            return 1;
        } else {
            /*
             * x is finite, but y is infinite. So we either have
             * (ax+c)/(ex+g), or if that comes to 0/0, we fall back to
             * (bx+d)/(fx+h).
             */
            out[0] = a*x + c;
            out[1] = e*x + g;
            if (out[0] == 0 && out[1] == 0) {
                out[0] = b*x + d;
                out[1] = f*x + h;
            }
            dprint("  x-infinite endpoint %b / %b", &out[0], &out[1]);
            return 1;
        }
    } else {
        if (!yinf) {
            /*
             * y is finite, but x is infinite. So we either have
             * (ay+b)/(ey+f), or if that comes to 0/0, we fall back to
             * (cy+d)/(gy+h).
             */
            out[0] = a*y + b;
            out[1] = e*y + f;
            if (out[0] == 0 && out[1] == 0) {
                out[0] = c*y + d;
                out[1] = g*y + h;
            }
            dprint("  y-infinite endpoint %b / %b", &out[0], &out[1]);
            return 1;
        } else {
            /*
             * Both x and y are infinite. In this case, our opening
             * bid is just a/e, and our final fallback if all of
             * a,b,c,e,f,g are zero is d/h; but in between, there's a
             * more interesting case.
             */
            if (a != 0 || e != 0) {
                out[0] = a;
                out[1] = e;
                dprint("  xy-infinite endpoint a/e %b / %b", &out[0], &out[1]);
                return 1;
            } else if (b == 0 && c == 0 && f == 0 && g == 0) {
                out[0] = d;
                out[1] = h;
                dprint("  xy-infinite endpoint d/h %b / %b", &out[0], &out[1]);
                return 1;
            } else {
                /*
                 * a,e are zero, but at least one of b,c,f,g is not.
                 *
                 * If c,g are zero too, then we have no dependency on
                 * y at all, and simply return as if we were computing
                 * (bx+d)/(fx+h) - which we also know comes to just
                 * b/f since we only get to that case if at least one
                 * of b,f is nonzero.
                 *
                 * Similarly, if b,f are both zero, then we return
                 * c/g, by the same reasoning in mirror symmetry.
                 *
                 * But if one of c,g is nonzero _and_ one of b,f is
                 * nonzero, then what do we do? We're essentially
                 * asking for the limit of (bx+cy+d)/(fx+gy+h) as x
                 * and y tend to infinity, and that could come to
                 * either b/f or c/g depending on which of x,y tends
                 * to infinity 'fastest'. (In fact, it could also come
                 * to (thing between b,c) / (thing between f,g) if x
                 * and y tended to infinity in some particular
                 * relationship, but that's exactly the sort of thing
                 * the code in iterate_spigot_algorithm can handle -
                 * we only have to give it the two extreme endpoints
                 * of that range.)
                 *
                 * So in fact, what we do is to optionally return
                 * _both_ b/f and c/g, which gives Core2 the
                 * possibility of returning five endpoints rather than
                 * just four.
                 */
                int i = 0;
                if (b != 0 || f != 0) {
                    out[2*i+0] = b;
                    out[2*i+1] = f;
                    dprint("  xy-infinite endpoint b/f %b / %b",
                           &out[2*i+0], &out[2*i+1]);
                    i++;
                }
                if (c != 0 || g != 0) {
                    out[2*i+0] = c;
                    out[2*i+1] = g;
                    dprint("  xy-infinite endpoint c/g %b / %b",
                           &out[2*i+0], &out[2*i+1]);
                    i++;
                }
                return i;
            }
        }
    }
}

int Core2::endpoints(bigint *endpoints)
{
    dprint("endpoints for %8m", matrix);
    /*
     * All of these calls to eval_endpoint return 1 endpoint only,
     * because eval_endpoint can only return 2 if both inputs are
     * infinite (and not necessarily even then).
     *
     * Note that the code in refine() which decides which source to
     * fetch from depends on the order of these calls, so don't switch
     * them around casually!
     */
    eval_endpoint(endpoints + 0, matrix, xbot, false, ybot, false);
    eval_endpoint(endpoints + 2, matrix, xbot, false, ytop, ytop_inf);
    eval_endpoint(endpoints + 4, matrix, xtop, xtop_inf, ybot, false);
    /*
     * The final call to eval_endpoint can return 2 instead of 1, so
     * we add its return value to the previous 3 endpoints to get our
     * final count.
     */
    int ret = eval_endpoint(endpoints + 6, matrix,
                            xtop, xtop_inf, ytop, ytop_inf);
    return ret + 3;
}

Spigot *spigot_add(Spigot *a, Spigot *b)
{
    bigint an, ad, bn, bd;
    bool arat = a->is_rational(&an, &ad);
    bool brat = b->is_rational(&bn, &bd);

    if (arat && brat) {
        delete a;
        delete b;
        return spigot_rational(an * bd + bn * ad, ad * bd);
    } else if (arat) {
        delete a;
        return spigot_mobius(b, ad, an, 0, ad);
    } else if (brat) {
        delete b;
        return spigot_mobius(a, bd, bn, 0, bd);
    }

    int m[8] = {0,1,1,0,0,0,0,1}; /* (0xy+1x+1y+0)/(0xy+0x+0y+1) == x+y */
    return new Core2(a->toSource(), b->toSource(), m);
}

Spigot *spigot_sub(Spigot *a, Spigot *b)
{
    bigint an, ad, bn, bd;
    bool arat = a->is_rational(&an, &ad);
    bool brat = b->is_rational(&bn, &bd);

    if (arat && brat) {
        delete a;
        delete b;
        return spigot_rational(an * bd - bn * ad, ad * bd);
    } else if (arat) {
        delete a;
        return spigot_mobius(b, -ad, an, 0, ad);
    } else if (brat) {
        delete b;
        return spigot_mobius(a, bd, -bn, 0, bd);
    }

    int m[8] = {0,1,-1,0,0,0,0,1}; /* (0xy+1x-1y+0)/(0xy+0x+0y+1) == x-y */
    return new Core2(a->toSource(), b->toSource(), m);
}

Spigot *spigot_mul(Spigot *a, Spigot *b)
{
    bigint an, ad, bn, bd;
    bool arat = a->is_rational(&an, &ad);
    bool brat = b->is_rational(&bn, &bd);

    if (arat && brat) {
        delete a;
        delete b;
        return spigot_rational(an * bn, ad * bd);
    } else if (arat) {
        delete a;
        return spigot_mobius(b, an, 0, 0, ad);
    } else if (brat) {
        delete b;
        return spigot_mobius(a, bn, 0, 0, bd);
    }

    int m[8] = {1,0,0,0,0,0,0,1}; /* (1xy+0x+0y+0)/(0xy+0x+0y+1) == xy */
    return new Core2(a->toSource(), b->toSource(), m);
}

Spigot *spigot_quadratic(Spigot *a, int a2, int a1, int a0)
{
    bigint an, ad;
    bool arat = a->is_rational(&an, &ad);

    if (arat) {
        delete a;
        return spigot_rational(a2 * an * an + a1 * an * ad + a0 * ad * ad,
                               ad * ad);
    }

    int m[8] = {a2,a1,0,a0,0,0,0,1};
    return new Core2(a->toSource(), NULL, m, true);
}

Spigot *spigot_square(Spigot *a)
{
    return spigot_quadratic(a, 1, 0, 0);
}

Spigot *spigot_invsquare(Spigot *a)
{
    bigint an, ad;
    bool arat = a->is_rational(&an, &ad);

    /*
     * Of course in principle I could have written a completely
     * general spigot_quadratic_rational_function() which took all six
     * of the coefficients that define an arbitrary function of the
     * form
     *
     *          n2 x^2 + n1 x + n0
     *    x |-> ------------------
     *          d2 x^2 + d1 x + d0
     *
     * and then I could have implemented both spigot_invsquare *and*
     * spigot_square as one-line wrappers on it. The reason I didn't
     * is because the expression for the rational special case in
     * spigot_quadratic_rational_function() would have been more
     * horrible than I felt like dealing with!
     */

    if (arat) {
        delete a;
        if (an == 0) {
            delete a;
            throw spigot_error("division by zero");
        }
        return spigot_rational(ad * ad, an * an);
    }

    int m[8] = {0,0,0,1,1,0,0,0};
    return new Core2(a->toSource(), NULL, m, true);
}

Spigot *spigot_div(Spigot *a, Spigot *b)
{
    bigint an, ad, bn, bd;
    bool arat = a->is_rational(&an, &ad);
    bool brat = b->is_rational(&bn, &bd);

    if (brat) {
        delete b;
        if (bn == 0) {
            delete a;
            throw spigot_error("division by zero");
        }
        if (arat) {
            delete a;
            return spigot_rational(an * bd, ad * bn);
        } else {
            return spigot_mobius(a, bd, 0, 0, bn);
        }
    } else if (arat) {
        delete a;
        return spigot_mobius(b, 0, an, ad, 0);
    }

    int m[8] = {0,1,0,0,0,0,1,0}; /* (0xy+1x+0y+0)/(0xy+0x+1y+0) = x/y */
    return new Core2(a->toSource(), b->toSource(), m);
}

Spigot *spigot_combine(Spigot *a, Spigot *b,
                       int nxy, int nx, int ny, int nc,
                       int dxy, int dx, int dy, int dc)
{
    /*
     * No special-case checking in this function - it's only called
     * from the internals of other spigot routines, and they're
     * assumed to take responsibility for interesting cases.
     */
    int m[8] = {nxy, nx, ny, nc, dxy, dx, dy, dc};
    return new Core2(a->toSource(), b->toSource(), m);
}
