#[compute]
#version 450

vec3 colorize( vec3 original, vec3 color ) 
{
  float grey = ( original.r + original.g + original.b ) / 3.0;
  
  return grey < 0.5 ? mix( vec3( 0.0 ), color, grey * 2.0 ) : 
                      mix( color, vec3( 1.0 ), ( grey - 0.5 ) * 2.0 );
}

float clamp01( float value )
{
  return clamp( value, 0.0, 1.0 );
}

float normalizeToRange( float value, float min, float max )
{
  return ( value - min ) / ( max - min );
}

float normalizeToRange01( float value, float min, float max )
{
  return clamp01( normalizeToRange( value, min, max ) );
}

float map( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange( value, inMin, inMax ) );
}

float mapClamped( float value, float inMin, float inMax, float outMin, float outMax )
{
  return mix( outMin, outMax, normalizeToRange01( value, inMin, inMax ) );
}




float random( vec2 uv ) 
{
  return fract( sin( dot( uv.xy, vec2( 12.9898, 78.233 ) ) ) * 43758.5453123 );
}

vec2 random_v2( vec2 uv ) 
{
  uv = vec2
  ( 
    dot(uv, vec2( 127.1,311.7 ) ),
    dot(uv, vec2( 269.5,183.3 ) ) 
  );
  
  return -1.0 + 2.0 * fract( sin( uv ) * 43758.5453123 );
}


vec3 random_v3( vec3 uvw )
{
	
  uvw = vec3( dot(uvw, vec3(127.1,311.7, 513.7) ),
              dot(uvw, vec3(269.5,183.3, 396.5) ),
        dot(uvw, vec3(421.3,314.1, 119.7) ) );
    
  return -1.0 + 2.0 * fract(sin(uvw) * 43758.5453123);
}

float perlin(vec2 uv) 
{
  vec2 uv_index = floor(uv);
  vec2 uv_fract = fract(uv);

  vec2 blur = smoothstep(0.0, 1.0, uv_fract);

  return mix( mix( dot( random_v2(uv_index + vec2(0.0,0.0) ), uv_fract - vec2(0.0,0.0) ),
                    dot( random_v2(uv_index + vec2(1.0,0.0) ), uv_fract - vec2(1.0,0.0) ), blur.x),
              mix( dot( random_v2(uv_index + vec2(0.0,1.0) ), uv_fract - vec2(0.0,1.0) ),
                    dot( random_v2(uv_index + vec2(1.0,1.0) ), uv_fract - vec2(1.0,1.0) ), blur.x), blur.y) + 0.5;
}

float perlinOctaves( vec2 uv, int octaves, float scale, float gain )
{
  float s = 1.0;
  float g = 1.0;

  float v = perlin( uv * s ) * g;
  float n = 1.0;

  for ( int i = 0; i < octaves; i++ )
  {
    s *= scale;
    g *= gain;
     
    v += perlin( uv * s ) * g;
    n += g;
  }

  return v / n;
}

vec2 seamLessCoordinate( vec2 uv, vec2 seamRange, vec2 seamFade, vec2 type )
{
  return mod( uv - seamFade * type, seamRange );
}

float perlinOctavesSeamless( vec2 uv, int octaves, float scale, float gain, vec2 seamRange, vec2 seamFade )
{
  // uv = mod( uv , seamRange );
  seamFade *= seamRange;
  vec2 fading = mod( uv, seamRange ) ;
  fading.x = normalizeToRange01( fading.x, seamRange.x - seamFade.x, seamRange.x );
  fading.y = normalizeToRange01( fading.y, seamRange.y - seamFade.y, seamRange.y );


  // 0, seamRange - seamFade, seamRange
  //

  vec2 p00 = seamLessCoordinate( uv, seamRange, seamFade, vec2( 0.0, 0.0 ) );
  vec2 p10 = seamLessCoordinate( uv, seamRange, seamFade, vec2( 1.0, 0.0 ) );
  vec2 p01 = seamLessCoordinate( uv, seamRange, seamFade, vec2( 0.0, 1.0 ) );
  vec2 p11 = seamLessCoordinate( uv, seamRange, seamFade, vec2( 1.0, 1.0 ) );

  float n00 = perlinOctaves( p00, octaves, scale, gain );
  float n10 = perlinOctaves( p10, octaves, scale, gain );
  float n01 = perlinOctaves( p01, octaves, scale, gain );
  float n11 = perlinOctaves( p11, octaves, scale, gain );

  float n0 = mix( n00, n10, fading.x );
  float n1 = mix( n01, n11, fading.x );

  return mix( n0, n1, fading.y );
}



layout( local_size_x = 8, local_size_y = 8, local_size_z = 1 ) in;

layout( rgba16f, set = 0, binding = 0 ) 
uniform image2D inputImage;

layout( rgba16f, set = 1, binding = 0 ) 
uniform image2D outputImage;

// layout( rgba16f, set = 2, binding = 0 ) 
// uniform image2D noiseImage;

layout( push_constant, std430 ) 
uniform Parameters 
{
  float r;
  float g;
  float b;
  float a;
  float replace;
  float add;
  float multiply;
  float colorize;

  float noiseAmount;
  float noiseScale;
  float noiseX;
  float noiseY;

} parameters;

void main( ) 
{
    ivec2 size = imageSize( inputImage );
    ivec2 texel_coord = ivec2( gl_GlobalInvocationID.xy );
    
    if ( any( greaterThanEqual( texel_coord, size ) )  ) 
    {
      return;
    }

    vec2 uv = ( vec2( texel_coord ) + 0.5 ) / vec2( size );
    int noiseOctaves = 3;
    float noiseScale = 0.5;
    float noiseGain = 0.5;
    float range = 100.0;
    float fade = 20.0;
   
    vec2 ratio = vec2( size ) / min( size.x, size.y );
  

    // float noise = perlinOctavesSeamless( 
    //   uv * parameters.noiseScale * ratio + vec2( parameters.noiseX, parameters.noiseY ),
    //   noiseOctaves, noiseScale, noiseGain, 
    //   vec2( range ), vec2( fade )
    // );

    float noise = perlin( uv * parameters.noiseScale * ratio + vec2( parameters.noiseX, parameters.noiseY ) );
    // float noise = imageLoad( inputImage, ivec2( uv * imageSize( noiseImage ) ) ).r;

    vec4 color = imageLoad( inputImage, ivec2( uv * size ) );
    vec4 replaced = vec4( parameters.r, parameters.g, parameters.b, 1.0 );
   

    vec4 multiplied = replaced * color;
    vec4 added = replaced + color;
    vec4 colorized = vec4( colorize( color.rgb, replaced.rgb ), 1.0 );


    vec4 mixed = replaced * parameters.replace + 
                 added * parameters.add + 
                 multiplied * parameters.multiply + 
                 colorized * parameters.colorize;
    
    float noiseAmount = parameters.noiseAmount;
  

    
    float noiseAlpha = 1.0 - smoothstep( parameters.a - 0.05, parameters.a + 0.05, noise );

    float edgeFade = 0.1;
    float startMix = mapClamped( parameters.a, 0.0, edgeFade, 1.0, 0.0 );
    float endMix   = mapClamped( parameters.a, 1.0 - edgeFade, 1.0, 0.0, 1.0 );

    noiseAlpha = mix( noiseAlpha, 0.0, startMix );
    noiseAlpha = mix( noiseAlpha, 1.0, endMix );

    float alpha = mix( parameters.a, noiseAlpha, parameters.noiseAmount );
    
    vec4 blended = mix( color, mixed, alpha );
    // blended = clamp( blended, vec4( 0.0, 0.0, 0.0, 1.0 ), vec4( 1.0 ) );


    imageStore( outputImage, texel_coord, blended );
}