package simulate;
import java.awt.Graphics;
import java.awt.Color;
import java.util.Random;

public class Space2D extends Space {
    
    public static final int D = 2;
    public final int D() {return D;}
    
    public double sphereVolume(double r) {return Math.PI*r*r;}  //volume of a sphere of radius r
    public double sphereArea(double r) {return 2.0*Math.PI*r;}  //surface area of sphere of radius r (used for differential shell volume)
    public Space.Vector makeVector() {return new Vector();}
    public Space.Coordinate makeCoordinate(Space.Occupant o) {return new Coordinate(o);}
    public Space.CoordinatePair makeCoordinatePair(Space.Boundary boundary) {return new CoordinatePair((Boundary)boundary);}
    public Space.Boundary makeBoundary(int b) {
        switch(b) {
            case(Boundary.NONE):            return new BoundaryNone();
            case(Boundary.PERIODIC_SQUARE): return new BoundaryPeriodicSquare();
            case(Boundary.HARD):            return new BoundaryHard();
            case(Boundary.SLIDING_BRICK):   return new BoundarySlidingBrick();
            default:                        return null;
        }
    }
        
    public final static class Vector extends Space.Vector {  //declared final for efficient method calls
        public static final Random random = new Random();
        public static final Vector ORIGIN = new Vector(0.0,0.0);
        public double x, y;
        public Vector () {x = 0.0; y = 0.0;}
        public Vector (double a1, double a2) {x = a1; y = a2;}
        public double component(int i) {return (i==0) ? x : y;}
        public void setComponent(int i, double d) {if(i==0) x=d; else y=d;}
        public void E(Vector u) {x = u.x; y = u.y;}
        public void E(double a) {x = a; y = a;}
        public void E(int i, double a) {if(i==0) x = a; else y = a;}  //assumes i = 0 or 1
        public void Ea1Tv1(double a1, Space.Vector u) {Vector u1=(Vector)u; x = a1*u1.x; y = a1*u1.y;}
        public void PEa1Tv1(double a1, Space.Vector u) {Vector u1=(Vector)u; x += a1*u1.x; y += a1*u1.y;}
        public void PE(Vector u) {x += u.x; y += u.y;}
        public void ME(Vector u) {x -= u.x; y -= u.y;}
        public void PE(int i, double a) {if(i==0) x += a; else y += a;}
        public void TE(double a) {x *= a; y *= a;}
        public void TE(int i, double a) {if(i==0) x *= a; else y *= a;}
        public void DE(double a) {x /= a; y /= a;}
        public double squared() {return x*x + y*y;}
        public double dot(Vector u) {return x*u.x + y*u.y;}
        public void randomStep(double d) {x += (2.*random.nextDouble()-1.0)*d; y+= (2.*random.nextDouble()-1.0)*d;} //uniformly distributed random step in x and y, within +/- d
        public void setRandom(double d) {x = random.nextDouble()*d; y = random.nextDouble()*d;}
        public void setRandom(double dx, double dy) {x = random.nextDouble()*dx; y = random.nextDouble()*dy;}
        public void setRandom(Vector u) {setRandom(u.x,u.y);}
        public void setRandomCube() {
            x = random.nextDouble() - 0.5; 
            y = random.nextDouble() - 0.5;
        }
        public void setRandomSphere() {
            x = Math.cos(2*Math.PI*random.nextDouble()); 
            y = Math.sqrt(1.0 - x*x);
            if(random.nextDouble() < 0.5) y = -y;
        }
        public void setToOrigin() {x = ORIGIN.x; y = ORIGIN.y;}
        public void randomDirection() {
            x = Math.cos(Math.PI*random.nextDouble()); 
            y = Math.sqrt(1.0 - x*x);
            if(random.nextDouble() < 0.5) y = -y;
        }
        public void E(Space.Vector u) {E((Vector)u);}
        public void PE(Space.Vector u) {PE((Vector)u);}
        public void ME(Space.Vector u) {ME((Vector)u);}
        public void TE(Space.Vector u) {TE((Vector)u);}
        public void DE(Space.Vector u) {DE((Vector)u);}
        public double dot(Space.Vector u) {return dot((Vector)u);}
    }
    
    public static final class CoordinatePair extends Space.CoordinatePair {  
        Coordinate c1;
        Coordinate c2;
        final Boundary boundary;
        final Vector dimensions;   //assumes this is not transferred between phases
        private final Vector dr = new Vector();  //note that dr is not cloned if this is cloned -- should change this if using dr in clones; also this makes cloned coordinatePairs not thread-safe
        private double drx, dry, dvx, dvy;
        public CoordinatePair() {boundary = new BoundaryNone(); dimensions = (Vector)boundary.dimensions();}
        public CoordinatePair(Space.Boundary b) {boundary = (Boundary)b; dimensions = (Vector)boundary.dimensions();}
        public void reset(Space.Coordinate coord1, Space.Coordinate coord2) {  //don't usually use this; instead set c1 and c2 directly, without a cast
            c1 = (Coordinate)coord1;
            c2 = (Coordinate)coord2;
            reset();
        }
        public void reset() {
            dr.x = c2.r.x - c1.r.x;
            dr.y = c2.r.y - c1.r.y;
            boundary.nearestImage(dr);
            drx = dr.x; 
            dry = dr.y;
            r2 = drx*drx + dry*dry;
            double rm1 = c1.parent().rm();
            double rm2 = c2.parent().rm();
            dvx = (rm2*c2.p.x - rm1*c1.p.x);  
            dvy = (rm2*c2.p.y - rm1*c1.p.y);  
        }
        /**
         * Recomputes pair separation, with atom 2 shifted by the given vector
         * Does not apply any PBC, regardless of boundary chosen for space
         */
        public void reset(Space2D.Vector M) {
            dr.x = c2.r.x - c1.r.x + M.x;
            dr.y = c2.r.y - c1.r.y + M.y;
            drx = dr.x;
            dry = dr.y;
            r2 = drx*drx + dry*dry;
        }
        public Space.Vector dr() {return dr;}
        public double dr(int i) {return (i==0) ? drx : dry;}
        public double dv(int i) {return (i==0) ? dvx : dvy;}
        public double v2() {
            return dvx*dvx + dvy*dvy;
        }
        public double vDotr() {
            return drx*dvx + dry*dvy;
        }
        public void push(double impulse) {  //changes momentum in the direction joining the atoms
            c1.p.x += impulse*drx;
            c1.p.y += impulse*dry;
            c2.p.x -= impulse*drx;
            c2.p.y -= impulse*dry;
        }
        public void setSeparation(double r2New) {
            double ratio = c2.parent().mass()*c1.parent().rm();  // (mass2/mass1)
            double delta = (Math.sqrt(r2New/r2()) - 1.0)/(1+ratio);
            c1.r.x -= ratio*delta*drx;
            c1.r.y -= ratio*delta*dry;
            c2.r.x += delta*drx;
            c2.r.y += delta*dry;
            //need call reset?
        }
    }

    static class Coordinate extends Space.Coordinate {
        public final Vector r = new Vector();  //Cartesian coordinates
        public final Vector p = new Vector();  //Momentum vector
        public Coordinate(Space.Occupant o) {super(o);}
        public Space.Vector position() {return r;}
        public Space.Vector momentum() {return p;}
        public double position(int i) {return r.component(i);}
        public double momentum(int i) {return p.component(i);}
        public Space.Vector makeVector() {return new Vector();}
        public final double kineticEnergy(double mass) {return 0.5*p.squared()/mass;}
    } 
    
    public static abstract class Boundary extends Space.Boundary {
        public static final int NONE = 0;
        public static final int PERIODIC_SQUARE = 1;
        public static final int HARD = 2;
        public static final int SLIDING_BRICK = 3;
        public abstract void nearestImage(Vector dr);
        public abstract void centralImage(Vector r);
        public abstract void centralImage(Coordinate c);
    }

    /**
     * Class for implementing no periodic boundary conditions
     */
    protected static final class BoundaryNone extends Boundary {
        private final Vector temp = new Vector();
        private final double[][] shift0 = new double[0][D];
        public final Vector dimensions = new Vector();
        public final Space.Vector dimensions() {return dimensions;}
        public static final Random random = new Random();
        public void nearestImage(Space.Vector dr) {}
        public void centralImage(Space.Vector r) {}
        public void nearestImage(Vector dr) {}
        public void centralImage(Vector r) {}
        public void centralImage(Coordinate c) {}
        public double volume() {return Double.MAX_VALUE;}
        public void inflate(double s) {}
        public double[][] imageOrigins(int nShells) {return new double[0][D];}
        public double[][] getOverflowShifts(Space.Vector rr, double distance) {return shift0;}
        public Space.Vector randomPosition() {  //arbitrary choice for this method in this boundary
            temp.x = random.nextDouble(); 
            temp.y = random.nextDouble(); 
            return temp;
        }
        public void draw(Graphics g, int[] origin, double scale) {}
    }
    
    public static final class BoundaryHard extends BoundaryPeriodicSquare {
        private double collisionRadius = 0.0;
        public BoundaryHard() {super();}
        public BoundaryHard(double lx, double ly) {super(lx,ly);}
        public void nearestImage(Space.Vector dr) {}
        public void centralImage(Space.Vector r) {}
        public void nearestImage(Vector dr) {}
        public void centralImage(Vector r) {}
        public void centralImage(Coordinate c) {}
        public void setCollisionRadius(double d) {collisionRadius = d;}
        public double getCollisionRadius() {return collisionRadius;}

        //the pair passed to this method has both atoms pointing to the same atom
        //use atom1
        public double collisionTime(AtomPair pair) {
            Atom a = pair.atom1();
            Vector r = (Vector)a.coordinate().position();
            Vector p = (Vector)a.coordinate().momentum();
            double tx = (p.x > 0.0) ? (dimensions.x - r.x - collisionRadius)/(p.x*a.rm()) : (-r.x + collisionRadius)/(p.x*a.rm());
            double ty = (p.y > 0.0) ? (dimensions.y - r.y - collisionRadius)/(p.y*a.rm()) : (-r.y + collisionRadius)/(p.y*a.rm());
            return Math.min(tx,ty);
        }
        public void bump(AtomPair pair) {
            Atom a = pair.atom1();
            Vector r = (Vector)a.coordinate().position();
            Vector p = (Vector)a.coordinate().momentum();
            double dx = (p.x > 0.0) ? Math.abs(dimensions.x - r.x - collisionRadius) : Math.abs(-r.x + collisionRadius);
            double dy = (p.y > 0.0) ? Math.abs(dimensions.y - r.y - collisionRadius) : Math.abs(-r.y + collisionRadius);
//            double dx = Math.abs(r.x/dimensions.x-0.5);   //determine which component is further from center
//            double dy = Math.abs(r.y/dimensions.y-0.5);
//            if(dx > dy) {
            if(dx < dy) {
                pAccumulator += 2*Math.abs(p.x);
                p.x *= -1;
            }
            else {
                pAccumulator += 2*Math.abs(p.y);
                p.y *= -1;
            }
        }
        //not yet implemented
        public double lastCollisionVirial() {return 0.0;}

        public double pAccumulator = 0.0;
    }

    /**
     * Class for implementing rectangular periodic boundary conditions
     */
    protected static class BoundaryPeriodicSquare extends Boundary implements PotentialHard {
        private final Vector temp = new Vector();
        public static final Random random = new Random();
        private final double[][] shift0 = new double[0][D];
        private final double[][] shift1 = new double[1][D]; //used by getOverflowShifts
        private final double[][] shift3 = new double[3][D];
        public BoundaryPeriodicSquare() {this(Default.BOX_SIZE,Default.BOX_SIZE);}
        public BoundaryPeriodicSquare(double lx, double ly) {dimensions.x = lx; dimensions.y = ly;}
        public final Vector dimensions = new Vector();
        public final Space.Vector dimensions() {return dimensions;}
        public Space.Vector randomPosition() {
            temp.x = dimensions.x*random.nextDouble(); 
            temp.y = dimensions.y*random.nextDouble(); 
            return temp;}
        public void nearestImage(Space.Vector dr) {nearestImage((Vector)dr);}
        public void nearestImage(Vector dr) {
            dr.x -= dimensions.x * ((dr.x > 0.0) ? Math.floor(dr.x/dimensions.x+0.5) : Math.ceil(dr.x/dimensions.x-0.5));
            dr.y -= dimensions.y * ((dr.y > 0.0) ? Math.floor(dr.y/dimensions.y+0.5) : Math.ceil(dr.y/dimensions.y-0.5));
        }
        public void centralImage(Coordinate c) {centralImage(c.r);}
        public void centralImage(Space.Vector r) {centralImage((Vector)r);}
        public void centralImage(Vector r) {
            r.x -= dimensions.x * ((r.x > 0.0) ? Math.floor(r.x/dimensions.x) : Math.ceil(r.x/dimensions.x-1.0));
            r.y -= dimensions.y * ((r.y > 0.0) ? Math.floor(r.y/dimensions.y) : Math.ceil(r.y/dimensions.y-1.0));
        }
        public void inflate(double scale) {dimensions.TE(scale);}
        public double volume() {return dimensions.x * dimensions.y;}
        
        //"Collision" whenever atom travels half the edge length of the simulation volume
        public double collisionTime(AtomPair pair) {
            Atom a = pair.atom1();
            Vector p = (Vector)a.coordinate().momentum();
            double diameter = ((AtomType.Disk)a.type).diameter();
            //assumes range of potential is .le. diameter, simulation box is square (or x is smaller dimension)
            return 0.5*(dimensions.y-1.0001*diameter)/(a.rm()*Math.sqrt(p.squared()));  
        }
        //No action needed at collision, just want to update neighbor list
        public void bump(AtomPair pair) {}
        public double lastCollisionVirial() {return 0.0;}
        public double energy(AtomPair pair) {return 0.0;}
        public double energyLRC(int n1, int n2, double V) {return 0.0;}
        public boolean overlap(AtomPair pair) {return false;}
        public Simulation getParentSimulation() {return simulation;}
        public void setParentSimulation(Simulation s) {simulation = s;}
        private Simulation simulation;

        public void draw(Graphics g, int[] origin, double scale) {
            g.setColor(Color.gray);
            double toPixels = scale*DisplayConfiguration.SIM2PIXELS;
            g.drawRect(origin[0],origin[1],(int)(toPixels*dimensions.component(0))-1,(int)(toPixels*dimensions.component(1))-1);
            }
        /** Computes origins for periodic images
        */
        public double[][] imageOrigins(int nShells) {
            int nImages = (2*nShells+1)*(2*nShells+1)-1;
            double[][] origins = new double[nImages][D];
            int k = 0;
            for(int i=-nShells; i<=nShells; i++) {
                for(int j=-nShells; j<=nShells; j++) {
                    if(i==0 && j==0) {continue;}
                    origins[k][0] = i*dimensions.x;
                    origins[k][1] = j*dimensions.y;
                    k++;
                }
            }
            return origins;
        }

        /** Returns coordinate shifts needed to draw all images that overflow into central image
         * 0, 1, or 3 shifts may be returned
         */
        public double[][] getOverflowShifts(Space.Vector rr, double distance) {
            Vector r = (Vector)rr;
            int shiftX = 0;
            int shiftY = 0;
            if(r.x-distance < 0.0) {shiftX = +1;}
            else if(r.x+distance > dimensions.x) {shiftX = -1;}
            
            if(r.y-distance < 0.0) {shiftY = +1;}
            else if(r.y+distance > dimensions.y) {shiftY = -1;}
            
            if(shiftX == 0) {
                if(shiftY == 0) {
                    return shift0;
                }
                else {
                    shift1[0][0] = 0.0;
                    shift1[0][1] = shiftY*dimensions.y;
                    return shift1;
                }
            }
            else { //shiftX != 0
                if(shiftY == 0) {
                    shift1[0][0] = shiftX*dimensions.x;
                    shift1[0][1] = 0.0;
                    return shift1;
                }
                else {
                    shift3[0][0] = shiftX*dimensions.x;
                    shift3[0][1] = 0.0;
                    shift3[1][0] = 0.0;
                    shift3[1][1] = shiftY*dimensions.y;
                    shift3[2][0] = shift3[0][0];
                    shift3[2][1] = shift3[1][1];
                    return shift3;
                }
            }
        } //end of getOverflowShifts
    }  //end of BoundaryPeriodicSquare
    
    public static final class BoundarySlidingBrick extends BoundaryPeriodicSquare {
        private double gamma = 0.0;
        private double delvx;
        private IntegratorMD.ChronoMeter timer;
        public void setShearRate(double g) {gamma = g; computeDelvx();}
        public double getShearRate() {return gamma;}
        private void computeDelvx() {delvx = gamma*dimensions.y;}
        
        public void setTimer(IntegratorMD.ChronoMeter t) {timer = t;}
        
        public void nearestImage(Vector dr) {
            double delrx = delvx*timer.currentValue();
            double cory = ((dr.y > 0.0) ? Math.floor(dr.y/dimensions.y+0.5) : Math.ceil(dr.y/dimensions.y-0.5));
            dr.x -= cory*delrx;
            dr.x -= dimensions.x * ((dr.x > 0.0) ? Math.floor(dr.x/dimensions.x+0.5) : Math.ceil(dr.x/dimensions.x-0.5));
            dr.y -= dimensions.y * cory;
        }
        public void centralImage(Vector r) {
            double delrx = delvx*timer.currentValue();
            double cory = ((r.y > 0.0) ? Math.floor(r.y/dimensions.y) : Math.ceil(r.y/dimensions.y-1.0));
//            if(cory != 0.0) System.out.println(delrx*cory);
            r.x -= cory*delrx;
            r.x -= dimensions.x * ((r.x > 0.0) ? Math.floor(r.x/dimensions.x) : Math.ceil(r.x/dimensions.x-1.0));
            r.y -= dimensions.y * cory;
        }
        public void centralImage(Coordinate c) {
            Vector r = c.r;
            double cory = ((r.y > 0.0) ? Math.floor(r.y/dimensions.y) : Math.ceil(r.y/dimensions.y-1.0));
            double corx = ((r.x > 0.0) ? Math.floor(r.x/dimensions.x) : Math.ceil(r.x/dimensions.x-1.0));
            if(corx==0.0 && cory==0.0) return;
            double delrx = delvx*timer.currentValue();
            Vector p = c.p;
            r.x -= cory*delrx;
            r.x -= dimensions.x * corx; 
            r.y -= dimensions.y * cory;
            p.x -= cory*delvx;
        }
        
        public double[][] imageOrigins(int nShells) {
            int nImages = (2*nShells+1)*(2*nShells+1)-1;
            double[][] origins = new double[nImages][D];
            int k = 0;
            for(int i=-nShells; i<=nShells; i++) {
                for(int j=-nShells; j<=nShells; j++) {
                    if(i==0 && j==0) {continue;}
                    origins[k][0] = i*dimensions.x + j*delvx*timer.currentValue();
                    origins[k][1] = j*dimensions.y;
                    k++;
                }
            }
            return origins;
        }
    }
            
}