#[compute]
#version 450

layout( local_size_x = 8, local_size_y = 8, local_size_z = 1 ) in;

layout( rgba16f, set = 0, binding = 0 ) 
uniform image2D inputImage;

layout( rgba16f, set = 1, binding = 0 ) 
uniform image2D outputImage;


layout( push_constant, std430 ) 
uniform Parameters 
{
  float amount;
  float maxDifference;

} parameters;


void main( ) 
{
  ivec2 size = imageSize( inputImage );
  ivec2 pixelUV = ivec2( gl_GlobalInvocationID.xy );
  
  if ( any( greaterThanEqual( pixelUV, size ) )  ) 
  {
    return;
  }  

  vec4 color = imageLoad( inputImage, pixelUV + ivec2( -1, 0 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  0, -1 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  1, 0 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  0, 1 ) ) +
               imageLoad( inputImage, pixelUV + ivec2( -1, -1 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  1, 1 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  -1, 1 ) ) +
               imageLoad( inputImage, pixelUV + ivec2(  1, -1 ) )
                ;


  vec4 originalColor = imageLoad( inputImage, pixelUV );
  color.a = originalColor.a;
  
  color = clamp( originalColor * 9.0 - color, 0.0, 1.0 );

  vec3 difference = ( originalColor.rgb - color.rgb );

  if ( length( difference ) > parameters.maxDifference )
  {
    color = originalColor - vec4( normalize( difference ) * parameters.maxDifference, color.a ) ;
  }

  vec4 mixedColor = mix( originalColor, color, parameters.amount );

  imageStore( outputImage, pixelUV, mixedColor );
}