package montecarlo;

import java.util.Random;

/**
 * Collection of methods for performing numerical quadrature on a given function
 */

public class Quadrature {
    
    private Function integrand; //function being integrated
    private double xMin, xMax;  //lower and upper limit of integration
    private double[] x;         //x values used in quadrature
    private int nPoints;        //number of quadrature points (function evaluations)
    
    public Quadrature(Function i) {
        setFunction(i);
        setN(20);
        setXMin(0.0);
        setXMax(1.0);
    }
    
    public void setFunction(Function i) {integrand = i;}
    
    public void setN(int n) {nPoints = n; x = new double[nPoints];}
    public int getN() {return nPoints;}
    
    public double[] getX() {return x;}
    
    public void setXMin(double xmin) {xMin = xmin;}
    public double getXMin() {return xMin;}
    public void setXMax(double xmax) {xMax = xmax;}
    public double getXMax() {return xMax;}
    
    public double rectangle(int n) {setN(n); return rectangle();}
    public double rectangle() {  //rectangle-rule quadrature formula
        double deltaX = (xMax-xMin)/(double)nPoints;
        double sum = 0.0;
        for(int i=0; i<nPoints; i++) {
            x[i] = xMin + (i+0.5)*deltaX;
            sum += integrand.f(x[i]);
        }
        return sum*deltaX;
    }
    public double trapezoid(int n) {setN(n); return trapezoid();}
    public double trapezoid() {  //trapezoid-rule quadrature formula
        double deltaX = (xMax-xMin)/((double)nPoints-1.0);
        double sum = 0.0;
        for(int i=0; i<nPoints; i++) {
            x[i] = xMin + i*deltaX;
            sum += (i==0 || i==nPoints-1) ? 0.5*integrand.f(x[i]) : integrand.f(x[i]);
        }
        return sum*deltaX;
    }
    
    public double simpleMC(int n) {setN(n); return simpleMC();}
    public double simpleMC() {  //unbiased Monte Carlo quadrature
        Random rand = new MyRandom();
        double deltaX = (xMax-xMin);
        double sum = 0.0;
        for(int i=0; i<nPoints; i++) {
            x[i] = xMin + deltaX*rand.nextDouble();
            sum += integrand.f(x[i]);
        }
        return sum*deltaX/(double)nPoints;
    }
    
    public double importanceMC(int n, Random rand, Function w) {setN(n); return importanceMC(rand, w);}
    public double importanceMC(Random rand, Function w) {
        double deltaX = (xMax-xMin);
        double sum = 0.0;
        for(int i=0; i<nPoints; i++) {
            x[i] = xMin + deltaX*rand.nextDouble();
            sum += integrand.f(x[i])/w.f(x[i]);
        }
        return sum*deltaX/(double)nPoints;
    }
    
    //Importance sampling quadrature using a Markov chain
    public double markovMC(int n, double x, double step, Function w) {setN(n); return markovMC(x, step, w);}
    public double markovMC(double x, double step, Function w) {
        Random rand = new MyRandom();
        double deltaX = (xMax - xMin);
        double sum = 0.0;
        for(int i=0; i<nPoints; i++) {
           double xTrial = x + (2.*rand.nextDouble()-1.)*step;
           while(xTrial > xMax) xTrial -= deltaX;
           while(xTrial < xMin) xTrial += deltaX;
           double R = w.f(xTrial)/w.f(x);
           if(R > 1.0 || R > rand.nextDouble()) x = xTrial;
           sum += integrand.f(x)/w.f(x);
        }
        return sum*deltaX/(double)nPoints;
    }
}
            
        