#[compute]
#version 450

vec3 colorize( vec3 original, vec3 color ) 
{
  float grey = ( original.r + original.g + original.b ) / 3.0;
  
  return grey < 0.5 ? mix( vec3( 0.0 ), color, grey * 2.0 ) : 
                      mix( color, vec3( 1.0 ), ( grey - 0.5 ) * 2.0 );
}
 
float clamp01( float value )
{
  return clamp( value, 0.0, 1.0 );
}

float normalizeToRange( float value, float min, float max )
{
  return ( value - min ) / ( max - min );
}

float normalizeToRange01( float value, float min, float max )
{
  return clamp01( normalizeToRange( value, min, max ) );
}

float map( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange( value, inMin, inMax ) );
}

float mapClamped( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange01( value, inMin, inMax ) );
}

vec2 rotate_v2( vec2 uv, float angle )
{
  float s = sin( angle );
  float c = cos( angle );

    
  float x = uv.x;
  float y = uv.y; 
 
  uv.x = c * x - s * y;
  uv.y = s * x + c * y;

  return uv;
}

vec2 rotateAround_v2( vec2 uv, float angle, vec2 pivot )
{
  uv -= pivot;
  uv =  rotate_v2( uv, angle );
  uv += pivot;

  return uv;
}


layout( local_size_x = 8, local_size_y = 8, local_size_z = 1 ) in;

layout( rgba16f, set = 0, binding = 0 ) 
uniform image2D inputImage;

layout( rgba16f, set = 1, binding = 0 ) 
uniform image2D outputImage;

// layout( rgba16f, set = 2, binding = 0 ) 
// uniform image2D noiseImage;

layout( push_constant, std430 ) 
uniform Parameters 
{
  float r;
  float g;
  float b;
  float a;
  float replace;
  float add;
  float multiply;
  float colorize;
  
  float angle;
  float fade;
  float doubleSided;

} parameters;

float PI = 3.141592653589793;

float computeFadeValue( vec2 uv, float angle )
{
  float minFade = 0 + parameters.fade;
  float maxFade = 1 - parameters.fade;

  float angle_rad = angle * PI / 180.0;
  vec2 dir = vec2( cos( angle_rad), sin( angle_rad ) );

  float proj = dot( uv - vec2(0.0), dir);

  float p0 = dot( vec2( 0.0, 0.0 ), dir );
  float p1 = dot( vec2( 1.0, 0.0 ), dir );
  float p2 = dot( vec2( 0.0, 1.0 ), dir );
  float p3 = dot( vec2( 1.0, 1.0 ), dir );

  float min_p = min( min( p0, p1 ), min( p2, p3) );
  float max_p = max( max( p0, p1 ), max( p2, p3) );

  float fadeValue = ( proj - min_p ) / ( max_p - min_p );

  fadeValue = mapClamped( fadeValue, 0.0, 1.0, minFade, maxFade );

  return fadeValue;
}

void main( ) 
{
  ivec2 size = imageSize( inputImage );
  ivec2 texel_coord = ivec2( gl_GlobalInvocationID.xy );
  
  if ( any( greaterThanEqual( texel_coord, size ) )  ) 
  {
    return;
  }    

  vec2 uv = ( vec2( texel_coord ) + 0.5 ) / vec2( size );
  
  float alpha = 0.0;

  if ( parameters.doubleSided > 0.5 )
  {
    float fadeValueA = computeFadeValue( uv, parameters.angle );
    float fadeValueB = computeFadeValue( uv, parameters.angle + 180.0 );

    float fadeRange = 0.5 + parameters.fade;
    float fadeStart = parameters.a * fadeRange - parameters.fade * 0.5;
    float fadeEnd   = parameters.a * fadeRange + parameters.fade * 0.5;;

    float alphaA = mapClamped( fadeValueA, fadeStart, fadeEnd, 1.0, 0.0 );
    float alphaB = mapClamped( fadeValueB, fadeStart, fadeEnd, 1.0, 0.0 );

    alpha = max( alphaA, alphaB );
  }
  else
  {
    float fadeValue = computeFadeValue( uv, parameters.angle );
  
    float fadeStart = parameters.a - parameters.fade;
    float fadeEnd   = parameters.a + parameters.fade;

    alpha = mapClamped( fadeValue, fadeStart, fadeEnd, 1.0, 0.0 );
  }


  vec4 color = imageLoad( inputImage, ivec2( uv * size ) );
  vec4 replaced = vec4( parameters.r, parameters.g, parameters.b, 1.0 );   

  vec4 multiplied = replaced * color;
  vec4 added = replaced + color;
  vec4 colorized = vec4( colorize( color.rgb, replaced.rgb ), 1.0 );


  vec4 mixed = replaced * parameters.replace + 
                added * parameters.add + 
                multiplied * parameters.multiply + 
                colorized * parameters.colorize;


  
  
  vec4 blended = mix( color, mixed, alpha );    


  imageStore( outputImage, texel_coord, blended );
}