#[compute]
#version 450

layout( local_size_x = 8, local_size_y = 8, local_size_z = 1 ) in;

layout( rgba16f, set = 0, binding = 0 ) 
uniform image2D inputImage;
layout( rgba16f, set = 1, binding = 0 ) 
uniform image2D outputImage;


layout( push_constant, std430 ) 
uniform Params 
{
  vec2 rShift;
  vec2 gShift;
  vec2 bShift;
  float amount;
  float shiftAll;
  float unshiftCenter;

} params;

void main( ) 
{
    ivec2 size = imageSize( inputImage );
    ivec2 texel_coord = ivec2( gl_GlobalInvocationID.xy );
    
    if ( any( greaterThanEqual( texel_coord, size ) )  ) 
    {
      return;
    }

    vec2 uv = ( vec2( texel_coord ) + 0.5 ) / vec2( size );
    

    vec2 uvR = uv + params.rShift * params.shiftAll;
    vec2 uvG = uv + params.gShift * params.shiftAll;
    vec2 uvB = uv + params.bShift * params.shiftAll;

    float r = imageLoad( inputImage, ivec2( uvR * size ) ).r;
    float g = imageLoad( inputImage, ivec2( uvG * size ) ).g;
    float b = imageLoad( inputImage, ivec2( uvB * size ) ).b;

    vec4 chromaticColor = vec4( r, g, b, 1 );

    vec4 color = imageLoad( inputImage, ivec2( uv * size ) );
    vec4 blended = mix( color, chromaticColor, params.amount );


    imageStore( outputImage, texel_coord, blended );
}