#=------------------------------------------------------------------------------
    Script for the cluster

    Details can be found in the article:
    [ ] A. Laurent
        A uniformly accurate scheme for the numerical integration of penalized Langevin dynamics
        To appear in SIAM J. Sci. Comput.

    PLEASE CITE THE ABOVE PAPER WHEN USING THIS PROGRAM ! :-)
    Version of 10.08.2022

    Uses Julia 1.5.4
=#
#=------------------------------------------------------------------------------
    Manifold/Diffusion parameters
=#
using LinearAlgebra

# Choose experiment number: 1-sphere, 2-torus or 3-orthogonal group
experiment=1

if experiment==1
    # SPHERE
    dim=3
    codim=1

    zeta = x -> (norm(x)^2-1)/2
    g = x -> x     # g=nabla zeta
    Dg = x -> Matrix{Float64}(I,dim,dim)

    G = x -> g(x)'*g(x)
    divg = x -> dim

    # f = x -> -[(x[1]^2+x[2]^2-1),(x[1]^2+x[2]^2-1),x[3]]/(2*norm(x)^2)      # be careful of the explored zone when epsilon~1
    f = x -> -25*[x[1]-1,x[2],x[3]]

    X0=[1.;0.;0.]

    phi = x -> x[1]
elseif experiment==2
    # TORUS
    dim=3
    codim=1

    r=1.
    R=3.
    zeta = x -> (x[1]^2+x[2]^2+x[3]^2+R^2-r^2)^2-4*R^2*(x[1]^2+x[2]^2)
    g = x -> [4*x[1]*(x[1]^2+x[2]^2+x[3]^2-R^2-r^2);
        4*x[2]*(x[1]^2+x[2]^2+x[3]^2-R^2-r^2);
        4*x[3]*(x[1]^2+x[2]^2+x[3]^2+R^2-r^2)]      # g=nabla zeta
    Dg = x -> [4*(3*x[1]^2+x[2]^2+x[3]^2-R^2-r^2) 8*x[1]*x[2] 8*x[1]*x[3];
        8*x[1]*x[2] 4*(x[1]^2+3*x[2]^2+x[3]^2-R^2-r^2) 8*x[2]*x[3];
        8*x[1]*x[3] 8*x[2]*x[3] 4*(x[1]^2+x[2]^2+3*x[3]^2+R^2-r^2)]

    G = x -> g(x)'*g(x)
    divg = x -> tr(Dg(x))

    f = x -> -25*[x[1]-R+r,x[2],x[3]]

    X0=[R-r;0.;0.]

    phi = x -> norm(x)^2
elseif experiment==3
    # SO(n)
    n_SO=3
    dim=n_SO^2
    codim=div(n_SO*(n_SO+1),2)
    List_index=[[p1,p2] for p1 in 1:n_SO for p2 in 1:n_SO]

    zeta = x -> (X=reshape(x,n_SO,n_SO);Y=X'*X-Matrix{Float64}(I,n_SO,n_SO);[Y[i,j] for i in 1:n_SO for j in i:n_SO])
    g = x -> (X=reshape(x,n_SO,n_SO);Y=zeros(dim,codim);
    for p in List_index
        p1=p[1];p2=p[2];
        Y[(p1-1)*n_SO+p2,:]=reshape([X[p1,i]*(j==p2)+X[p1,j]*(i==p2) for i in 1:n_SO for j in i:n_SO],1,codim)
    end;
    Y)
    Dg = x -> (h -> (H=reshape(h,n_SO,n_SO);Y=zeros(dim,codim);
    for p in List_index
        p1=p[1];p2=p[2];
        Y[(p1-1)*n_SO+p2,:]=reshape([H[p1,i]*(j==p2)+H[p1,j]*(i==p2) for i in 1:n_SO for j in i:n_SO],1,codim)
    end;
    Y))

    G = x -> g(x)'*g(x)
    divg=[2*(i==j) for i in 1:n_SO for j in i:n_SO]

    f = x -> -100*(x-reshape(Matrix{Float64}(I,n_SO,n_SO),dim))
    # f = x -> (X=reshape(x,n_SO,n_SO);-det(X)*reshape(inv(X'),dim))

    X0=reshape(Matrix{Float64}(I,n_SO,n_SO),dim)

    phi = x -> tr(reshape(x,n_SO,n_SO))
end

sigma=sqrt(2)
#=------------------------------------------------------------------------------
    Numerical parameters of the problem
=#
T=10
Liste_N=[floor(Int,2^i*T) for i in 6:10]
Liste_h=T./Liste_N[1:length(Liste_N)]
Liste_eps=[0.0001,0.001,0.01,0.1,1.0,10.0]
N_exa=floor(Int,2^12*T)

# T=1
# Liste_N=[floor(Int,2^i*T) for i in 7:7]
# Liste_h=T./Liste_N[1:length(Liste_N)]
# Liste_eps=[0.005]
# N_exa=floor(Int,2^9*T)


tol=1e-8
maxiter=20

N_trajectories=10
#=------------------------------------------------------------------------------
     Functions
=#
using Random
using DelimitedFiles

include("Functions.jl")
#=------------------------------------------------------------------------------
     Initialize random number generator
=#
if length(ARGS)==0
    Random.seed!(0)
else
    Random.seed!(parse(Int64,ARGS[1]))
end
#=------------------------------------------------------------------------------
     Main Script
=#
totaltime=time()

if codim==1
    print("Exact solution\n")
    Xexa=zeros(length(Liste_eps),length(Liste_N))
    for k in 1:length(Liste_eps)
        epsi=Liste_eps[k]
        println("k=",k)
        start = time()
        for m in 1:N_trajectories
            Xexa[k,1]+=phi(Integrator_UA_codim_1(X0,T/N_exa,epsi,N_exa,tol,maxiter))
        end
        println(time() - start)
    end

    print("\nNumerical integrators\n")
    Xnum_UA=zeros(length(Liste_eps),length(Liste_N))
    Xnum_EEE=zeros(length(Liste_eps),length(Liste_N))
    for k in 1:length(Liste_eps)
        epsi=Liste_eps[k]
        println("k=",k)
        start = time()
        for n in 1:length(Liste_N)
            N=Liste_N[n]
            h=Liste_h[n]
            for m in 1:N_trajectories
                Xnum_UA[k,n]+=phi(Integrator_UA_codim_1(X0,h,epsi,N,tol,maxiter))
                Xnum_EEE[k,n]+=phi(Integrator_Explicit_Euler_codim_1(X0,h,epsi,N))
            end
        end
        println(time() - start)
    end

    print("\nConstrained integrator\n")
    Xnum_EE_Manifold=zeros(length(Liste_N))
    for n in 1:length(Liste_N)
        N=Liste_N[n]
        h=Liste_h[n]
        for m in 1:N_trajectories
            Xnum_EE_Manifold[n]+=phi(Integrator_Euler_Manifold_codim_1(X0,h,N,tol,maxiter))
        end
    end

    Snum=[Xexa;Xnum_UA;Xnum_EEE;Xnum_EE_Manifold']/N_trajectories
    println("\ntotal time = ",time()-totaltime)
else
    print("Exact solution\n")
    Xexa=zeros(length(Liste_eps),length(Liste_N))
    for k in 1:length(Liste_eps)
        epsi=Liste_eps[k]
        println("k=",k)
        local start = time()
        for m in 1:N_trajectories
            Xexa[k,1]+=phi(Integrator_UA_matrix_form(X0,T/N_exa,epsi,N_exa,tol,maxiter))
        end
        println(time() - start)
    end

    print("\nNumerical integrators\n")
    Xnum_UA=zeros(length(Liste_eps),length(Liste_N))
    Xnum_EEE=zeros(length(Liste_eps),length(Liste_N))
    for k in 1:length(Liste_eps)
        epsi=Liste_eps[k]
        println("k=",k)
        local start = time()
        for n in 1:length(Liste_N)
            N=Liste_N[n]
            h=Liste_h[n]
            for m in 1:N_trajectories
                Xnum_UA[k,n]+=phi(Integrator_UA_matrix_form(X0,h,epsi,N,tol,maxiter))
                Xnum_EEE[k,n]+=phi(Integrator_Explicit_Euler(X0,h,epsi,N))
            end
        end
        println(time() - start)
    end

    print("\nConstrained integrator\n")
    Xnum_EE_Manifold=zeros(length(Liste_N))
    local start = time()
    for n in 1:length(Liste_N)
        N=Liste_N[n]
        h=Liste_h[n]
        for m in 1:N_trajectories
            Xnum_EE_Manifold[n]+=phi(Integrator_Euler_Manifold(X0,h,N,tol,maxiter))
        end
    end
    println(time() - start)

    Snum=[Xexa;Xnum_UA;Xnum_EEE;Xnum_EE_Manifold']/N_trajectories
    println("\ntotal time = ",time()-totaltime)
end
#=------------------------------------------------------------------------------
     Storing results
=#
if length(ARGS)==0
    writedlm("./Data/data.txt", Snum)
else
    writedlm(string("./Data/res_",ARGS[1],".txt"), Snum)
end

println("end")
