
using Godot;
using Godot.Collections;
using System.Collections.Generic;

namespace Rokojori
{  
  [Tool]
  [GlobalClass]
  public abstract partial class BlurCompositorEffect:CompositorEffect
  {
    public enum BlurType
    {
      Box,
      Gaussian
    } 

    [ExportGroup( "Blur Properties" )]
    [Export]
    public BlurType blurType = BlurType.Box;

    [Export( PropertyHint.Range, "1,80" )]
    public int blurSamples = 15;
    [Export( PropertyHint.Range, "1,80" )]
    public int blurWidth = 80;
    [Export( PropertyHint.Range, "0,0.5" )]
    public float dither = 0;
    [Export( PropertyHint.Range, "1,4" )]
    public int mipLevel = 1;
    [Export]
    public bool logDebugInfo = false;

  //for sensing when the backbuffers need rebuilding
    public Vector2I sizeCache = new Vector2I();
    public int mipCache;
    
    public RenderingDevice rd;
    public Rid shader;
    public Rid pipeline;
    public Array backbuffers = new Array();
    public RDTextureFormat backbufferFormat;
    public RDTextureView texview;
    public RDSamplerState samplerState;
    public Rid linearSampler;
    

    public BlurCompositorEffect()
    {  
      RJLog.Log( "_Init" );
      RenderingServer.CallOnRenderThread( Callable.From( InitializeComputeShader ) );    
    }
    
    public void InitializeComputeShader()
    {  
      RJLog.Log( "InitializeComputeShader" );

      rd = RenderingServer.GetRenderingDevice();

      if ( rd == null )
      {
        RJLog.Log( "Initializing failed" );
        return;
      }

      RJLog.Log( "Initializing succeed, loading shader" );
      //Make sure this is correctly pointing to the GLSL file
      RDShaderFile glslFile = ( RDShaderFile ) GD.Load( RokojoriCompositorEffect.Path( "Blur/BlurEffect.glsl"  ) );

      shader = rd.ShaderCreateFromSpirV( glslFile.GetSpirV() );      
      pipeline = rd.ComputePipelineCreate( shader );
      
      samplerState = new RDSamplerState();
      samplerState.MinFilter = RenderingDevice.SamplerFilter.Linear;
      samplerState.MagFilter = RenderingDevice.SamplerFilter.Linear;

      linearSampler = rd.SamplerCreate( samplerState );


      RJLog.Log( "Initializing done", shader, pipeline, samplerState, linearSampler );
      
    }
    
    public override void _Notification( int what )
    {  
      if ( what != NotificationPredelete || ! shader.IsValid || rd == null )
      {
        return;
      }
      
      rd.FreeRid( shader );
      rd.FreeRid( linearSampler );

      foreach( var b in backbuffers )
      {
        rd.FreeRid( ( Rid ) b );      
      }
      
    }

    
    public override void _RenderCallback( int effectCallbackType, RenderData renderData )
    {  
      DoPass( renderData, 1 ); // horizontal blur
      DoPass( renderData, 2 ); // vertical blur
      DoPass( renderData, 3 ); // draw buffers to screen
        
    }
    
    public void DoPass( RenderData renderData, int passNum )
    {  
      if ( rd == null )
      {
        this.LogInfoIf( logDebugInfo, "No RD" );
        return;
      }
      
      //get fresh scene buffers && data for this pass
      RenderSceneBuffersRD sceneBuffers = ( RenderSceneBuffersRD ) renderData.GetRenderSceneBuffers();
      RenderSceneDataRD sceneData = ( RenderSceneDataRD ) renderData.GetRenderSceneData();
     
      if ( sceneBuffers == null && sceneData == null )
      {
        this.LogInfoIf( logDebugInfo, "sceneBuffers == null && sceneData == null" );
        return;
      }
      
      var size = sceneBuffers.GetInternalSize();

      if ( size.X == 0 || size.Y == 0 )
      {
        this.LogInfoIf( logDebugInfo, "size.X == 0 || size.Y == 0" );
        return;
      }

      int xGroups;
      int yGroups;

      if ( passNum == 1 || passNum == 2 )
      {
        xGroups = ( size.X/mipLevel ) / 16 + 1;
        yGroups = ( size.Y/mipLevel ) / 16 + 1;
      }
      else
      {
        xGroups = size.X / 16 + 1;
        yGroups = size.Y / 16 + 1;
      
      }

      int viewCount = ( int ) sceneBuffers.GetViewCount();

      if ( backbuffers.Count < viewCount * 2 || sizeCache != size || mipCache != mipLevel )
      {
        InitBackbuffer( viewCount * 2, size );      
      }

      var packedBytes  = new List<byte>();
      var packedFloats = new List<float>();
      var packedInts   = new List<int>();
      
      if ( passNum == 1 || passNum ==  2 )
      {
        packedFloats.Add( size.X/mipLevel ) ;
        packedFloats.Add( size.Y/mipLevel ) ;
      }
      else
      {
        packedFloats.Add( size.X ) ;
        packedFloats.Add( size.Y ) ;
      }

      packedFloats.Add( dither ) ;
      
      packedInts.Add( blurType == BlurType.Gaussian ? 1 : 0 );
      packedInts.Add( Mathf.Min( blurSamples, blurWidth )) ;
      packedInts.Add( blurWidth );
      packedInts.Add( passNum );
      
      packedBytes.AddRange( Bytes.Convert( packedFloats ) );
      packedBytes.AddRange( Bytes.Convert( packedInts ) );

      while ( packedBytes.Count < 32 )
      {
        packedBytes.Add( 0 );
      }

      for ( int i = 0; i < viewCount; i++ )
      {
        var view = i;
        
        Rid screenTex = sceneBuffers.GetColorLayer( (uint)view );

        Rid screenImageUniformSet = new Rid();
        Rid backbufferUniformSet1 = new Rid();
        Rid backbufferUniformSet2 = new Rid();
    
        if ( passNum == 1 )
        {
          backbufferUniformSet1 = CreateImageUniformSet( (Rid) backbuffers[view] );
          screenImageUniformSet = CreateSamplerUniformSet( screenTex );
        }
        else if ( passNum == 2 )
        {
          backbufferUniformSet2 = CreateImageUniformSet( (Rid) backbuffers[ viewCount + view ] );
          backbufferUniformSet1 = CreateSamplerUniformSet( (Rid) backbuffers[view] ); 
        }
        else if ( passNum == 3 )
        {
          backbufferUniformSet2 = CreateImageUniformSet( screenTex );
          screenImageUniformSet = CreateSamplerUniformSet( (Rid) backbuffers[viewCount + view] );    
        }

        int computeList = (int) rd.ComputeListBegin();
        rd.ComputeListBindComputePipeline( computeList, pipeline );

        if ( passNum == 1 )
        {
          rd.ComputeListBindUniformSet( computeList, backbufferUniformSet1, 0 );
          rd.ComputeListBindUniformSet( computeList, screenImageUniformSet, 1 );
        }
        else if ( passNum == 2 )
        {
          rd.ComputeListBindUniformSet( computeList, backbufferUniformSet2, 0 );
          rd.ComputeListBindUniformSet( computeList, backbufferUniformSet1, 1 );
        }
        else if ( passNum == 3 )
        {
          rd.ComputeListBindUniformSet( computeList, screenImageUniformSet, 1 );
          rd.ComputeListBindUniformSet( computeList, backbufferUniformSet2, 0 );
          
        }

        rd.ComputeListSetPushConstant( computeList, packedBytes.ToArray(), (uint) packedBytes.Count );
        rd.ComputeListDispatch( computeList, (uint) xGroups, (uint)yGroups, 1 );
        rd.ComputeListEnd();
    
      }
    }

    Rid CreateImageUniformSet( Rid image )
    {
      var uniform = new RDUniform();
      uniform.UniformType = RenderingDevice.UniformType.Image;
      uniform.Binding = 0;
      uniform.AddId( image );
      return UniformSetCacheRD.GetCache( shader, 0, new Array<RDUniform>(){uniform} );
    }

    Rid CreateSamplerUniformSet( Rid texture )
    {
      var uniform = new RDUniform();
      uniform.UniformType = RenderingDevice.UniformType.SamplerWithTexture;
      uniform.Binding = 0;
      uniform.AddId( linearSampler );
      uniform.AddId( texture ) ;
      return  UniformSetCacheRD.GetCache( shader, 1, new Array<RDUniform>(){uniform} );
    }

    
    public void InitBackbuffer( int count, Vector2I size )
    {  
      //remember to properly free the buffers else the memory leak will blow up your pc
      foreach( var b in backbuffers )
      {
        rd.FreeRid( ( Rid ) b );
      }

      backbuffers.Clear();
      
      if ( backbufferFormat == null )
      {
        backbufferFormat = new RDTextureFormat();
      //there's loads of formats to choose from. This one is RGBA 16bit float with values 0.0 - 1.0
      }

      backbufferFormat.Format = RenderingDevice.DataFormat.R16G16B16A16Unorm;
      backbufferFormat.Width = ( uint ) ( size.X / mipLevel );
      backbufferFormat.Height =  ( uint ) ( size.Y / mipLevel );
      backbufferFormat.UsageBits = 
        RenderingDevice.TextureUsageBits.StorageBit |
        RenderingDevice.TextureUsageBits.SamplingBit;
      
      if ( texview == null )
      {
        texview = new RDTextureView();        
      }
      
      for ( int i = 0; i < count; i++ )
      {
        backbuffers.Add( rd.TextureCreate( backbufferFormat, texview ) );      
      }

      mipCache = mipLevel;
      sizeCache.X = size.X;
      sizeCache.Y = size.Y;
    }
	
  }

}