import package::shaders::pbr::pbr_material;
import package::shaders::math::{calculate_camera_pos_worldspace, rotate_xyz_matrix};


struct VertexInput {
    @location(0) a_position: vec3<f32>,
    @location(1) a_normal: vec3<f32>,
    @location(2) a_tangent: vec3<f32>,
    @location(3) a_uv: vec2<f32>,
};

struct VertexOutput {
    @builtin(position) position: vec4<f32>,
    @location(0) v_uv: vec2<f32>,
    @location(1) v_normal_worldspace: vec3<f32>,
    @location(2) v_tangent_worldspace: vec3<f32>,
    @location(3) v_pos_worldspace: vec3<f32>,
    @location(4) v_camera_pos_worldspace: vec3<f32>,
    @location(5) v_material_adjustment: vec3<f32>,
};

// Vertex shader inputs
struct VsUniforms {
    g_projection_from_world: mat4x4<f32>,
    g_projection_from_model: mat4x4<f32>,
    g_camera_from_model: mat4x4<f32>,
    g_camera_from_world: mat4x4<f32>,
    g_world_from_model: mat4x4<f32>,
    g_light_dir_worldspace_norm: vec3<f32>,
    g_app_time: f32,
    g_simulation_frame_ratio: f32,
    time: f32,
    instance_move: vec3<f32>,
};

struct Particle {
    position: vec3<f32>,
    velocity: vec3<f32>,
    upvector: vec3<f32>,
}

@group(0) @binding(0) var<uniform> context: VsUniforms;


struct FsUniforms {
    g_light_projection_from_world: mat4x4<f32>,
    g_camera_from_world: mat4x4<f32>,
    g_projection_from_camera: mat4x4<f32>,
    g_chart_time: f32,
    g_app_time: f32,
    g_light_dir_camspace_norm: vec3<f32>,
    g_light_dir_worldspace_norm: vec3<f32>,
    light_color: vec4<f32>,
    roughness: f32,
    metallic: f32,
    ambient: f32,
    normal_strength: f32,
    shadow_bias: f32,
    color: vec3<f32>,
};

// Fragment shader inputs
@group(1) @binding(0) var<uniform> u: FsUniforms;
@group(1) @binding(1) var envmap: texture_2d<f32>;
@if(!ENTRY_POINT_FS_MAIN_NOOP) 
@group(1) @binding(2) var shadow: texture_depth_2d;
@group(1) @binding(3) var base_color_map: texture_2d<f32>;
@group(1) @binding(4) var roughness_map: texture_2d<f32>;
@group(1) @binding(5) var metallic_map: texture_2d<f32>;
@group(1) @binding(6) var normal_map: texture_2d<f32>;
@group(1) @binding(7) var brdf_lut: texture_2d<f32>;

@group(1) @binding(11) var sampler_envmap: sampler;
@group(1) @binding(12) var sampler_shadow: sampler_comparison;
@group(1) @binding(13) var sampler_repeat: sampler;

fn sample_shadow_map(world_pos: vec3<f32>, shadow: texture_depth_2d) -> f32 {
    var lightspace_pos = (u.g_light_projection_from_world * vec4<f32>(world_pos, 1.0)).xyz;
    lightspace_pos = lightspace_pos * vec3f(0.5, -0.5, 1) + vec3f(0.5, 0.5, u.shadow_bias * -0.001);
    return textureSampleCompare(shadow, sampler_shadow, lightspace_pos.xy, lightspace_pos.z);
}

@if(!ENTRY_POINT_FS_MAIN_NOOP) 
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
    let lightness = sample_shadow_map(in.v_pos_worldspace, shadow);

    let color = pbr_material(in.v_uv, in.v_pos_worldspace, in.v_normal_worldspace, 
        in.v_tangent_worldspace, 
        in.v_camera_pos_worldspace, u.g_light_dir_worldspace_norm,
        u.normal_strength, u.light_color.rgb * lightness, vec3f(u.ambient), 
        u.roughness, u.metallic, 
        base_color_map, roughness_map, metallic_map, normal_map, 
        envmap, brdf_lut, 
        sampler_repeat, sampler_envmap);
    return vec4<f32>(color, 1.0);
}

@fragment
fn fs_main_noop(in: VertexOutput) {}

@vertex
fn vs_main(input: VertexInput, @builtin(instance_index) instance_index: u32) -> VertexOutput {
    var output: VertexOutput;

    let t = floor((context.time + 5.34241)) * f32(instance_index + 1);

    let v1 = cos(vec3f(t, t * 1.12312, t*3.123333)) * 200.;
    let v2 = normalize(cos(vec3f(t, t * 1.12312, t*3.123333)));
    let rotm = rotate_xyz_matrix(v2);

    let m = context.g_world_from_model * rotm;

    output.v_pos_worldspace = (m * vec4<f32>(input.a_position + v1, 1.0)).xyz;

    output.position = context.g_projection_from_world * vec4<f32>(output.v_pos_worldspace, 1.0);

    output.v_uv = input.a_uv;
    output.v_normal_worldspace = (m * vec4<f32>(input.a_normal, 0.0)).xyz;
    output.v_tangent_worldspace = (m * vec4<f32>(input.a_tangent, 0.0)).xyz;
    output.v_camera_pos_worldspace = calculate_camera_pos_worldspace(context.g_camera_from_world);

    return output;
}
