#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains utility functions for symbolic computations,
mostly a thin layer on top of swiginac.
"""

# Copyright (C) 2007-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2007-10-16
# Last changed: 2009-03-19

import swiginac
import SyFi

from sfc.symbolic_utils.symbol_factory import symbol, symbols, symbolic_vector, symbolic_matrix

# ... Utilities for symbolic computations, mostly on matrices

_p = symbols( ["x", "y", "z"] )

# ... Some type stuff, shouldn't really be in this file?

def is_indexed_type(A):
    """Checks if the argument is a matrix or lst."""
    return isinstance(A, list) \
        or isinstance(A, tuple) \
        or isinstance(A, swiginac.lst) \
        or isinstance(A.evalm(), swiginac.matrix)

def is_scalar(e):
    return not is_indexed_type(e.evalm())

def as_matrix(A):
    """Convert A to a swiginac.matrix from a lst or list."""
    if isinstance(A, swiginac.lst):     return swiginac.lst_to_matrix(A).evalm()
    if isinstance(A, list):             return swiginac.matrix(len(A), 1, A).evalm() # TODO: handle list of lists?
    if isinstance(A, tuple):            return swiginac.matrix(len(A), 1, A).evalm() # TODO: handle list of lists?
    if isinstance(A.evalm(), swiginac.matrix):  return A.evalm()
    raise TypeError("ERROR: cannot convert A to a swiginac.matrix, A is of type %s" % str(type(A)))

def zeros(m, n):
    """Returns a m x n matrix with zeros."""
    return swiginac.matrix(m, n)

def ones(m, n):
    """Returns a m x n matrix with ones."""
    A = swiginac.matrix(m, n)
    for i in xrange(m):
        for j in xrange(n):
            A[i,j] = 1
    return A

def Id(n):
    """Returns the n x n identity matrix."""
    return swiginac.unit_matrix(n)

def add(A, B):
    """Adds two swiginac expressions and calls evalm() on the result before returning."""
    if is_indexed_type(A): A = as_matrix(A)
    if is_indexed_type(B): B = as_matrix(B)
    return (A+B).evalm()

def sub(A, B):
    """Subtracts two swiginac expressions and calls evalm() on the result before returning."""
    if is_indexed_type(A): A = as_matrix(A)
    if is_indexed_type(B): B = as_matrix(B)
    return (A-B).evalm()

def mul(A, B):
    """Multiplies two swiginac expressions and calls evalm() on the result before returning."""
    if is_indexed_type(A): A = as_matrix(A)
    if is_indexed_type(B): B = as_matrix(B)
    return (A*B).evalm()

def cross(a, b):
    """Takes the cross product of two vectors."""
    a = as_matrix(a)
    b = as_matrix(b)
    if len(a) != 3 or len(b) != 3:
        raise TypeError("Need 3D vectors a and b in cross(a,b).")
    im = [a[1], a[2],
          b[1], b[2]]
    jm = [a[2], a[0],
          b[2], b[0]]
    km = [a[0], a[1],
          b[0], b[1]]
    return swiginac.matrix(3, 1, [swiginac.matrix(2,2,m).determinant() for m in (im, jm, km)])

def inner(A, B): # FIXME: use multiplication for matrices instead of contraction?
    """Takes the inner product of A and B: multiplication for scalars, dot product for vectors and lsts, contraction for matrices."""
    if is_indexed_type(A): A = as_matrix(A)
    if is_indexed_type(B): B = as_matrix(B)
    if isinstance(A, swiginac.matrix) and isinstance(B, swiginac.matrix):
        if len(A) == len(B):
            return sum( [A.op(i)*B.op(i) for i in xrange(len(A))] )
        if A.cols() == 1: A = A.transpose() # row vector to column vector
        if B.rows() == 1: B = B.transpose() # column vector to row vector
        if A.cols() == B.rows(): # matrix-vector or vector-matrix inner product
            return mul(A, B)
        raise ValueError("Unmatching size of operands")
    return (A*B).evalm()

def dot(a, b): # FIXME: use multiplication for matrices instead of contraction?
    return inner(a, b)

def contract(A,B):
    """See inner(A,B)."""
    return inner(A,B)

def det(A):
    """Returns the determinant of the argument."""
    return A.evalm().determinant()

def abs(x):
    """Returns the absolute value of the argument."""
    return swiginac.abs(x)

def transpose(A):
    """Returns the transpose of the argument."""
    return swiginac.transpose(A.evalm())

def trace(A):
    """Returns the trace of the argument."""
    return swiginac.trace(A.evalm())

def inverse(A):
    """Returns the inverse of the argument."""
    A = A.evalm()
    if isinstance(A, swiginac.matrix):
        return A.evalm().inverse()
    return 1.0 / A

def rank(x):
    if isinstance(x, swiginac.matrix):
        r = x.rows()
        c = x.cols()
        if r > 1 and c > 1: return 2
        if r*c > 1: return 1
    return 0

def shape(x):
    if isinstance(x, swiginac.matrix):
        r = x.rows()
        c = x.cols()
        if r > 1 and c > 1: return (r, c)
        if r*c > 1: return (r*c,)
    return (1,)

def diff(f, x):
    """Returns df/dx, where x is a symbol or symbolic matrix."""
    if type(x) == swiginac.matrix:
        f_rank = rank(f)
        x_rank = rank(x)
        if f_rank+x_rank > 2: raise RuntimeError("Cannot apply diff to f,x with ranks %d,%d" % (f_rank, x_rank))
        # FIXME: support different combinations of ranks for f,x
        A = zeros(x.rows(), x.cols())
        for i in range(A.rows()):
            for j in range(A.cols()):
                sym = x[i,j]
                assert isinstance(sym, swiginac.symbol)
                val = f # [...] FIXME
                A[i,j] = swiginac.diff(val, sym)
        return A
    assert isinstance(x, swiginac.symbol)
    return swiginac.diff(f, x)

def ddx(f, i, GinvT=None):
    """Returns df/dx_i, where i is the number of the coordinate.
       GinvT is an optional geometry mapping."""
    # without mapping:
    if GinvT is None:
        return diff(f, x(i))
    # with mapping:
    dfdx = 0
    for k in range(GinvT.rows()):
        dfdx += GinvT[i, k] * swiginac.diff(f, _p[k])
    return dfdx


def grad(u, GinvT=None):
    """Returns the gradient of u w.r.t. the spacial variables x,y[,z], depending on the dimension specified by SyFi.initSyFi(nsd)."""
    
    ran = rank(u)
    sh  = shape(u)

    if ran == 0:
        n = SyFi.cvar.nsd
        r, c = n, 1
        u = swiginac.matrix(1,1,[u])
    elif ran == 1:
        n = sh[0]
        r, c = n, n
    else:
        raise ValueError("Taking gradient of a matrix is not supported.")

    if GinvT is None:
        GinvT = Id(n)

    g = zeros(r, c)
    for i in range(r):
        for j in range(c):
            f = u[j,0]
            for k in range(n):
                g[i,j] += GinvT[i,k] * diff(f, _p[k])
    return g


def div(u, GinvT=None):
    r = u.rows()
    c = u.cols()
    ran = rank(u)
    if ran == 2:
        if r != c:
            raise RuntimeError("Taking divergence of an unsymmetric matrix.")
        n = r
        if GinvT is None:
            GinvT = Id(n)
        d = zeros(r, 1)
        for j in range(r):
            for i in range(r):
                for k in range(r):
                    d[j,0] += GinvT[i,k] * diff(u[i,j], _p[k])
    elif ran == 1:
        n = r*c
        if GinvT is None:
            GinvT = Id(n)
        d = 0
        for i in range(n):
            for k in range(n):
                d += GinvT[i,k] * diff(u[i,0], _p[k])
    else:
        raise RuntimeError("Taking divergence of a scalar.")
    return d


def curl(u, GinvT=None):
    g = grad(u, GinvT)
    c = swiginac.matrix(2,1)
    c[0,0] = -g[1,0]
    c[1,0] =  g[0,0]
    return c


def laplace(u, GinvT=None):
    return div(grad(u, GinvT), GinvT)



if __name__ == '__main__':
    a = swiginac.matrix(3,1,[1,0,0])
    b = swiginac.matrix(3,1,[0,1,0])
    c = swiginac.matrix(3,1,[0,0,1])

    print ""
    print cross(a,b)
    print cross(a,c)
    print cross(c,b)
    print cross(b,c)

    import SyFi
    SyFi.initSyFi(2)

    x, y, z = symbols(["x", "y", "z"])
    x2, y2, z2 = symbols(["x", "y", "z"])
    assert (x-x2).is_zero()
    assert (y-y2).is_zero()
    assert (z-z2).is_zero()

    p2 = swiginac.matrix(2, 1, [x,y])
    p3 = swiginac.matrix(3, 1, [x,y,z])


    A = swiginac.matrix(2,2, [x*x*x, x*x*y, x*y*y, y*y*y])
    sh = shape(A)
    assert rank(A) == 2
    assert len(sh) == 2
    assert sh[0] == 2
    assert sh[1] == 2
    print ""
    print "A", A
    print "div  A", div(A)


    v = swiginac.matrix(2, 1, [3*x,5*y])
    sh = shape(v)
    assert rank(v) == 1
    assert len(sh) == 1
    assert sh[0] == 2
    print ""
    print "v", v
    print "div  v", div(v)
    print "grad  v", grad(v)
    print "dv/dx  ", diff(v, p2)


    v = swiginac.matrix(2, 1, [3*y,5*x])
    sh = shape(v)
    assert rank(v) == 1
    assert len(sh) == 1
    assert sh[0] == 2
    print ""
    print "v", v
    print "div  v", div(v)
    print "grad  v", grad(v)
    print "dv/dx  ", diff(v, p2)

