from numpy.fft import rfft, irfft
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
from math import sin, cos, pi, exp
from functools import partial

##Matplotlib Animation Example
##
##author: Jake Vanderplas
##email: vanderplas@astro.washington.edu
##website: http://jakevdp.github.com
##license: BSD
##Please feel free to use and modify this, but keep the above information. Thanks!

a, b = 0.0, 20.0    # starting and ending positions on the medium
L = b - a           # length of the medium (perhaps a string)
N = 1000            # divisions of the medium into segments
dx = L/N
v = 5.0             # propagation speed of waves in the medium
delay = 20          # milliseconds

# 1-D Discrete Sine Transform Type-I (whatever that is)
def dst(y):
    N = len(y)
    y2 = np.empty(2*N, float)
    y2[0] = y2[N] = 0.0
    y2[1:N] = y[1:]
    y2[:N:-1] = -y[1:]
    a = -np.imag(rfft(y2))[:N]
    a[0] = 0.0
    return a

# 1-D Inverse Discrete Sine Transform Type-I
def idst(a):
    N = len(a)
    c = np.empty(N+1,complex)
    c[0] = c[N] = 0.0
    c[1:N] = -1j*a[1:]  # j = sqrt(-1)
    y = irfft(c)[:N]
    y[0] = 0.0
    return y

def standingwave(x, Amp, a, b, n):   # Finds y and v_y (always zero) at the given x
    # at time t = 0 for a standing sin wave of amplitude A with
    # n half-wavelengths in length (b-a)
    L = b - a
    u = x - a
    k = pi * n/L    # wavenumber, radians/m
    y = Amp * sin(k * u)
    return y, 0.0   # displacement, velocity of the medium at t = 0


##def readnums(fname):   # Opens csv file named fname and
##    # reads data (one value per line) from it.
##    # Returns array of the data and its dimension.
##    xarr, N = None, None    # Initialize in case they don't fill
##    x, y, = [],[]
##    with open(fname) as infile:     # Open the input file
##        for line in infile:         # Read line by line
##            line.strip()            # Get rid of nonprinting characters        
##            x.append(float(line))
##        xarr = np.copy(x)
##        N = len(x)
##    return N, xarr    # Number of data points, x- and y-data

def y_of_it(t, omega, c, d):    # Once the starting FTs are done, performs an
    # inverse FT on the coefficients to generate the solution at time t
    # omega contains the angular frequencies of the wave functions
    # c is the sine FT of the initial positions y(x, t=0)
    # d is the sine FT of the initial velocities v_y(x, t=0)
    q = t * omega
    e = np.empty(N, float)
    e[0] = 0.0
    e[1:] = (c[1:]*np.cos(q[1:]) + np.sin(q[1:])*d[1:]/omega[1:])
        # generate the coefficients for the FT of the solution at time t
    y_t = idst(e)   # Inverse FT: reconstitute the solution at time t
    return y_t

x = np.empty(N,float)
y0 = np.empty(N,float)
vy0 = np.empty(N,float)
omega = np.empty(N, float)
omcoeff = pi*v/L

for i in range(N):  # initialize the arrays of x, y, and vy
    xi = a + i * dx
    x[i] = xi
    # initial wave form and transverse velocity
    yi, vyi = standingwave(xi, 1., a, b, 3)
    y0[i] = yi
    vy0[i] = vyi
    omega[i] = omcoeff * i
    
c = dst(y0)     # FT sine transform of the initial positions
d = dst(vy0)    # FT sine transform of the initial velocities

plt.plot(abs(c))
plt.ylabel("coefficient")
plt.xlabel("number")
plt.suptitle("FT of initial positions")
plt.show()

plt.plot(abs(d))
plt.ylabel("coefficient")
plt.xlabel("number")
plt.suptitle("FT of initial velocities")
plt.show()

times3 = np.arange(0.0, 2.0*L/v, delay/1000)    # forward and back
times = [0.0]
for t in times:     # Static plot of x at time t = 0
    # Calculate displacements y at all x values at time t
    y = y_of_it(t, omega, c, d)     # Perform the inverse FT for time t
    # to find y at all values of x
    plt.plot(x,y)
    plt.ylabel("amplitude")
    plt.xlabel("position, m")
    plt.suptitle("time = "+str(t)+" s")
    plt.show()
    
# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure()
ax = plt.axes(xlim=(a, b), ylim=(-1.10, 1.10))
line, = ax.plot([], [], lw=2)

# initialization function: plot the background of each frame
def init():
    line.set_data([], [])
    return line,

y_t = partial(y_of_it, omega=omega, c=c, d=d) # Partial function: calls
    # the function y_of_it with an abbreviated set of arguments  Function
    # y_of_it uses four arguments; y_t only t (= time)
def animate(t):     # Calls the inverse FT and makes a plot line
    yt = y_t(t)     # Find the y's from time t
##    print("Animating time",t)
    line.set_data(x, yt)
    return line,
##print("Calling animate function.")
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=times3, interval=delay,
                               blit=True, repeat=True)
plt.show()
print("Done.")
