// Each #kernel tells which function to compile; you can have many kernels
#pragma use_dxc
#include "UnityCG.cginc"
#define PI 3.14159
#pragma kernel CSMain

RWStructuredBuffer<uint> DDABuffer;
RWStructuredBuffer<uint> ShadowBuffer;
RWTexture2D<float4> Result;
RWTexture3D<float> DDATextureWrite;
Texture3D<float> DDATexture;
float4x4 _CameraInverseProjection;
float4x4 CameraToWorld;
uint ScreenWidth;
uint ScreenHeight;
float3 Size;
struct Ray {
    float3 origin;
    float3 direction;
};
Ray CreateRay(float3 A, float3 B) {
    Ray ray;
    ray.origin = A;
    ray.direction = B;
    return ray;
}
Ray CreateCameraRay(float2 uv) {
    // Transform the camera origin to world space
    float3 origin = mul(CameraToWorld, float4(0.0f, 0.0f, 0.0f, 1.0f)).xyz;

    // Invert the perspective projection of the view-space position
    float3 direction = mul(_CameraInverseProjection, float4(uv, 0.0f, 1.0f)).xyz;
    // Transform the direction from camera to world space and normalize
    direction = mul(CameraToWorld, float4(direction, 0.0f)).xyz;
    direction = normalize(direction);


    return CreateRay(origin, direction);
}



int ShadowDistanceOffset;


inline bool rayBoxIntersection(const float3 ray_orig, const float3 inv_dir, const float3 Min, const float3 Max, float tMax, inout float t0, inout float t1) {
    const float3 tmp_min = (Min - ray_orig) * inv_dir;
    const float3 tmp_max = (Max - ray_orig) * inv_dir;
    const float3 tmin = min(tmp_min, tmp_max);
    const float3 tmax = max(tmp_min, tmp_max);
    t0 = max(tmin.x, max(tmin.y, max(tmin.z, 0))); // Usually ray_tmin = 0
    t1 = min(tmax.x, min(tmax.y, min(tmax.z, tMax)));
    return (t0 <= t1);
}

inline float GetIndex(const int3 xyz) {
    // const int subVolumeIndex = (xyz.x >> 3) + ((xyz.y >> 3) << 4) + ((xyz.z >> 3) << 8);
    // xyz &= 7;
    // return asfloat(DDABuffer[(subVolumeIndex << 9) + (xyz.x + (xyz.y << 3) + (xyz.z << 6))]);
    return LOAD_TEXTURE2D(DDATexture,xyz);
}

float3 SunDir;

inline float3 MarchShadowDDA(int3 mapPos, const float3 rayDir) {
    float Density = 0;
    float t0, t1;
    if(rayBoxIntersection(mapPos, rcp(rayDir), 0, Size, 99999, t0, t1)) {
        const float3 deltaDist = abs(rcp(rayDir)) * ShadowDistanceOffset;
        const int3 rayStep = sign(rayDir) * ShadowDistanceOffset;
        float3 sideDist = ((rayStep * 0.5) + 0.5) * deltaDist;

        float minDist = min(min(sideDist.x, sideDist.y), sideDist.z);
        bool3 mask;
        while (minDist < t1) {
            mask = (sideDist.xyz <= minDist);
            sideDist += mask * deltaDist;
            mapPos += mask * rayStep;
            Density += GetIndex(mapPos) * ShadowDistanceOffset;
            minDist = min(min(sideDist.x, sideDist.y), sideDist.z);
            if (Density >= 10)
                break;
        }
    }
    return Density;
}


inline float3 MarchDDA(Ray ray, inout float3 Luminance) {
    float t0, t1;
    float3 Transmission = 1;
    Luminance = 0;
    float mu = dot(-ray.direction, SunDir);
    if(rayBoxIntersection(ray.origin, rcp(ray.direction), 0, Size, 99999, t0, t1)) {
        ray.origin += ray.direction * t0;
        t1 = t1 - t0;
        int3 mapPos = int3(floor(ray.origin));

        const float3 deltaDist = abs(rcp(ray.direction));
        
        const int3 rayStep = int3(sign(ray.direction));

        float3 sideDist = (rayStep * (mapPos - ray.origin) + (rayStep * 0.5) + 0.5) * deltaDist;

        bool3 mask;
        float minDist = min(min(sideDist.x, sideDist.y), sideDist.z);
        
        while(minDist < t1) {
            mask = (sideDist.xyz <= min(sideDist.yzx, sideDist.zxy));
            sideDist += mask * deltaDist;
            minDist = min(min(sideDist.x, sideDist.y), sideDist.z);
            mapPos += mask * rayStep;
            float Density = GetIndex(mapPos);
            if(Density > 0.01f) {
                float LightDensity = asfloat(ShadowBuffer[mapPos.x + mapPos.y * Size.x + mapPos.z * Size.x * Size.y]);
                float3 beerslaw = MultipleOctaveScattering(LightDensity, mu);
                float3 powder = 1.0f - exp(-LightDensity * 2 * float3(0.8, 0.8, 1));
                // float3 transmittance = 0;
                // float3 Radiance = GetSkyRadiance(mapPos, -SunDir, 0, -SunDir, transmittance);
                float3 lum = 12.0f * beerslaw * lerp(2 * powder, 1, remap(mu, -1, 1, 0, 1));// * transmittance + Radiance;

                float3 transmittance2 = exp(-Density * float3(0.8, 0.8, 1));
                float3 integscatter = (lum - lum * transmittance2);

                Luminance += integscatter * Transmission;
                Transmission *= transmittance2;
                if(Transmission.x < 0.01f) break;
            }
        }

    }
    return Transmission;
}


static const int2 PerFrameOffset[4] = {
    {int2(0,0), int2(1,0), int2(0,1), int2(1,1)  }
};

int CurFrame;
[numthreads(8,8,1)]
void CSMain (uint3 id : SV_DispatchThreadID)
{
    int2 ProperID = id.xy;/// * 2 + PerFrameOffset[CurFrame % 4];
     float2 uv = float2((ProperID) / float2(ScreenWidth, ScreenHeight) * 2.0f - 1.0f);
     Ray ray = CreateCameraRay(uv);
     float3 Luminance;
     // bool ValidPixel = (int(id.x)/3 + int(id.y)/3)%2==(CurFrame % 2);
     // if(ValidPixel) {
        float3 transmittance = 0;
        float3 Radiance = GetSkyRadiance(ray.origin, ray.direction, 0, -SunDir, transmittance);
        Radiance = float3(1, 1, 1) - exp(-Radiance / 1 * 10.0f);
        float3 SkyBoxCol = saturate(Radiance);// * Direct + trans2;
        float3 Sun = 0;//(saturate(max(min(exp(-acos(max(dot(-SunDir, -ray.direction), 0.0f))* 60.0f),12.0f),0) * transmittance));
        if (dot(ray.direction, -SunDir) > cos(0.0235f / 2.0f)) {
            Sun = saturate(Radiance + transmittance * (1.5f / (PI * (0.0235f / 2.0f) * (0.0235f / 2.0f))));
        }
        float3 Transmission = MarchDDA(ray, Luminance);
        Result[ProperID] = float4(SkyBoxCol * Transmission + Luminance + Sun * Transmission,1);
    // }
    }


#pragma kernel CopyToTexture

[numthreads(8,8,8)]
void CopyToTexture (uint3 id : SV_DispatchThreadID)
{
    DDATextureWrite[id.xyz] = asfloat(DDABuffer[id.x + id.y * Size.x + id.z * Size.x * Size.y]);
}

#pragma kernel ShadeComputation

StructuredBuffer<float3> NonZeroVoxels;
[numthreads(1023,1,1)]
void ShadeComputation (uint3 id : SV_DispatchThreadID)
{
    ShadowBuffer[NonZeroVoxels[id.x].x + NonZeroVoxels[id.x].y * Size.x + NonZeroVoxels[id.x].z * Size.y * Size.x] = asuint(MarchShadowDDA(NonZeroVoxels[id.x], -SunDir));

}