#include "timestep.h"
#include "variables.h"
#include "grid.h"
#include "exit-codes.h"
#include <stdio.h>
#include <math.h>

#ifndef SOLVER
#define SOLVER

ConservVars
cell_flux(ConservVars vars)
{
  if (vars.u1 == 0)
    {
      perror("cell_flux: u1 == 0\n");
      exit(EXIT_U1_EQ_0);
    }
  if (vars.u3 == 0)
    {
      perror("cell_flux: u3 == 0\n");
      exit(EXIT_U3_EQ_0);
    }
  return (ConservVars){
    vars.u2, pow(vars.u2,2)/vars.u1 + pow(sos,2)/kappa*vars.u1,
    vars.u4, pow(vars.u4,2)/vars.u3 + pow(sos,2)/kappa * vars.u1*vars.u3/(u3def-vars.u3) + sigma
      };
};

ConservVars
count_left_flux(Grid1D grid, unsigned int i, double tau)
{
  ConservVars left_shot;
  ConservVars right_shot;
  
  if (i == 0) {
    left_shot = grid.lbc.vars;
    right_shot = grid.cells[i].vars;
  } else if (i == grid.N-1) {
    left_shot = grid.cells[i-1].vars;
    right_shot = grid.rbc.vars;
  } else {
    left_shot = grid.cells[i-1].vars;
    right_shot = grid.cells[i].vars;
  };
    
  return 0.5*(cell_flux(left_shot)+cell_flux(right_shot)) - 0.5*grid.h*(right_shot-left_shot)/tau;
 };
	   
Grid1D
make_step(Grid1D grid, double tau)
{
  Grid1D newgrid = grid_skel(grid);
  newgrid.iter += 1;
  newgrid.time += tau;

  double h = grid.h;
  double N = grid.N;

  ConservVars lf;
  ConservVars rf = count_left_flux(grid,0,tau);
  for (unsigned int i=0; i<N; i++)
    {
      lf = rf;
      if (i != N-1)
	rf = count_left_flux(grid,i+1,tau);
      newgrid.cells[i].vars = grid.cells[i].vars - tau*(rf-lf)/h;
    };
  free(grid.cells);
  return newgrid;
 };

#endif // SOLVER