package net.vulkanmod.vulkan.shader;

import com.mojang.blaze3d.vertex.VertexFormat;
import com.mojang.blaze3d.vertex.VertexFormatElement;
import com.mojang.blaze3d.vertex.VertexFormatElement.Type;
import it.unimi.dsi.fastutil.objects.Object2LongMap;
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.util.List;
import net.minecraft.class_290;
import net.vulkanmod.interfaces.VertexFormatMixed;
import net.vulkanmod.vulkan.Renderer;
import net.vulkanmod.vulkan.Vulkan;
import net.vulkanmod.vulkan.device.DeviceManager;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;
import org.lwjgl.vulkan.VK10;
import org.lwjgl.vulkan.VkAllocationCallbacks;
import org.lwjgl.vulkan.VkGraphicsPipelineCreateInfo;
import org.lwjgl.vulkan.VkPipelineColorBlendAttachmentState;
import org.lwjgl.vulkan.VkPipelineColorBlendStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineDepthStencilStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineDynamicStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineInputAssemblyStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineMultisampleStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineRasterizationStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineShaderStageCreateInfo;
import org.lwjgl.vulkan.VkPipelineVertexInputStateCreateInfo;
import org.lwjgl.vulkan.VkPipelineViewportStateCreateInfo;
import org.lwjgl.vulkan.VkVertexInputAttributeDescription;
import org.lwjgl.vulkan.VkVertexInputBindingDescription;

public class GraphicsPipeline extends Pipeline {
   private final Object2LongMap<PipelineState> graphicsPipelines = new Object2LongOpenHashMap();
   private final VertexFormat vertexFormat;
   private final VertexInputDescription vertexInputDescription;
   private long vertShaderModule = 0L;
   private long fragShaderModule = 0L;

   GraphicsPipeline(Pipeline.Builder builder) {
      super(builder.shaderPath);
      this.buffers = builder.UBOs;
      this.manualUBO = builder.manualUBO;
      this.imageDescriptors = builder.imageDescriptors;
      this.pushConstants = builder.pushConstants;
      this.vertexFormat = builder.vertexFormat;
      this.vertexInputDescription = new VertexInputDescription(this.vertexFormat);
      this.createDescriptorSetLayout();
      this.createPipelineLayout();
      this.createShaderModules(builder.vertShaderSPIRV, builder.fragShaderSPIRV);
      if (builder.renderPass != null) {
         this.graphicsPipelines.computeIfAbsent(PipelineState.DEFAULT, this::createGraphicsPipeline);
      }

      this.createDescriptorSets(Renderer.getFramesNum());
      PIPELINES.add(this);
   }

   public long getHandle(PipelineState state) {
      return this.graphicsPipelines.computeIfAbsent(state, this::createGraphicsPipeline);
   }

   private long createGraphicsPipeline(PipelineState state) {
      MemoryStack stack = MemoryStack.stackPush();

      long var21;
      try {
         ByteBuffer entryPoint = stack.UTF8("main");
         VkPipelineShaderStageCreateInfo.Buffer shaderStages = VkPipelineShaderStageCreateInfo.calloc(2, stack);
         VkPipelineShaderStageCreateInfo vertShaderStageInfo = (VkPipelineShaderStageCreateInfo)shaderStages.get(0);
         vertShaderStageInfo.sType(18);
         vertShaderStageInfo.stage(1);
         vertShaderStageInfo.module(this.vertShaderModule);
         vertShaderStageInfo.pName(entryPoint);
         VkPipelineShaderStageCreateInfo fragShaderStageInfo = (VkPipelineShaderStageCreateInfo)shaderStages.get(1);
         fragShaderStageInfo.sType(18);
         fragShaderStageInfo.stage(16);
         fragShaderStageInfo.module(this.fragShaderModule);
         fragShaderStageInfo.pName(entryPoint);
         VkPipelineVertexInputStateCreateInfo vertexInputInfo = VkPipelineVertexInputStateCreateInfo.calloc(stack);
         vertexInputInfo.sType(19);
         if (this.vertexInputDescription != null) {
            vertexInputInfo.pVertexBindingDescriptions(this.vertexInputDescription.bindingDescriptions);
            vertexInputInfo.pVertexAttributeDescriptions(this.vertexInputDescription.attributeDescriptions);
         }

         int topology = PipelineState.AssemblyRasterState.decodeTopology(state.assemblyRasterState);
         VkPipelineInputAssemblyStateCreateInfo inputAssembly = VkPipelineInputAssemblyStateCreateInfo.calloc(stack);
         inputAssembly.sType(20);
         inputAssembly.topology(topology);
         inputAssembly.primitiveRestartEnable(false);
         VkPipelineViewportStateCreateInfo viewportState = VkPipelineViewportStateCreateInfo.calloc(stack);
         viewportState.sType(22);
         viewportState.viewportCount(1);
         viewportState.scissorCount(1);
         int polygonMode = PipelineState.AssemblyRasterState.decodePolygonMode(state.assemblyRasterState);
         int cullMode = PipelineState.AssemblyRasterState.decodeCullMode(state.assemblyRasterState);
         VkPipelineRasterizationStateCreateInfo rasterizer = VkPipelineRasterizationStateCreateInfo.calloc(stack);
         rasterizer.sType(23);
         rasterizer.depthClampEnable(false);
         rasterizer.rasterizerDiscardEnable(false);
         rasterizer.polygonMode(polygonMode);
         rasterizer.lineWidth(1.0F);
         rasterizer.cullMode(cullMode);
         rasterizer.frontFace(0);
         rasterizer.depthBiasEnable(true);
         VkPipelineMultisampleStateCreateInfo multisampling = VkPipelineMultisampleStateCreateInfo.calloc(stack);
         multisampling.sType(24);
         multisampling.sampleShadingEnable(false);
         multisampling.rasterizationSamples(1);
         VkPipelineDepthStencilStateCreateInfo depthStencil = VkPipelineDepthStencilStateCreateInfo.calloc(stack);
         depthStencil.sType(25);
         depthStencil.depthTestEnable(PipelineState.DepthState.depthTest(state.depthState_i));
         depthStencil.depthWriteEnable(PipelineState.DepthState.depthMask(state.depthState_i));
         depthStencil.depthCompareOp(PipelineState.DepthState.decodeDepthFun(state.depthState_i));
         depthStencil.depthBoundsTestEnable(false);
         depthStencil.minDepthBounds(0.0F);
         depthStencil.maxDepthBounds(1.0F);
         depthStencil.stencilTestEnable(false);
         boolean hasColorAttachment = state.renderPass.hasColorAttachment();
         VkPipelineColorBlendStateCreateInfo colorBlending = VkPipelineColorBlendStateCreateInfo.calloc(stack);
         colorBlending.sType(26);
         colorBlending.logicOpEnable(PipelineState.LogicOpState.enable(state.logicOp_i));
         colorBlending.logicOp(PipelineState.LogicOpState.decodeFun(state.logicOp_i));
         colorBlending.blendConstants(stack.floats(0.0F, 0.0F, 0.0F, 0.0F));
         if (hasColorAttachment) {
            VkPipelineColorBlendAttachmentState.Buffer colorBlendAttachment = VkPipelineColorBlendAttachmentState.calloc(1, stack);
            colorBlendAttachment.colorWriteMask(state.colorMask_i);
            if (PipelineState.BlendState.enable(state.blendState_i)) {
               colorBlendAttachment.blendEnable(true);
               colorBlendAttachment.srcColorBlendFactor(PipelineState.BlendState.getSrcRgbFactor(state.blendState_i));
               colorBlendAttachment.dstColorBlendFactor(PipelineState.BlendState.getDstRgbFactor(state.blendState_i));
               colorBlendAttachment.colorBlendOp(PipelineState.BlendState.blendOp(state.blendState_i));
               colorBlendAttachment.srcAlphaBlendFactor(PipelineState.BlendState.getSrcAlphaFactor(state.blendState_i));
               colorBlendAttachment.dstAlphaBlendFactor(PipelineState.BlendState.getDstAlphaFactor(state.blendState_i));
               colorBlendAttachment.alphaBlendOp(PipelineState.BlendState.blendOp(state.blendState_i));
            } else {
               colorBlendAttachment.blendEnable(false);
            }

            colorBlending.pAttachments(colorBlendAttachment);
         }

         VkPipelineDynamicStateCreateInfo dynamicStates = VkPipelineDynamicStateCreateInfo.calloc(stack);
         dynamicStates.sType(27);
         if (topology != 1 && polygonMode != 1) {
            dynamicStates.pDynamicStates(stack.ints(3, 0, 1));
         } else {
            dynamicStates.pDynamicStates(stack.ints(3, 0, 1, 2));
         }

         VkGraphicsPipelineCreateInfo.Buffer pipelineInfo = VkGraphicsPipelineCreateInfo.calloc(1, stack);
         pipelineInfo.sType(28);
         pipelineInfo.pStages(shaderStages);
         pipelineInfo.pVertexInputState(vertexInputInfo);
         pipelineInfo.pInputAssemblyState(inputAssembly);
         pipelineInfo.pViewportState(viewportState);
         pipelineInfo.pRasterizationState(rasterizer);
         pipelineInfo.pMultisampleState(multisampling);
         pipelineInfo.pDepthStencilState(depthStencil);
         pipelineInfo.pColorBlendState(colorBlending);
         pipelineInfo.pDynamicState(dynamicStates);
         pipelineInfo.layout(this.pipelineLayout);
         pipelineInfo.basePipelineHandle(0L);
         pipelineInfo.basePipelineIndex(-1);
         pipelineInfo.renderPass(state.renderPass.getId());
         pipelineInfo.subpass(0);
         LongBuffer pGraphicsPipeline = stack.mallocLong(1);
         Vulkan.checkResult(VK10.vkCreateGraphicsPipelines(DeviceManager.vkDevice, PIPELINE_CACHE, pipelineInfo, (VkAllocationCallbacks)null, pGraphicsPipeline), "Failed to create graphics pipeline " + this.name);
         var21 = pGraphicsPipeline.get(0);
      } catch (Throwable var24) {
         if (stack != null) {
            try {
               stack.close();
            } catch (Throwable var23) {
               var24.addSuppressed(var23);
            }
         }

         throw var24;
      }

      if (stack != null) {
         stack.close();
      }

      return var21;
   }

   private void createShaderModules(SPIRVUtils.SPIRV vertSpirv, SPIRVUtils.SPIRV fragSpirv) {
      this.vertShaderModule = createShaderModule(vertSpirv.bytecode());
      this.fragShaderModule = createShaderModule(fragSpirv.bytecode());
   }

   public void cleanUp() {
      VK10.vkDestroyShaderModule(DeviceManager.vkDevice, this.vertShaderModule, (VkAllocationCallbacks)null);
      VK10.vkDestroyShaderModule(DeviceManager.vkDevice, this.fragShaderModule, (VkAllocationCallbacks)null);
      this.vertexInputDescription.cleanUp();
      this.destroyDescriptorSets();
      this.graphicsPipelines.forEach((state, pipeline) -> VK10.vkDestroyPipeline(DeviceManager.vkDevice, pipeline, (VkAllocationCallbacks)null));
      this.graphicsPipelines.clear();
      VK10.vkDestroyDescriptorSetLayout(DeviceManager.vkDevice, this.descriptorSetLayout, (VkAllocationCallbacks)null);
      VK10.vkDestroyPipelineLayout(DeviceManager.vkDevice, this.pipelineLayout, (VkAllocationCallbacks)null);
      PIPELINES.remove(this);
      Renderer.getInstance().removeUsedPipeline(this);
   }

   private static VkVertexInputBindingDescription.Buffer getBindingDescription(VertexFormat vertexFormat) {
      VkVertexInputBindingDescription.Buffer bindingDescription = VkVertexInputBindingDescription.calloc(1);
      bindingDescription.binding(0);
      bindingDescription.stride(vertexFormat.getVertexSize());
      bindingDescription.inputRate(0);
      return bindingDescription;
   }

   private static VkVertexInputAttributeDescription.Buffer getAttributeDescriptions(VertexFormat vertexFormat) {
      List<VertexFormatElement> elements = vertexFormat.getElements();
      int size = elements.size();
      VkVertexInputAttributeDescription.Buffer attributeDescriptions = VkVertexInputAttributeDescription.calloc(size);
      int offset = 0;

      for(int i = 0; i < size; ++i) {
         VkVertexInputAttributeDescription posDescription;
         posDescription = (VkVertexInputAttributeDescription)attributeDescriptions.get(i);
         posDescription.binding(0);
         posDescription.location(i);
         VertexFormatElement formatElement = (VertexFormatElement)elements.get(i);
         VertexFormatElement.Usage usage = formatElement.usage();
         VertexFormatElement.Type type = formatElement.type();
         int elementCount = formatElement.count();
         label55:
         switch (usage) {
            case POSITION:
               switch (type) {
                  case FLOAT:
                     posDescription.format(106);
                     posDescription.offset(offset);
                     offset += 12;
                     break label55;
                  case SHORT:
                     posDescription.format(96);
                     posDescription.offset(offset);
                     offset += 8;
                     break label55;
                  case BYTE:
                     posDescription.format(42);
                     posDescription.offset(offset);
                     offset += 4;
                  default:
                     break label55;
               }
            case COLOR:
               switch (type) {
                  case UBYTE:
                     posDescription.format(37);
                     posDescription.offset(offset);
                     offset += 4;
                     break label55;
                  case UINT:
                     posDescription.format(98);
                     posDescription.offset(offset);
                     offset += 4;
                  default:
                     break label55;
               }
            case UV:
               switch (type) {
                  case FLOAT:
                     posDescription.format(103);
                     posDescription.offset(offset);
                     offset += 8;
                     break label55;
                  case SHORT:
                     posDescription.format(82);
                     posDescription.offset(offset);
                     offset += 4;
                  case BYTE:
                  case UBYTE:
                  default:
                     break label55;
                  case UINT:
                     posDescription.format(98);
                     posDescription.offset(offset);
                     offset += 4;
                     break label55;
                  case USHORT:
                     posDescription.format(81);
                     posDescription.offset(offset);
                     offset += 4;
                     break label55;
               }
            case NORMAL:
               posDescription.format(38);
               posDescription.offset(offset);
               offset += 4;
               break;
            case GENERIC:
               if (type == Type.SHORT && elementCount == 1) {
                  posDescription.format(75);
                  posDescription.offset(offset);
                  offset += 2;
                  break;
               }

               if (type != Type.INT || elementCount != 1) {
                  throw new RuntimeException(String.format("Unknown format: %s", usage));
               }

               posDescription.format(99);
               posDescription.offset(offset);
               offset += 4;
               break;
            default:
               throw new RuntimeException(String.format("Unknown format: %s", usage));
         }

         posDescription.offset(((VertexFormatMixed)vertexFormat).getOffset(i));
      }

      return (VkVertexInputAttributeDescription.Buffer)attributeDescriptions.rewind();
   }

   static class VertexInputDescription {
      final VkVertexInputAttributeDescription.Buffer attributeDescriptions;
      final VkVertexInputBindingDescription.Buffer bindingDescriptions;

      VertexInputDescription(VertexFormat vertexFormat) {
         if (vertexFormat != class_290.field_60033) {
            this.bindingDescriptions = GraphicsPipeline.getBindingDescription(vertexFormat);
            this.attributeDescriptions = GraphicsPipeline.getAttributeDescriptions(vertexFormat);
         } else {
            this.bindingDescriptions = null;
            this.attributeDescriptions = null;
         }

      }

      void cleanUp() {
         if (this.bindingDescriptions != null) {
            MemoryUtil.memFree(this.bindingDescriptions);
            MemoryUtil.memFree(this.attributeDescriptions);
         }

      }
   }
}
