#[compute]
#version 450

vec3 colorize( vec3 original, vec3 color ) 
{
  float grey = ( original.r + original.g + original.b ) / 3.0;
  
  return grey < 0.5 ? mix( vec3( 0.0 ), color, grey * 2.0 ) : 
                      mix( color, vec3( 1.0 ), ( grey - 0.5 ) * 2.0 );
}

float clamp01( float value )
{
  return clamp( value, 0.0, 1.0 );
}

float normalizeToRange( float value, float min, float max )
{
  return ( value - min ) / ( max - min );
}

float normalizeToRange01( float value, float min, float max )
{
  return clamp01( normalizeToRange( value, min, max ) );
}

float map( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange( value, inMin, inMax ) );
}

float mapClamped( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange01( value, inMin, inMax ) );
}

vec2 rotate_v2( vec2 uv, float angle )
{
  float s = sin( angle );
  float c = cos( angle );

    
  float x = uv.x;
  float y = uv.y; 
 
  uv.x = c * x - s * y;
  uv.y = s * x + c * y;

  return uv;
}

vec2 rotateAround_v2( vec2 uv, float angle, vec2 pivot )
{
  uv -= pivot;
  uv =  rotate_v2( uv, angle );
  uv += pivot;

  return uv;
}


layout( local_size_x = 8, local_size_y = 8, local_size_z = 1 ) in;

layout( rgba16f, set = 0, binding = 0 ) 
uniform image2D inputImage;

layout( rgba16f, set = 1, binding = 0 ) 
uniform image2D outputImage;

// layout( rgba16f, set = 2, binding = 0 ) 
// uniform image2D noiseImage;

layout( push_constant, std430 ) 
uniform Parameters 
{
  float r;
  float g;
  float b;
  float a;
  float replace;
  float add;
  float multiply;
  float colorize;

  float centerX;
  float centerY;
  float fade;

} parameters;

float PI = 3.141592653589793;


void main( ) 
{
  ivec2 size = imageSize( inputImage );
  ivec2 texel_coord = ivec2( gl_GlobalInvocationID.xy );
  
  if ( any( greaterThanEqual( texel_coord, size ) )  ) 
  {
    return;
  }    

  vec2 uv = ( vec2( texel_coord ) + 0.5 ) / vec2( size );

  float aspect = float( size.y ) / float( size.x );
  vec2 circleUV = uv;
  circleUV.y -= 0.5;
  circleUV.y *= aspect;
  circleUV.y += 0.5;

  float alpha = 0.0;
  vec2 center = vec2( parameters.centerX, parameters.centerY );

  float d0 = distance( center, vec2(0.0, 0.0) );
  float d1 = distance( center, vec2(1.0, 0.0) );
  float d2 = distance( center, vec2(0.0, 1.0) );
  float d3 = distance( center, vec2(1.0, 1.0) );

  float max_radius = max( max( d0, d1 ), max( d2, d3 ) );

  float r = distance( circleUV, center );
  float fadeValue = 1.0 - r / max_radius;

  float fadeStart = parameters.a - parameters.fade * 2.0;
  float fadeEnd   = parameters.a + parameters.fade * 2.0;

  float minFade = 0 + parameters.fade * 2.0;
  float maxFade = 1 - parameters.fade * 2.0;

  fadeValue = mapClamped( fadeValue, 0.0, 1.0, minFade, maxFade );

  alpha = mapClamped( fadeValue, fadeStart, fadeEnd, 1.0, 0.0 );

  vec4 color = imageLoad( inputImage, ivec2( uv * size ) );
  vec4 replaced = vec4( parameters.r, parameters.g, parameters.b, 1.0 );   

  vec4 multiplied = replaced * color;
  vec4 added = replaced + color;
  vec4 colorized = vec4( colorize( color.rgb, replaced.rgb ), 1.0 );


  vec4 mixed = replaced * parameters.replace + 
                added * parameters.add + 
                multiplied * parameters.multiply + 
                colorized * parameters.colorize;


  
  
  vec4 blended = mix( color, mixed, alpha );    


  imageStore( outputImage, texel_coord, blended );
}