diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 000000000..b5dcafc4e --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,164 @@ +# D3D12-Metal Roadmap (DXMT fork) + +## Goal +Get RE4 (Resident Evil 4, AppID 2050650) running on macOS via MetalSharp Wine with a custom D3D12→Metal translation layer. + +## Current Status: Compute Pipeline Working, Graphics Pipeline Needed + +RE4's compute shaders compile and execute correctly. Descriptor resolution works. Command buffers complete (status=4). Game shows ~35 frames of loading screen (pink = uninitialized/cleared RT, expected). Game exits after init when it tries to enter the main rendering loop. + +### What's Working +- D3D12 device creation + ID3D12Device1 QueryInterface +- DXIL→Metal shader compilation via macOS metal-shaderconverter (v3.1.1) +- Compute PSO creation (CS_FastClear, CS_ZeroFill) with correct threadgroup sizes (256x1x1) +- Descriptor heap + stride fix (sizeof(D3D12Descriptor) instead of hardcoded 64) +- Descriptor table resolution: GPU handles → D3D12Descriptor → Metal resources ✓ +- Command list record + replay with all command types +- Swapchain creation + multi-backbuffer (4 buffers) + Present with blit +- Fence signaling with MTLSharedEvent +- Feature support queries (OPTIONS1-12, shader model 6.5, etc.) +- Root signature creation + serialization +- Resource creation (committed + reserved fallback) +- CBV/SRV/UAV/RTV/DSV creation +- MinGW cross-compilation (build-win64.txt + LLVM 15) + +### What's NOT Working +- **RE4 exits after ~35 frames** — game finishes init, can't enter graphics rendering +- **No graphics PSOs created** — game creates graphics root signatures but dies before CreateGraphicsPipelineState +- **ID3D12Device2-9 not implemented** — RE4 queries all of them, gets E_NOINTERFACE +- No audio + +--- + +## Phase 1: Complete Device Interface Chain + Missing APIs [IN PROGRESS] + +RE4 queries ID3D12Device1-9. Currently only Device1 works. Need to implement Device2-9 stubs so the game doesn't bail when it can't get a newer device interface. + +### 1a. Add ID3D12Device2-9 GUIDs to d3d12.h +The MinGW headers only define up to ID3D12Device1. Need to add GUIDs for Device2-9 so QueryInterface can match them. These are well-known GUIDs from Microsoft's d3d12.h SDK headers. + +### 1b. Make MTLD3D12Device inherit ID3D12Device9 +Inherit the full chain. Stub all new virtual methods. Key methods to actually implement (non-stub): +- Device2: `CreatePipelineState` (stream-based desc) — may be how RE4 creates graphics PSOs +- Device4: `CreateCommandList1` — creates command lists without PSO arg +- Device4: `CreateCommittedResource1`, `CreatePlacedResource1` — with heap flags +- Device8: `CreateCommittedResource2` — with enhanced desc + +### 1c. Add missing command list interfaces +RE4 may query ID3D12GraphicsCommandList1-6. Check and add stubs. + +### 1d. CopyDescriptors fix +Currently broken — divides increment by sizeof(D3D12Descriptor) giving 0. Fix to use direct pointer arithmetic. + +--- + +## Phase 2: Graphics Pipeline + +Once RE4 stays alive past init, it will try to create graphics PSOs with VS/PS shaders. + +### 2a. Graphics PSO compilation +- CompileShader already handles VS/PS via metal-shaderconverter +- Need to create WMT RenderPipelineState with vertex + fragment functions +- Wire up blend state, rasterizer state, depth stencil, RTV formats, topology + +### 2b. Render pass encoding +- OMSetRenderTargets → open render encoder with correct attachments +- ClearRenderTargetView → render encoder clear +- ClearDepthStencilView → render encoder clear +- SetPipelineState → render encoder setRenderPipelineState + +### 2c. Draw call replay +- DrawInstanced → WMT render draw command +- DrawIndexedInstanced → WMT render draw indexed command +- IASetVertexBuffers → set vertex buffer offsets/strides +- IASetIndexBuffer → set index buffer +- IASetPrimitiveTopology → triangle/line/point list + +### 2d. Root signature binding for graphics +- Graphics root constants → setVertexBytes / setFragmentBytes +- Graphics root CBVs → setVertexBuffer / setFragmentBuffer +- Graphics root descriptor tables → resolve descriptors, bind textures/buffers + +--- + +## Phase 3: Polish +- Per-game DLL routing +- Performance tuning (remove waitUntilCompleted, use async fence) +- Support other DX12 games +- Shader cache warmup (pre-compile all shaders from game data) + +--- + +## Build & Deploy + +### Build (MinGW cross-compile) +```bash +cd /tmp/dxmt-src +rm -f build/src/d3d12/d3d12.dll +ninja -C build src/d3d12/d3d12.dll +``` + +### Deploy +```bash +cp build/src/d3d12/d3d12.dll ~/.metalsharp/runtime/wine/lib/wine/x86_64-windows/d3d12.dll +cp build/src/d3d12/d3d12.dll ~/.metalsharp/prefix-steam/drive_c/windows/system32/d3d12.dll +``` + +### Launch RE4 +```bash +# Kill everything first +kill -9 $(ps aux | grep -iE 'wine|steam|re4|cef|winedevice|steamservice|steamwebhelper' | grep -v grep | grep -v ipcserver | awk '{print $2}') + +# Launch +rm -f /tmp/dxmt_dxgi_trace.log +WINEPREFIX=~/.metalsharp/prefix-steam WINEDEBUG=-all \ + nohup ~/.metalsharp/runtime/wine/bin/wine \ + ~/.metalsharp/prefix-steam/drive_c/Program\ Files\ \(x86\)/Steam/steamapps/common/RESIDENT\ EVIL\ 4\ \ BIOHAZARD\ RE4/re4.exe &>/dev/null & +``` + +### Key Paths +- Source: `/tmp/dxmt-src/` → symlinks to `/Volumes/AverySSD/metalsharp/dxmt-src/` +- Shader cache: `/tmp/dxmt_shader_cache/` +- Trace file: `/tmp/dxmt_dxgi_trace.log` +- Cross file: `build-win64.txt` +- LLVM 15: `/opt/homebrew/opt/llvm@15` +- Fake winebuild: `~/.metalsharp/runtime/wine/bin/winebuild` + +### RE4 Command Profile (per frame during init) +- 8 Dispatch (compute only) +- 1 SetGraphicsRoot32BitConstants +- 14 SetGraphicsRootDescriptorTable +- 1 OMSetRenderTargets +- 1 OMSetStencilRef +- 4 ClearDepthStencilView +- 1 ClearRenderTargetView +- 2 SetPipelineState +- ResourceBarrier (no-op) + +### Device Interface GUIDs (from Microsoft d3d12.h) +``` +ID3D12Device = 189819f1-1db6-4b57-be54-1821339b85f7 +ID3D12Device1 = 77acce80-638e-4e65-8895-c1f23386863e +ID3D12Device2 = 30baa41e-b15b-475c-a0bb-1af5c5b64328 ← RE4 queries this +ID3D12Device3 = 81dadc15-2bad-4392-93c5-101345c4aa98 +ID3D12Device4 = e865df17-a9ee-46f9-a463-3098315aa2e5 +ID3D12Device5 = 8b4f173b-2fea-4b80-8f58-4307191ab95d +ID3D12Device6 = c70b221b-40e4-4a17-89af-025a0727a6dc +ID3D12Device7 = 9218e6bb-f944-4f7e-a75c-b1b2c7b701f3 +ID3D12Device8 = 9b7e4c0f-342c-4106-a19f-4f2704f689f0 +ID3D12Device9 = 4c80e962-f032-4f60-bc9e-ebc2cfa1d83c (this is actually Device10) +ID3D12Device10= 74eaee3f-2f4b-476d-82ba-2b85cb49e310 +``` + +Wait — the GUIDs RE4 queries don't all match standard Device2-10 GUIDs. Let me verify against the trace: +- 74eaee3f = ID3D12Device9 or Device10 +- 4c80e962 = ID3D12Device8 or Device9 +- 9218e6bb = ID3D12Device7 +- c70b221b = ID3D12Device6 +- 8b4f173b = ID3D12Device5 +- e865df17 = ID3D12Device4 +- 81dadc15 = ID3D12Device3 +- 30baa41e = ID3D12Device2 +- db6f6ddb = maybe ID3D12Device11 or ID3D12Device12 +- 9b7e4c0f = ID3D12Device8 +- 54ec77fa = unknown diff --git a/compile_shaders.sh b/compile_shaders.sh new file mode 100755 index 000000000..5b844f792 --- /dev/null +++ b/compile_shaders.sh @@ -0,0 +1,14 @@ +#!/bin/bash +echo "Pre-compiling shaders from /tmp/dxmt_shader_cache/..." +cd /tmp/dxmt_shader_cache +for f in *.dxbc; do + [ -f "$f" ] || continue + base="${f%.dxbc}" + if [ ! -f "${base}.metallib" ]; then + /usr/local/bin/metal-shaderconverter -o "${base}.metallib" "$f" --output-reflection-file="${base}.json" 2>/dev/null + echo " Compiled $f -> ${base}.metallib" + else + echo " Already cached: ${base}.metallib" + fi +done +echo "Done. Shaders ready for RE4." diff --git a/cross-mingw.txt b/cross-mingw.txt new file mode 100644 index 000000000..bc1af71a4 --- /dev/null +++ b/cross-mingw.txt @@ -0,0 +1,15 @@ +[binaries] +c = 'x86_64-w64-mingw32-gcc' +cpp = 'x86_64-w64-mingw32-g++' +ar = 'x86_64-w64-mingw32-ar' +strip = 'x86_64-w64-mingw32-strip' +windres = 'x86_64-w64-mingw32-windres' + +[built-in options] +cpp_args = ['-std=c++20'] + +[host_machine] +system = 'windows' +cpu_family = 'x86_64' +cpu = 'x86_64' +endian = 'little' diff --git a/fake-winebuild.sh b/fake-winebuild.sh new file mode 100755 index 000000000..e8b1945d7 --- /dev/null +++ b/fake-winebuild.sh @@ -0,0 +1,4 @@ +#!/bin/bash +# Fake winebuild -- just copy the input to output for postproc +cp "$2" "$2.postproc" 2>/dev/null +exit 0 diff --git a/meson.build b/meson.build index e10c90191..9aa0c7365 100644 --- a/meson.build +++ b/meson.build @@ -47,6 +47,11 @@ compiler_args = [ '-Wno-extern-c-compat', '-Wno-unused-const-variable', '-Wno-missing-braces', + '-Wno-inconsistent-missing-override', + '-Wno-overriding-method-mismatch', + '-Wno-error=non-virtual-dtor', + '-Wno-covered-switch-default', + '-fpermissive', '-fblocks', '-fmacro-prefix-map=' + source_prefix_path + '/=', ] diff --git a/src/airconv/darwin/meson.build b/src/airconv/darwin/meson.build index f3e39ecd7..224853de5 100644 --- a/src/airconv/darwin/meson.build +++ b/src/airconv/darwin/meson.build @@ -14,31 +14,31 @@ llvm_ld_flags_darwin = [ '-lm', '-lz', '-lcurses', '-lxml2' ] -if native_llvm_path.startswith('/usr/local/opt') - llvm_ld_flags_darwin = [ - llvm_ld_flags_darwin, - '/usr/local/opt/zstd/lib/libzstd.a', '/usr/local/opt/llvm@15/lib/libunwind.a' - ] +llvm_extra_libs = [] +if native_llvm_path.startswith('/usr/local/opt') or native_llvm_path.startswith('/opt/homebrew/opt') + llvm_extra_libs = ['/tmp/zstd-1.5.7/lib/libzstd.a', join_paths(native_llvm_path, 'lib/libunwind.a')] +elif native_llvm_path.startswith('/tmp/llvm15-build') + llvm_extra_libs = ['/tmp/zstd-1.5.7/lib/libzstd.a'] endif airconv_lib_darwin = static_library('airconv', airconv_src, include_directories : [ dxmt_include_path, llvm_include_path_darwin ], cpp_args : [ airconv_args ], dependencies : [ dxbc_parser_native_dep ], - link_args : llvm_ld_flags_darwin, + link_args : [ llvm_ld_flags_darwin, llvm_extra_libs ], native : dxmt_crossbuild ) airconv_dep_darwin = declare_dependency( link_with : [ airconv_lib_darwin ], include_directories : [ dxmt_include_path, include_directories('..') ], - link_args : [ llvm_ld_flags_darwin, llvm_deps ] # meh + link_args : [ llvm_ld_flags_darwin, llvm_extra_libs, llvm_deps ] ) executable('airconv', airconv_src + airconv_cli_src, include_directories : [ dxmt_include_path, llvm_include_path_darwin ], cpp_args : [ airconv_args ], dependencies : [ dxbc_parser_native_dep ], - link_args : [ llvm_ld_flags_darwin, llvm_deps ], + link_args : [ llvm_ld_flags_darwin, llvm_extra_libs, llvm_deps ], native : dxmt_crossbuild ) \ No newline at end of file diff --git a/src/airconv/dxil/dxil_container.cpp b/src/airconv/dxil/dxil_container.cpp new file mode 100644 index 000000000..8684450c6 --- /dev/null +++ b/src/airconv/dxil/dxil_container.cpp @@ -0,0 +1,74 @@ +#include "dxil_container.hpp" +#include + +namespace dxmt::dxil { + +std::optional DXILContainer::parse(const void *data, size_t size) { + if (!data || size < 16) + return std::nullopt; + + auto *base = static_cast(data); + + const uint32_t *vals = reinterpret_cast(base); + uint32_t program_version = vals[0]; + uint32_t prog_size = vals[1]; + + const uint32_t *dxil_fields = vals + 2; + uint32_t dxil_magic = dxil_fields[0]; + + if (dxil_magic != DXIL_FOURCC) + return std::nullopt; + + uint16_t dxil_minor = *reinterpret_cast(base + 12); + uint16_t dxil_major = *reinterpret_cast(base + 14); + (void)dxil_major; + (void)dxil_minor; + + uint32_t bitcode_offset = *reinterpret_cast(base + 16); + uint32_t bitcode_size = *reinterpret_cast(base + 20); + + FILE *_dbg = fopen("Z:\\tmp\\dxmt_dxil_trace.log", "a"); + if (_dbg) { + fprintf(_dbg, "DXILContainer: ver=0x%08x prog_size=%u dxil_magic=0x%08x bc_off=%u bc_sz=%u blob_size=%zu\n", + program_version, prog_size, dxil_magic, bitcode_offset, bitcode_size, size); + fclose(_dbg); + } + + uint32_t dxil_magic_offset = 8; + uint32_t actual_bitcode_start = dxil_magic_offset + bitcode_offset; + + uint32_t kind_val = (program_version >> 16) & 0xFFFF; + DxilShaderKind kind = static_cast(kind_val); + + DxilShaderModel sm; + sm.major = (program_version >> 4) & 0xF; + sm.minor = program_version & 0xF; + + if (actual_bitcode_start >= size) + return std::nullopt; + + if (bitcode_size == 0 || actual_bitcode_start + bitcode_size > size) + bitcode_size = size - actual_bitcode_start; + + const uint8_t *bitcode_ptr = base + actual_bitcode_start; + + DXILContainer result; + result.m_shader.kind = kind; + result.m_shader.shader_model = sm; + result.m_shader.bitcode.data = bitcode_ptr; + result.m_shader.bitcode.size = bitcode_size; + + switch (kind) { + case DxilShaderKind::Compute: result.m_shader.entry_point = "cs_main"; break; + case DxilShaderKind::Vertex: result.m_shader.entry_point = "vs_main"; break; + case DxilShaderKind::Pixel: result.m_shader.entry_point = "ps_main"; break; + case DxilShaderKind::Geometry: result.m_shader.entry_point = "gs_main"; break; + case DxilShaderKind::Hull: result.m_shader.entry_point = "hs_main"; break; + case DxilShaderKind::Domain: result.m_shader.entry_point = "ds_main"; break; + default: result.m_shader.entry_point = "main"; break; + } + + return result; +} + +} diff --git a/src/airconv/dxil/dxil_container.hpp b/src/airconv/dxil/dxil_container.hpp new file mode 100644 index 000000000..f38e22168 --- /dev/null +++ b/src/airconv/dxil/dxil_container.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace dxmt::dxil { + +static constexpr uint32_t DXIL_FOURCC = + (uint32_t)'D' | ((uint32_t)'X' << 8) | ((uint32_t)'I' << 16) | ((uint32_t)'L' << 24); + +static constexpr uint32_t DXBC_FOURCC = + (uint32_t)'D' | ((uint32_t)'X' << 8) | ((uint32_t)'B' << 16) | ((uint32_t)'C' << 24); + +struct DXBCChunkHeader { + uint32_t fourCC; + uint32_t size; +}; + +struct DxilProgramHeader { + uint32_t version; + uint32_t size; + uint16_t dxil_major; + uint16_t dxil_minor; +}; + +struct DxilShaderModel { + uint8_t major; + uint8_t minor; +}; + +enum class DxilShaderKind : uint32_t { + Pixel = 0, + Vertex = 1, + Geometry = 2, + Hull = 3, + Domain = 4, + Compute = 5, + Library = 6, + RayGeneration = 7, + Intersection = 8, + AnyHit = 9, + ClosestHit = 10, + Miss = 11, + Callable = 12, + Mesh = 13, + Amplification = 14, + Invalid = 0xFFFFFFFF, +}; + +struct DxilHeader { + DxilProgramHeader program; + uint32_t dxil_version; + uint32_t bitcode_offset; +}; + +struct DxilBitcodeRef { + const uint8_t *data; + uint32_t size; +}; + +struct DxilParsedShader { + DxilShaderKind kind; + DxilShaderModel shader_model; + DxilBitcodeRef bitcode; + std::string entry_point; +}; + +class DXILContainer { +public: + static std::optional parse(const void *data, size_t size); + + const DxilParsedShader &shader() const { return m_shader; } + +private: + DXILContainer() = default; + DxilParsedShader m_shader; + std::vector m_storage; +}; + +} diff --git a/src/airconv/dxil/dxil_to_msl.cpp b/src/airconv/dxil/dxil_to_msl.cpp new file mode 100644 index 000000000..b30286fab --- /dev/null +++ b/src/airconv/dxil/dxil_to_msl.cpp @@ -0,0 +1,987 @@ +#include "dxil_to_msl.hpp" +#include +#include +#include +#include +#include + +#define DXTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxil_trace.log", "a"); if (_tf) { fprintf(_tf, fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt::dxil { + +enum DXIntrinsicOpcode { + DXOP_LoadInput = 4, + DXOP_StoreOutput = 5, + DXOP_CreateHandle = 57, + DXOP_CBufferLoadLegacy = 59, + DXOP_ThreadId = 93, + DXOP_GroupId = 94, + DXOP_ThreadIDInGroup = 95, + DXOP_FlattenedThreadIDInGroup = 96, + DXOP_BufferLoad = 68, + DXOP_BufferStore = 69, + DXOP_TextureLoad = 66, + DXOP_TextureStore = 67, + DXOP_TextureGather = 73, + DXOP_TextureSample = 60, + DXOP_TextureSampleCmp = 63, + DXOP_Barrier = 80, + DXOP_Unary = 13, + DXOP_Binary = 14, + DXOP_Tertiary = 15, + DXOP_Dot2 = 54, + DXOP_Dot3 = 55, + DXOP_Dot4 = 56, + DXOP_MakeDouble = 101, + DXOP_SplitDouble = 102, + DXOP_RawBufferLoad = 1025, + DXOP_RawBufferStore = 1026, + DXOP_AtomicBinOp = 78, + DXOP_AtomicCompareExchange = 79, + DXOP_DerivCoarseX = 83, + DXOP_DerivCoarseY = 84, + DXOP_DerivFineX = 85, + DXOP_DerivFineY = 86, + DXOP_CalcLOD = 81, + DXOP_Texture2DMSGetSamplePosition = 97, + DXOPRenderTargetGetSamplePosition = 98, + DXOP_NumPrimitives = 109, + DXOP_NumOutputVertices = 110, +}; + +static const char *kMetalHeader = R"(#include +using namespace metal; + +)"; + +static std::string escapeName(const std::string &s) { + if (s.empty()) return "_"; + std::string r; + for (char c : s) { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || c == '_') + r += c; + else + r += '_'; + } + if (!r.empty() && r[0] >= '0' && r[0] <= '9') + r = "_" + r; + return r; +} + +std::string DXILToMSL::getTypeName(const LLVMType &t, const LLVMModule &mod) { + switch (t.kind) { + case LLVMType::Void: return "void"; + case LLVMType::Float: return "float"; + case LLVMType::Double: return "float64_t"; + case LLVMType::Integer: + if (t.bit_width == 1) return "bool"; + if (t.bit_width == 8) return "char"; + if (t.bit_width == 16) return "short"; + if (t.bit_width == 32) return "int"; + if (t.bit_width == 64) return "long"; + return "int"; + case LLVMType::Pointer: return "device char*"; + case LLVMType::Struct: return "char" + std::to_string((uint64_t)&t % 997); + case LLVMType::Array: return "array"; + case LLVMType::Vector: { + if (t.subtypes.empty()) + return "float4"; + return getTypeName(t.subtypes[0], mod) + std::to_string(t.bit_width); + } + case LLVMType::Function: return "void"; + } + return "int"; +} + +std::string DXILToMSL::getVectorTypeName(const LLVMType &elem, uint32_t count, const LLVMModule &mod) { + return getTypeName(elem, mod) + std::to_string(count); +} + +uint32_t DXILToMSL::getTypeSize(const LLVMType &t, const LLVMModule &mod) { + switch (t.kind) { + case LLVMType::Void: return 0; + case LLVMType::Float: return 4; + case LLVMType::Double: return 8; + case LLVMType::Integer: return (t.bit_width + 7) / 8; + case LLVMType::Pointer: return 8; + case LLVMType::Struct: { + uint32_t s = 0; + for (auto &st : t.subtypes) + s += getTypeSize(st, mod); + return s; + } + case LLVMType::Array: return t.bit_width * (t.subtypes.empty() ? 4 : getTypeSize(t.subtypes[0], mod)); + case LLVMType::Vector: return t.bit_width * 4; + case LLVMType::Function: return 0; + } + return 4; +} + +std::string DXILToMSL::emitValue(uint32_t idx) { + if (idx == 0xFFFFFFFF) return "undef"; + return "v" + std::to_string(idx); +} + +std::string DXILToMSL::emitConstant(const std::vector &ops, uint32_t type_id, const LLVMModule &mod) { + if (type_id >= mod.types.size()) + return "0"; + auto &t = mod.types[type_id]; + if (ops.empty()) + return "0"; + switch (t.kind) { + case LLVMType::Integer: + if (t.bit_width == 1) return ops[0] ? "true" : "false"; + if (t.bit_width <= 32) return std::to_string((int32_t)ops[0]); + return std::to_string((int64_t)ops[0]); + case LLVMType::Float: { + float f; + uint32_t u = (uint32_t)ops[0]; + memcpy(&f, &u, 4); + char buf[64]; + snprintf(buf, sizeof(buf), "%.9g", (double)f); + if (!strchr(buf, '.') && !strchr(buf, 'e') && !strchr(buf, 'E')) + strcat(buf, ".0"); + return std::string(buf) + "f"; + } + case LLVMType::Double: { + double d; + uint64_t u = ops[0]; + memcpy(&d, &u, 8); + char buf[64]; + snprintf(buf, sizeof(buf), "%.17g", d); + return std::string(buf); + } + default: + return "0"; + } +} + +void DXILToMSL::emitBindings(EmitContext &ctx) { + auto &os = ctx.os; + + if (ctx.shader.kind == DxilShaderKind::Compute) { + os << " uint3 dtid [[thread_position_in_grid]];\n"; + os << " uint3 gtid [[thread_position_in_threadgroup]];\n"; + os << " uint3 ggid [[threadgroup_position_in_grid]];\n"; + os << " uint3 gsz [[threads_per_threadgroup]];\n"; + ctx.uses_thread_id = true; + ctx.uses_group_id = true; + ctx.uses_group_thread_id = true; + ctx.uses_group_size = true; + } + + os << "\n"; +} + +void DXILToMSL::emitFunctionPrologue(EmitContext &ctx) { + auto &os = ctx.os; + os << kMetalHeader; + + os << "struct input_v {\n"; + os << " float4 position [[position]];\n"; + os << " float4 v0;\n float4 v1;\n float4 v2;\n float4 v3;\n"; + os << " float4 v4;\n float4 v5;\n float4 v6;\n float4 v7;\n"; + os << " float2 uv0; float2 uv1; float2 uv2; float2 uv3;\n"; + os << " float4 color0;\n float4 color1;\n float4 color2;\n float4 color3;\n"; + os << "};\n\n"; + + os << "struct output_v {\n"; + os << " float4 position [[position]];\n"; + os << " float4 v0; float4 v1; float4 v2; float4 v3;\n"; + os << " float2 uv0 [[user(locn0)]]; float2 uv1 [[user(locn1)]];\n"; + os << " float2 uv2 [[user(locn2)]]; float2 uv3 [[user(locn3)]];\n"; + os << " float4 color0 [[color(0)]]; float4 color1 [[color(1)]];\n"; + os << " float4 color2 [[color(2)]]; float4 color3 [[color(3)]];\n"; + os << "};\n\n"; + + if (ctx.shader.kind == DxilShaderKind::Compute) { + os << "kernel void cs_main(\n"; + os << " device char* buf0 [[buffer(0)]],\n"; + os << " device char* buf1 [[buffer(1)]],\n"; + os << " device char* buf2 [[buffer(2)]],\n"; + os << " device char* buf3 [[buffer(3)]],\n"; + os << " device char* buf4 [[buffer(4)]],\n"; + os << " device char* buf5 [[buffer(5)]],\n"; + os << " device char* buf6 [[buffer(6)]],\n"; + os << " device char* buf7 [[buffer(7)]],\n"; + os << " texture2d tex0 [[texture(0)]],\n"; + os << " texture2d tex1 [[texture(1)]],\n"; + os << " texture2d tex2 [[texture(2)]],\n"; + os << " texture2d tex3 [[texture(3)]],\n"; + os << " texture2d tex4 [[texture(4)]],\n"; + os << " texture2d tex5 [[texture(5)]],\n"; + os << " texture2d tex6 [[texture(6)]],\n"; + os << " texture2d tex7 [[texture(7)]],\n"; + os << " sampler samp0 [[sampler(0)]],\n"; + os << " sampler samp1 [[sampler(1)]],\n"; + os << " sampler samp2 [[sampler(2)]],\n"; + os << " sampler samp3 [[sampler(3)]],\n"; + os << " uint3 dtid [[thread_position_in_grid]],\n"; + os << " uint3 gtid [[thread_position_in_threadgroup]],\n"; + os << " uint3 ggid [[threadgroup_position_in_grid]],\n"; + os << " uint3 gsz [[threads_per_threadgroup]]\n"; + os << ") {\n"; + } else if (ctx.shader.kind == DxilShaderKind::Vertex) { + os << "vertex output_v vs_main(\n"; + os << " uint vid [[vertex_id]],\n"; + os << " device char* buf0 [[buffer(0)]],\n"; + os << " device char* buf1 [[buffer(1)]],\n"; + os << " device char* buf2 [[buffer(2)]],\n"; + os << " device char* buf3 [[buffer(3)]],\n"; + os << " device char* buf4 [[buffer(4)]],\n"; + os << " device char* buf5 [[buffer(5)]],\n"; + os << " device char* buf6 [[buffer(6)]],\n"; + os << " device char* buf7 [[buffer(7)]]\n"; + os << ") {\n"; + os << " output_v out = {};\n"; + } else if (ctx.shader.kind == DxilShaderKind::Pixel) { + os << "fragment float4 ps_main(\n"; + os << " input_v in [[stage_in]],\n"; + os << " device char* buf0 [[buffer(0)]],\n"; + os << " device char* buf1 [[buffer(1)]],\n"; + os << " device char* buf2 [[buffer(2)]],\n"; + os << " device char* buf3 [[buffer(3)]],\n"; + os << " device char* buf4 [[buffer(4)]],\n"; + os << " device char* buf5 [[buffer(5)]],\n"; + os << " device char* buf6 [[buffer(6)]],\n"; + os << " device char* buf7 [[buffer(7)]],\n"; + os << " texture2d tex0 [[texture(0)]],\n"; + os << " texture2d tex1 [[texture(1)]],\n"; + os << " texture2d tex2 [[texture(2)]],\n"; + os << " texture2d tex3 [[texture(3)]],\n"; + os << " texture2d tex4 [[texture(4)]],\n"; + os << " texture2d tex5 [[texture(5)]],\n"; + os << " texture2d tex6 [[texture(6)]],\n"; + os << " texture2d tex7 [[texture(7)]],\n"; + os << " sampler samp0 [[sampler(0)]],\n"; + os << " sampler samp1 [[sampler(1)]],\n"; + os << " sampler samp2 [[sampler(2)]],\n"; + os << " sampler samp3 [[sampler(3)]]\n"; + os << ") {\n"; + os << " float4 result = float4(0,0,0,1);\n"; + } else { + os << "kernel void unknown_main() {\n"; + } +} + +std::string DXILToMSL::translateDXIntrinsic(EmitContext &ctx, uint32_t intrinsic_id, + const std::vector &args) { + auto &os = ctx.os; + + switch (intrinsic_id) { + case DXOP_CreateHandle: { + if (args.size() < 5) return "0"; + uint32_t resource_class = args[1]; + uint32_t range_id = args[2]; + uint32_t index = args[3]; + bool non_uniform = args[4] != 0; + (void)non_uniform; + uint32_t handle_id = ctx.next_binding++; + std::string res_name; + if (resource_class == 0) { + res_name = "buf" + std::to_string(range_id); + } else if (resource_class == 1) { + res_name = "samp" + std::to_string(range_id); + } else if (resource_class == 2) { + res_name = "tex" + std::to_string(range_id); + } else if (resource_class == 3) { + res_name = "buf" + std::to_string(range_id); + } else { + res_name = "buf" + std::to_string(range_id); + } + DXTRACE("DXIL CreateHandle: class=%u range=%u index=%u -> %s", resource_class, range_id, index, res_name.c_str()); + return res_name; + } + + case DXOP_ThreadId: { + ctx.uses_thread_id = true; + if (!args.empty()) { + uint32_t component = args[0]; + if (component == 0) return "(int)dtidx.x"; + if (component == 1) return "(int)dtidx.y"; + if (component == 2) return "(int)dtidx.z"; + } + return "(int)dtidx.x"; + } + + case DXOP_GroupId: { + ctx.uses_group_id = true; + if (!args.empty()) { + uint32_t component = args[0]; + if (component == 0) return "(int)ggidx.x"; + if (component == 1) return "(int)ggidx.y"; + if (component == 2) return "(int)ggidx.z"; + } + return "(int)ggidx.x"; + } + + case DXOP_ThreadIDInGroup: { + ctx.uses_group_thread_id = true; + if (!args.empty()) { + uint32_t component = args[0]; + if (component == 0) return "(int)gtidx.x"; + if (component == 1) return "(int)gtidx.y"; + if (component == 2) return "(int)gtidx.z"; + } + return "(int)gtidx.x"; + } + + case DXOP_CBufferLoadLegacy: { + if (args.size() < 2) return "float4(0)"; + auto handle = args[0] < ctx.value_table.size() ? ctx.value_table[args[0]] : "buf0"; + auto reg_idx = args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "0"; + return "(reinterpret_cast(" + handle + "[(" + reg_idx + ")*64]))"; + } + + case DXOP_BufferLoad: { + if (args.size() < 3) return "float4(0)"; + auto handle = args[0] < ctx.value_table.size() ? ctx.value_table[args[0]] : "buf0"; + auto index = args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "0"; + auto w = args.size() > 2 && args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "0"; + return "(reinterpret_cast(" + handle + "[(" + index + ")*16]))"; + } + + case DXOP_TextureLoad: { + if (args.size() < 3) return "float4(0)"; + auto handle = args[0] < ctx.value_table.size() ? ctx.value_table[args[0]] : "tex0"; + auto coord = args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "int2(0)"; + return handle + ".read(" + coord + ")"; + } + + case DXOP_TextureSample: { + if (args.size() < 4) return "float4(0)"; + auto handle = args[0] < ctx.value_table.size() ? ctx.value_table[args[0]] : "tex0"; + auto sampler = args.size() > 1 && args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "samp0"; + auto coord = args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "float2(0)"; + return handle + ".sample(" + sampler + ", " + coord + ")"; + } + + case DXOP_Barrier: { + return "threadgroup_barrier(mem_flags::mem_threadgroup)"; + } + + case DXOP_Dot2: { + if (args.size() < 3) return "0.0"; + auto a = args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "float2(0)"; + auto b = args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "float2(0)"; + return "dot(" + a + ", " + b + ")"; + } + + case DXOP_Dot3: { + if (args.size() < 3) return "0.0"; + auto a = args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "float3(0)"; + auto b = args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "float3(0)"; + return "dot(" + a + ", " + b + ")"; + } + + case DXOP_Dot4: { + if (args.size() < 3) return "0.0"; + auto a = args[1] < ctx.value_table.size() ? ctx.value_table[args[1]] : "float4(0)"; + auto b = args[2] < ctx.value_table.size() ? ctx.value_table[args[2]] : "float4(0)"; + return "dot(" + a + ", " + b + ")"; + } + + case DXOP_LoadInput: { + if (args.size() < 4) return "float4(0)"; + uint32_t input_id = args[1]; + uint32_t component = args.size() > 3 ? args[3] : 0; + if (ctx.shader.kind == DxilShaderKind::Pixel) { + switch (input_id) { + case 0: return "in.position"; + case 1: return "in.v0"; + case 2: return "in.v1"; + case 3: return "in.v2"; + case 4: return "in.v3"; + default: return "in.v0"; + } + } + return "float4(0)"; + } + + case DXOP_StoreOutput: { + if (args.size() < 4) return ""; + uint32_t output_id = args[1]; + uint32_t component = args.size() > 2 ? args[2] : 0; + auto val = args[3] < ctx.value_table.size() ? ctx.value_table[args[3]] : "float4(0)"; + + if (ctx.shader.kind == DxilShaderKind::Vertex) { + switch (output_id) { + case 0: return "out.position = " + val; + case 1: return "out.v0 = " + val; + case 2: return "out.v1 = " + val; + case 3: return "out.v2 = " + val; + default: return "out.v0 = " + val; + } + } + return "result = " + val; + } + + default: + DXTRACE("DXIL unknown intrinsic: %u", intrinsic_id); + break; + } + + return "0 /* unknown dx intrinsic " + std::to_string(intrinsic_id) + " */"; +} + +void DXILToMSL::emitInstruction(EmitContext &ctx, const LLVMInstruction &inst, uint32_t &value_counter) { + auto &os = ctx.os; + std::string result = emitValue(value_counter); + + auto getValue = [&](uint32_t idx) -> std::string { + if (idx < ctx.value_table.size() && !ctx.value_table[idx].empty()) + return ctx.value_table[idx]; + return emitValue(idx); + }; + + auto ensureValueTable = [&](uint32_t needed) { + if (ctx.value_table.size() <= needed) + ctx.value_table.resize(needed + 1); + }; + + switch (inst.opcode) { + case LLVMInstruction::Ret: + if (ctx.shader.kind == DxilShaderKind::Vertex) { + os << " return out;\n"; + } else if (ctx.shader.kind == DxilShaderKind::Pixel) { + os << " return result;\n"; + } else { + os << " return;\n"; + } + break; + + case LLVMInstruction::Call: { + if (inst.operands.empty()) + break; + uint32_t callee = inst.operands[0]; + + std::vector call_args; + for (size_t i = 2; i < inst.operands.size(); i++) + call_args.push_back(inst.operands[i]); + + if (callee < ctx.value_table.size() && ctx.value_table[callee].substr(0, 5) == "dx.op") { + uint32_t intrinsic_id = 0; + if (call_args.size() > 0) { + std::string id_str = getValue(call_args[0]); + intrinsic_id = (uint32_t)std::stoi(id_str); + } + std::vector remaining_args(call_args.begin() + 1, call_args.end()); + + std::string translated = translateDXIntrinsic(ctx, intrinsic_id, remaining_args); + + if (inst.type_id != 0 || translated.find('=') == std::string::npos) { + ensureValueTable(value_counter); + if (!translated.empty() && translated[0] != ' ') { + os << " " << result << " = " << translated << ";\n"; + ctx.value_table[value_counter] = result; + } else if (!translated.empty()) { + os << " " << translated << ";\n"; + } + } else { + os << " " << translated << ";\n"; + } + } else { + os << " // call " << getValue(callee) << "("; + for (size_t i = 0; i < call_args.size(); i++) { + if (i) os << ", "; + os << getValue(call_args[i]); + } + os << ")\n"; + ensureValueTable(value_counter); + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::Add: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " + " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Sub: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " - " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Mul: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " * " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::UDiv: { + ensureValueTable(value_counter); + os << " " << result << " = (" << getValue(inst.operands[0]) << ") / (" << getValue(inst.operands[1]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::SDiv: { + ensureValueTable(value_counter); + os << " " << result << " = (" << getValue(inst.operands[0]) << ") / (" << getValue(inst.operands[1]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FAdd: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " + " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FSub: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " - " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FMul: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " * " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FDiv: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " / " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FRem: { + ensureValueTable(value_counter); + os << " " << result << " = fmod(" << getValue(inst.operands[0]) << ", " << getValue(inst.operands[1]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::And: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " & " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Or: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " | " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Xor: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " ^ " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Shl: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " << " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::LShr: { + ensureValueTable(value_counter); + os << " " << result << " = (uint)(" << getValue(inst.operands[0]) << ") >> " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::AShr: { + ensureValueTable(value_counter); + os << " " << result << " = (int)(" << getValue(inst.operands[0]) << ") >> " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::BitCast: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = reinterpret_cast(" << getValue(inst.operands[0]) << ");\n"; + } + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::ZExt: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = (decltype(" << result << "))(" << getValue(inst.operands[0]) << ");\n"; + } + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::SExt: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = (decltype(" << result << "))(" << getValue(inst.operands[0]) << ");\n"; + } + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Trunc: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = (decltype(" << result << "))(" << getValue(inst.operands[0]) << ");\n"; + } + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FPToUI: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FPToSI: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::UIToFP: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::SIToFP: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FPTrunc: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::FPExt: { + ensureValueTable(value_counter); + os << " " << result << " = static_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::PtrToInt: { + ensureValueTable(value_counter); + os << " " << result << " = reinterpret_cast(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::IntToPtr: { + ensureValueTable(value_counter); + os << " " << result << " = reinterpret_cast(static_cast(" << getValue(inst.operands[0]) << "));\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::ICmp: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 3) { + auto pred = inst.operands[0]; + auto lhs = getValue(inst.operands[1]); + auto rhs = getValue(inst.operands[2]); + std::string op; + switch (pred) { + case 32: op = "=="; break; + case 33: op = "!="; break; + case 34: op = ">"; break; + case 35: op = ">="; break; + case 36: op = "<"; break; + case 37: op = "<="; break; + default: op = "=="; break; + } + os << " bool " << result << " = " << lhs << " " << op << " " << rhs << ";\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::FCmp: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 3) { + auto pred = inst.operands[0]; + auto lhs = getValue(inst.operands[1]); + auto rhs = getValue(inst.operands[2]); + std::string op; + switch (pred) { + case 0: os << " bool " << result << " = false;\n"; break; + case 1: os << " bool " << result << " = true;\n"; break; + case 2: os << " bool " << result << " = isunordered(" << lhs << ", " << rhs << ");\n"; break; + case 3: os << " bool " << result << " = (" << lhs << " == " << rhs << ");\n"; break; + case 4: os << " bool " << result << " = (" << lhs << " != " << rhs << ");\n"; break; + case 5: os << " bool " << result << " = (" << lhs << " > " << rhs << ");\n"; break; + case 6: os << " bool " << result << " = (" << lhs << " >= " << rhs << ");\n"; break; + case 7: os << " bool " << result << " = (" << lhs << " < " << rhs << ");\n"; break; + case 8: os << " bool " << result << " = (" << lhs << " <= " << rhs << ");\n"; break; + default: os << " bool " << result << " = false;\n"; break; + } + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::Select: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 3) { + os << " " << result << " = " << getValue(inst.operands[0]) << " ? " << getValue(inst.operands[1]) << " : " << getValue(inst.operands[2]) << ";\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::Load: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = reinterpret_cast(" << getValue(inst.operands[0]) << ");\n"; + } + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Store: { + if (inst.operands.size() >= 2) { + os << " reinterpret_cast(" << getValue(inst.operands[0]) << ") = " << getValue(inst.operands[1]) << ";\n"; + } + break; + } + + case LLVMInstruction::GEP: + case LLVMInstruction::GetElementPtr: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 2) { + auto base = getValue(inst.operands[0]); + std::string offset = "0"; + if (inst.operands.size() >= 2) + offset = getValue(inst.operands[1]); + for (size_t i = 2; i < inst.operands.size(); i++) { + offset = "(" + offset + " + " + getValue(inst.operands[i]) + ")"; + } + os << " device char* " << result << " = (" << base << " + (" << offset << "));\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::Alloca: { + ensureValueTable(value_counter); + os << " thread char* " << result << " = (thread char*)alloca(256);\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::PHI: { + ensureValueTable(value_counter); + os << " auto " << result << " = decltype(" << result << ")(0); // phi\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Br: { + if (inst.operands.size() == 1) { + // unconditional branch + } else if (inst.operands.size() >= 3) { + auto cond = getValue(inst.operands[0]); + os << " if (" << cond << ") {\n // br true\n } else {\n // br false\n }\n"; + } + break; + } + + case LLVMInstruction::Switch: { + os << " // switch\n"; + break; + } + + case LLVMInstruction::ExtractValue: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 2) { + auto agg = getValue(inst.operands[0]); + auto idx = inst.operands[1]; + os << " " << result << " = (" << agg << "); // extractvalue idx=" << idx << "\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::InsertValue: { + ensureValueTable(value_counter); + os << " " << result << " = " << (inst.operands.size() >= 1 ? getValue(inst.operands[0]) : "0") << "; // insertvalue\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::ExtractElement: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 2) { + os << " " << result << " = " << getValue(inst.operands[0]) << "[" << getValue(inst.operands[1]) << "];\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::InsertElement: { + ensureValueTable(value_counter); + os << " " << result << " = " << (inst.operands.size() >= 1 ? getValue(inst.operands[0]) : "float4(0)") << "; // insertelement\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::ShuffleVector: { + ensureValueTable(value_counter); + os << " " << result << " = " << (inst.operands.size() >= 1 ? getValue(inst.operands[0]) : "float4(0)") << "; // shufflevector\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Unreachable: + os << " // unreachable\n"; + break; + + case LLVMInstruction::FNeg: { + ensureValueTable(value_counter); + if (inst.operands.size() >= 1) { + os << " " << result << " = -(" << getValue(inst.operands[0]) << ");\n"; + ctx.value_table[value_counter] = result; + } + value_counter++; + break; + } + + case LLVMInstruction::URem: + case LLVMInstruction::SRem: { + ensureValueTable(value_counter); + os << " " << result << " = " << getValue(inst.operands[0]) << " % " << getValue(inst.operands[1]) << ";\n"; + ctx.value_table[value_counter] = result; + value_counter++; + break; + } + + case LLVMInstruction::Invoke: { + os << " // invoke\n"; + break; + } + + default: + os << " // unhandled opcode " << (int)inst.opcode << "\n"; + ensureValueTable(value_counter); + ctx.value_table[value_counter] = result; + value_counter++; + break; + } +} + +std::optional DXILToMSL::convert(const LLVMModule &module, + const DxilParsedShader &shader) { + DXTRACE("DXILToMSL::convert: kind=%u sm=%u.%u functions=%zu types=%zu", + (uint32_t)shader.kind, shader.shader_model.major, shader.shader_model.minor, + module.functions.size(), module.types.size()); + + std::ostringstream os; + EmitContext ctx{os, module, shader, {}, {}, 0, false, false, false, false}; + + emitFunctionPrologue(ctx); + + ctx.value_table.resize(256); + + if (!module.functions.empty()) { + for (size_t i = 0; i < module.constants.size(); i++) { + uint32_t val_idx = (uint32_t)i; + if (val_idx < ctx.value_table.size()) { + ctx.value_table[val_idx] = "const_" + std::to_string(i); + } + } + + auto &fn = module.functions.back(); + DXTRACE("DXILToMSL: entry function has %zu blocks", fn.blocks.size()); + + uint32_t value_counter = (uint32_t)module.constants.size(); + + for (auto &block : fn.blocks) { + for (auto &inst : block.instructions) { + emitInstruction(ctx, inst, value_counter); + } + } + } else { + os << " // No functions parsed from DXIL bitcode\n"; + DXTRACE("DXILToMSL: no functions in module"); + } + + os << "}\n"; + + MSLShader result; + result.source = os.str(); + result.entry_point = shader.entry_point; + result.tg_size[0] = 1; + result.tg_size[1] = 1; + result.tg_size[2] = 1; + + DXTRACE("DXILToMSL: generated %zu bytes of MSL", result.source.size()); + + return result; +} + +} diff --git a/src/airconv/dxil/dxil_to_msl.hpp b/src/airconv/dxil/dxil_to_msl.hpp new file mode 100644 index 000000000..e66046418 --- /dev/null +++ b/src/airconv/dxil/dxil_to_msl.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "llvm_bitcode.hpp" +#include "dxil_container.hpp" +#include +#include +#include +#include + +namespace dxmt::dxil { + +struct MSLShader { + std::string source; + std::string entry_point; + uint32_t tg_size[3] = {1, 1, 1}; + uint32_t num_uavs = 0; + uint32_t num_srvs = 0; + uint32_t num_cbuffers = 0; + uint32_t num_samplers = 0; +}; + +struct ResourceBinding { + uint32_t register_space; + uint32_t register_index; + uint32_t count; + enum class Kind { SRV, UAV, CBuffer, Sampler } kind; + std::string name; +}; + +class DXILToMSL { +public: + static std::optional convert(const LLVMModule &module, + const DxilParsedShader &shader); + +private: + struct EmitContext { + std::ostringstream &os; + const LLVMModule &mod; + const DxilParsedShader &shader; + std::vector value_table; + std::vector resource_bindings; + uint32_t next_binding = 0; + bool uses_thread_id = false; + bool uses_group_id = false; + bool uses_group_thread_id = false; + bool uses_group_size = false; + }; + + static std::string getTypeName(const LLVMType &t, const LLVMModule &mod); + static std::string getVectorTypeName(const LLVMType &elem_type, uint32_t count, const LLVMModule &mod); + static uint32_t getTypeSize(const LLVMType &t, const LLVMModule &mod); + static std::string emitValue(uint32_t idx); + static std::string emitConstant(const std::vector &ops, uint32_t type_id, const LLVMModule &mod); + static void emitFunctionPrologue(EmitContext &ctx); + static void emitBindings(EmitContext &ctx); + static void emitInstruction(EmitContext &ctx, const LLVMInstruction &inst, uint32_t &value_counter); + static std::string translateDXIntrinsic(EmitContext &ctx, uint32_t intrinsic_id, + const std::vector &args); +}; + +} diff --git a/src/airconv/dxil/llvm_bitcode.cpp b/src/airconv/dxil/llvm_bitcode.cpp new file mode 100644 index 000000000..e380de9be --- /dev/null +++ b/src/airconv/dxil/llvm_bitcode.cpp @@ -0,0 +1,534 @@ +#include "llvm_bitcode.hpp" +#include +#include +#include + +#define DXTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxil_trace.log", "a"); if (_tf) { fprintf(_tf, fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt::dxil { + +class BitstreamReader { +public: + BitstreamReader(const uint8_t *data, uint32_t size) + : m_data(data), m_size(size), m_offset(0), m_cur_byte(0), m_bits_left(0) {} + + uint32_t read(uint32_t num_bits) { + uint32_t result = 0; + uint32_t bits_read = 0; + while (bits_read < num_bits) { + if (m_bits_left == 0) { + if (m_offset >= m_size) return 0; + m_cur_byte = m_data[m_offset++]; + m_bits_left = 8; + } + uint32_t to_read = std::min(num_bits - bits_read, m_bits_left); + result |= (uint32_t)(m_cur_byte & ((1 << to_read) - 1)) << bits_read; + m_cur_byte >>= to_read; + m_bits_left -= to_read; + bits_read += to_read; + } + return result; + } + + uint64_t read64(uint32_t num_bits) { + if (num_bits <= 32) return read(num_bits); + uint64_t lo = read(32); + uint64_t hi = read(num_bits - 32); + return lo | (hi << 32); + } + + uint32_t readVBR(uint32_t width) { + uint32_t result = 0; + uint32_t shift = 0; + uint32_t chunk; + do { + chunk = read(width); + result |= (chunk & ((1u << (width - 1)) - 1)) << shift; + shift += width - 1; + } while (chunk & (1u << (width - 1))); + return result; + } + + uint64_t readVBR64(uint32_t width) { + uint64_t result = 0; + uint64_t shift = 0; + uint32_t chunk; + do { + chunk = read(width); + result |= (uint64_t)(chunk & ((1u << (width - 1)) - 1)) << shift; + shift += width - 1; + } while (chunk & (1u << (width - 1))); + return result; + } + + void align32() { + while (m_offset % 4 != 0 && m_offset < m_size) + m_offset++; + m_bits_left = 0; + } + + uint32_t tell() const { return m_offset * 8 - m_bits_left; } + void seek(uint32_t bit_pos) { + m_offset = bit_pos / 8; + m_bits_left = 0; + uint32_t skip = bit_pos % 8; + if (skip) read(skip); + } + + bool atEnd() const { return m_offset >= m_size && m_bits_left == 0; } + +private: + const uint8_t *m_data; + uint32_t m_size; + uint32_t m_offset; + uint8_t m_cur_byte; + uint32_t m_bits_left; +}; + +static constexpr uint32_t kEnterSubBlock = 1; +static constexpr uint32_t kEndBlock = 0; +static constexpr uint32_t kDefineAbbrev = 2; +static constexpr uint32_t kUnabbrevRecord = 3; + +static constexpr uint32_t kBlockID_Module = 8; +static constexpr uint32_t kBlockID_BlockInfo = 0; +static constexpr uint32_t kBlockID_ValueSymTab = 14; +static constexpr uint32_t kBlockID_Function = 12; +static constexpr uint32_t kBlockID_Type = 17; +static constexpr uint32_t kBlockID_Constants = 11; + +static constexpr uint32_t kTypeCode_Void = 2; +static constexpr uint32_t kTypeCode_Float = 3; +static constexpr uint32_t kTypeCode_Double = 4; +static constexpr uint32_t kTypeCode_Integer = 7; +static constexpr uint32_t kTypeCode_Pointer = 8; +static constexpr uint32_t kTypeCode_Struct = 10; +static constexpr uint32_t kTypeCode_Array = 11; +static constexpr uint32_t kTypeCode_Vector = 12; +static constexpr uint32_t kTypeCode_Function = 9; + +static constexpr uint32_t kModuleCode_Function = 8; +static constexpr uint32_t kModuleCode_GlobalVar = 7; +static constexpr uint32_t kModuleCode_VSTOffset = 19; + +static constexpr uint32_t kFuncCode_DeclareBlocks = 1; +static constexpr uint32_t kFuncCode_InstRet = 10; +static constexpr uint32_t kFuncCode_InstBr = 11; +static constexpr uint32_t kFuncCode_InstCall = 34; +static constexpr uint32_t kFuncCode_InstPHI = 4; +static constexpr uint32_t kFuncCode_InstBinop = 2; +static constexpr uint32_t kFuncCode_InstCast = 3; +static constexpr uint32_t kFuncCode_InstGEP = 26; +static constexpr uint32_t kFuncCode_InstLoad = 20; +static constexpr uint32_t kFuncCode_InstStore = 44; +static constexpr uint32_t kFuncCode_InstExtractVal = 55; +static constexpr uint32_t kFuncCode_InstInsertVal = 56; +static constexpr uint32_t kFuncCode_InstSelect = 17; +static constexpr uint32_t kFuncCode_InstICmp = 45; +static constexpr uint32_t kFuncCode_InstFCmp = 46; +static constexpr uint32_t kFuncCode_InstUnreachable = 15; +static constexpr uint32_t kFuncCode_InstAlloca = 19; +static constexpr uint32_t kFuncCode_InstExtractElt = 57; +static constexpr uint32_t kFuncCode_InstInsertElt = 58; +static constexpr uint32_t kFuncCode_InstShuffleVec = 59; +static constexpr uint32_t kFuncCode_InstSwitch = 12; +static constexpr uint32_t kFuncCode_InstInvoke = 13; + +static constexpr uint32_t kConstantsCode_SetType = 1; +static constexpr uint32_t kConstantsCode_Null = 2; +static constexpr uint32_t kConstantsCode_Undefined = 3; +static constexpr uint32_t kConstantsCode_Integer = 4; +static constexpr uint32_t kConstantsCode_Float = 6; +static constexpr uint32_t kConstantsCode_Aggregate = 7; +static constexpr uint32_t kConstantsCode_String = 8; +static constexpr uint32_t kConstantsCode_Cast = 11; +static constexpr uint32_t kConstantsCode_GEP = 12; +static constexpr uint32_t kConstantsCode_Data = 15; + +struct Abbrev { + std::vector> ops; +}; + +struct BlockInfo { + uint32_t block_id = 0; + std::vector abbrevs; +}; + +struct ParseContext { + BitstreamReader &reader; + LLVMModule &module; + std::vector cur_abbrevs; + std::vector block_infos; +}; + +static std::optional readBlockHeader(BitstreamReader &r) { + r.align32(); + uint32_t block_id = r.readVBR(8); + uint32_t new_abbrev_len = r.readVBR(4); + uint32_t block_len = r.read(32); + (void)block_len; + return new_abbrev_len; +} + +static std::vector readUnabbrevRecord(BitstreamReader &r) { + uint32_t code = r.readVBR(6); + uint32_t num_ops = r.readVBR(6); + std::vector ops; + ops.push_back(code); + for (uint32_t i = 0; i < num_ops; i++) { + ops.push_back(r.readVBR64(6)); + } + return ops; +} + +static bool parseTypeBlock(ParseContext &ctx) { + auto abbrev_len = readBlockHeader(ctx.reader); + if (!abbrev_len) return false; + + while (!ctx.reader.atEnd()) { + uint32_t code = ctx.reader.read(*abbrev_len); + if (code == kEndBlock) { + ctx.reader.align32(); + return true; + } + if (code == kEnterSubBlock || code == kDefineAbbrev) + continue; + + std::vector ops; + if (code == kUnabbrevRecord) { + ops = readUnabbrevRecord(ctx.reader); + } else { + continue; + } + + uint32_t rec_code = (uint32_t)ops[0]; + LLVMType t; + t.kind = LLVMType::Void; + + switch (rec_code) { + case kTypeCode_Void: + t.kind = LLVMType::Void; + ctx.module.types.push_back(t); + break; + case kTypeCode_Float: + t.kind = LLVMType::Float; + t.bit_width = 32; + ctx.module.types.push_back(t); + break; + case kTypeCode_Double: + t.kind = LLVMType::Double; + t.bit_width = 64; + ctx.module.types.push_back(t); + break; + case kTypeCode_Integer: { + t.kind = LLVMType::Integer; + t.bit_width = ops.size() > 1 ? (uint32_t)ops[1] : 32; + ctx.module.types.push_back(t); + break; + } + case kTypeCode_Pointer: { + t.kind = LLVMType::Pointer; + if (ops.size() > 1) + t.subtypes.push_back({LLVMType::Void, 0, {}}); + ctx.module.types.push_back(t); + break; + } + case kTypeCode_Struct: { + t.kind = LLVMType::Struct; + for (size_t i = 1; i < ops.size(); i++) + t.subtypes.push_back({LLVMType::Void, 0, {}}); + ctx.module.types.push_back(t); + break; + } + case kTypeCode_Array: { + t.kind = LLVMType::Array; + t.bit_width = ops.size() > 1 ? (uint32_t)ops[1] : 0; + ctx.module.types.push_back(t); + break; + } + case kTypeCode_Vector: { + t.kind = LLVMType::Vector; + t.bit_width = ops.size() > 1 ? (uint32_t)ops[1] : 0; + ctx.module.types.push_back(t); + break; + } + case kTypeCode_Function: { + t.kind = LLVMType::Function; + if (ops.size() > 1) + t.subtypes.push_back({LLVMType::Void, 0, {}}); + ctx.module.types.push_back(t); + break; + } + default: + break; + } + } + return false; +} + +static bool parseConstantsBlock(ParseContext &ctx) { + auto abbrev_len = readBlockHeader(ctx.reader); + if (!abbrev_len) return false; + + uint32_t cur_type = 0; + while (!ctx.reader.atEnd()) { + uint32_t code = ctx.reader.read(*abbrev_len); + if (code == kEndBlock) { + ctx.reader.align32(); + return true; + } + if (code == kEnterSubBlock || code == kDefineAbbrev) + continue; + + std::vector ops; + if (code == kUnabbrevRecord) { + ops = readUnabbrevRecord(ctx.reader); + } else { + continue; + } + + uint32_t rec_code = (uint32_t)ops[0]; + switch (rec_code) { + case kConstantsCode_SetType: + if (ops.size() > 1) cur_type = (uint32_t)ops[1]; + break; + case kConstantsCode_Integer: + case kConstantsCode_Float: + case kConstantsCode_Null: + case kConstantsCode_Undefined: { + LLVMValue v; + v.kind = LLVMValue::Constant; + v.type_id = cur_type; + v.id = (uint32_t)ctx.module.constants.size(); + ctx.module.constants.push_back(v); + break; + } + default: + break; + } + } + return false; +} + +static bool parseFunctionBlock(ParseContext &ctx, LLVMFunction &fn) { + auto abbrev_len = readBlockHeader(ctx.reader); + if (!abbrev_len) return false; + + uint32_t cur_block = 0; + uint32_t value_id = 0; + + while (!ctx.reader.atEnd()) { + uint32_t code = ctx.reader.read(*abbrev_len); + if (code == kEndBlock) { + ctx.reader.align32(); + return true; + } + if (code == kEnterSubBlock) { + continue; + } + if (code == kDefineAbbrev) + continue; + + std::vector ops; + if (code == kUnabbrevRecord) { + ops = readUnabbrevRecord(ctx.reader); + } else { + continue; + } + + uint32_t rec_code = (uint32_t)ops[0]; + + switch (rec_code) { + case kFuncCode_DeclareBlocks: + fn.blocks.resize(ops.size() > 1 ? (size_t)ops[1] : 0); + cur_block = 0; + break; + case kFuncCode_InstRet: + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::Ret; + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + case kFuncCode_InstCall: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::Call; + if (ops.size() > 1) inst.type_id = (uint32_t)ops[1]; + for (size_t i = 2; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstBinop: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::Add; + if (ops.size() > 1) inst.type_id = (uint32_t)ops[1]; + for (size_t i = 2; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstCast: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::BitCast; + if (ops.size() > 1) inst.type_id = (uint32_t)ops[1]; + for (size_t i = 2; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstGEP: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::GetElementPtr; + for (size_t i = 1; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstLoad: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::Load; + if (ops.size() > 1) inst.type_id = (uint32_t)ops[1]; + for (size_t i = 2; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstStore: { + if (cur_block < fn.blocks.size()) { + LLVMInstruction inst; + inst.opcode = LLVMInstruction::Store; + for (size_t i = 1; i < ops.size(); i++) + inst.operands.push_back((uint32_t)ops[i]); + fn.blocks[cur_block].instructions.push_back(inst); + } + break; + } + case kFuncCode_InstExtractVal: + case kFuncCode_InstInsertVal: + case kFuncCode_InstSelect: + case kFuncCode_InstICmp: + case kFuncCode_InstFCmp: + case kFuncCode_InstUnreachable: + case kFuncCode_InstAlloca: + case kFuncCode_InstExtractElt: + case kFuncCode_InstInsertElt: + case kFuncCode_InstShuffleVec: + case kFuncCode_InstSwitch: + case kFuncCode_InstInvoke: + case kFuncCode_InstPHI: + case kFuncCode_InstBr: + default: + break; + } + } + return false; +} + +std::optional BitcodeReader::parse(const uint8_t *data, uint32_t size) { + LLVMModule module; + + DXTRACE("BitcodeReader::parse size=%u", size); + if (size >= 8) { + DXTRACE(" bytes: %02x %02x %02x %02x %02x %02x %02x %02x", + data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7]); + } + + BitstreamReader reader(data, size); + + uint32_t magic = reader.read(32); + if (magic != 0xDEC04342) + return std::nullopt; + + // DXIL bitcode: after magic, the bitstream starts directly + // No wrapper header — seek past the 4-byte magic + reader.seek(32); + + ParseContext ctx{reader, module, {}, {}}; + + uint32_t bc_abbrev = reader.read(2); + DXTRACE(" bc_abbrev=%u at bit %u", bc_abbrev, reader.tell()); + if (bc_abbrev != kEnterSubBlock) + return std::nullopt; + + auto abbrev_len = readBlockHeader(reader); + DXTRACE(" abbrev_len=%u", abbrev_len.value_or(0)); + if (!abbrev_len) return std::nullopt; + + std::vector pending_fn_types; + + while (!reader.atEnd()) { + uint32_t code = reader.read(*abbrev_len); + if (code == kEndBlock) { + reader.align32(); + break; + } + if (code == kDefineAbbrev) + continue; + + if (code == kEnterSubBlock) { + uint32_t block_id = reader.readVBR(8); + uint32_t new_abbrev_len = reader.readVBR(4); + uint32_t block_len = reader.read(32); + + switch (block_id) { + case kBlockID_Type: { + ParseContext type_ctx{reader, module, {}, {}}; + parseTypeBlock(type_ctx); + break; + } + case kBlockID_Constants: { + ParseContext const_ctx{reader, module, {}, {}}; + parseConstantsBlock(const_ctx); + break; + } + case kBlockID_Function: { + if (!pending_fn_types.empty()) { + uint32_t fn_type = pending_fn_types.back(); + pending_fn_types.pop_back(); + LLVMFunction fn; + fn.type_id = fn_type; + fn.is_declaration = false; + ParseContext func_ctx{reader, module, {}, {}}; + parseFunctionBlock(func_ctx, fn); + module.functions.push_back(fn); + } else { + reader.align32(); + reader.seek(reader.tell() + block_len * 8); + } + break; + } + case kBlockID_ValueSymTab: + case kBlockID_BlockInfo: + default: + reader.align32(); + reader.seek(reader.tell() + block_len * 8); + break; + } + continue; + } + + if (code == kUnabbrevRecord) { + auto ops = readUnabbrevRecord(reader); + uint32_t rec_code = (uint32_t)ops[0]; + + if (rec_code == kModuleCode_Function) { + pending_fn_types.push_back(ops.size() > 1 ? (uint32_t)ops[1] : 0); + } + } + } + + return module; +} + +} diff --git a/src/airconv/dxil/llvm_bitcode.hpp b/src/airconv/dxil/llvm_bitcode.hpp new file mode 100644 index 000000000..80a46f17a --- /dev/null +++ b/src/airconv/dxil/llvm_bitcode.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace dxmt::dxil { + +struct LLVMType { + enum Kind { + Void, + Float, + Double, + Integer, + Pointer, + Struct, + Array, + Vector, + Function, + } kind; + uint32_t bit_width = 0; + std::vector subtypes; +}; + +struct LLVMValue { + enum Kind { + Undef, + Constant, + Instruction, + Argument, + BasicBlock, + Function, + } kind; + uint32_t type_id = 0; + uint32_t id = 0; + std::string name; + std::string constant_data; +}; + +struct LLVMInstruction { + enum Opcode { + Ret = 1, + Br = 2, + Switch = 3, + Invoke = 4, + Unreachable = 8, + Add = 9, + Sub = 11, + Mul = 13, + UDiv = 15, + SDiv = 17, + URem = 19, + SRem = 21, + And = 23, + Or = 24, + Xor = 25, + Shl = 26, + LShr = 27, + AShr = 28, + FAdd = 29, + FSub = 30, + FMul = 31, + FDiv = 32, + FRem = 33, + FNeg = 34, + ExtractValue = 42, + InsertValue = 43, + ExtractElement = 44, + InsertElement = 45, + ShuffleVector = 46, + BitCast = 53, + ZExt = 55, + SExt = 56, + Trunc = 57, + FPToUI = 58, + FPToSI = 59, + UIToFP = 60, + SIToFP = 61, + FPTrunc = 62, + FPExt = 63, + PtrToInt = 64, + IntToPtr = 65, + ICmp = 68, + FCmp = 69, + PHI = 71, + Call = 72, + Select = 73, + GEP = 76, + Load = 81, + Store = 82, + Alloca = 83, + GetElementPtr = 84, + } opcode; + + uint32_t type_id = 0; + uint32_t result_id = 0; + std::vector operands; +}; + +struct LLVMBasicBlock { + std::string name; + std::vector instructions; +}; + +struct LLVMFunction { + std::string name; + uint32_t type_id = 0; + uint32_t calling_conv = 0; + bool is_declaration = true; + std::vector param_types; + LLVMType return_type; + std::vector blocks; + std::vector attributes; +}; + +struct LLVMModule { + std::vector types; + std::vector constants; + std::vector functions; + std::unordered_map function_map; + std::string source_filename; + std::string target_triple; +}; + +class BitcodeReader { +public: + static std::optional parse(const uint8_t *data, uint32_t size); + +private: + BitcodeReader() = default; +}; + +} diff --git a/src/d3d12/d3d12.cpp b/src/d3d12/d3d12.cpp new file mode 100644 index 000000000..a40944e9c --- /dev/null +++ b/src/d3d12/d3d12.cpp @@ -0,0 +1,352 @@ +#define INITGUID +#include "d3d12_dxgi_device.hpp" +#include "d3d12_device.hpp" +#include "com/com_pointer.hpp" +#include "dxgi_interfaces.h" +#include "dxmt_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include +#include +#include +#include + +#pragma pack(push, 1) +struct _RSHeader { + uint32_t num_parameters; + uint32_t num_static_samplers; + uint32_t flags; +}; +struct _RSParameter { + uint8_t type; + uint8_t visibility; + union { + struct { uint32_t register_space; uint32_t register_index; uint32_t num_32bit_values; } constants; + struct { uint32_t register_space; uint32_t register_index; } descriptor; + struct { uint32_t num_ranges; } table; + }; +}; +struct _RSDescriptorRange { + uint8_t range_type; + uint32_t num_descriptors; + uint32_t base_register; + uint32_t register_space; + uint32_t offset_in_table; +}; +struct _RSStaticSampler { + uint32_t filter; + uint32_t address_u; + uint32_t address_v; + uint32_t address_w; + float mip_lod_bias; + uint32_t max_anisotropy; + uint32_t comparison_func; + uint32_t border_color; + float min_lod; + float max_lod; + uint32_t register_space; + uint32_t register_index; + uint32_t shader_register_space; + uint8_t shader_visibility; +}; +#pragma pack(pop) + +class _RSBlob : public ID3DBlob { + ULONG m_ref = 1; + std::vector m_data; +public: + _RSBlob(std::vector &&data) : m_data(std::move(data)) {} + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppv) { + if (riid == IID_IUnknown || riid == IID_ID3D10Blob || riid == __uuidof(ID3DBlob)) { *ppv = this; AddRef(); return S_OK; } + return E_NOINTERFACE; + } + ULONG STDMETHODCALLTYPE AddRef() { return ++m_ref; } + ULONG STDMETHODCALLTYPE Release() { ULONG r = --m_ref; if (!r) delete this; return r; } + LPVOID STDMETHODCALLTYPE GetBufferPointer() { return m_data.data(); } + SIZE_T STDMETHODCALLTYPE GetBufferSize() { return m_data.size(); } +}; + +static HRESULT _SerializeRootSig(const D3D12_ROOT_SIGNATURE_DESC *desc, ID3DBlob **ppBlob) { + if (!desc || !ppBlob) return E_INVALIDARG; + *ppBlob = nullptr; + + std::vector buf; + size_t total = sizeof(_RSHeader) + desc->NumParameters * sizeof(_RSParameter); + for (UINT i = 0; i < desc->NumParameters; i++) { + if (desc->pParameters[i].ParameterType == D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE) + total += desc->pParameters[i].DescriptorTable.NumDescriptorRanges * sizeof(_RSDescriptorRange); + } + total += desc->NumStaticSamplers * sizeof(_RSStaticSampler); + buf.resize(total); + + auto *hdr = reinterpret_cast<_RSHeader *>(buf.data()); + hdr->num_parameters = desc->NumParameters; + hdr->num_static_samplers = desc->NumStaticSamplers; + hdr->flags = desc->Flags; + + uint8_t *ptr = buf.data() + sizeof(_RSHeader); + for (UINT i = 0; i < desc->NumParameters; i++) { + auto &p = desc->pParameters[i]; + auto *out = reinterpret_cast<_RSParameter *>(ptr); + out->type = (uint8_t)p.ParameterType; + out->visibility = (uint8_t)p.ShaderVisibility; + switch (p.ParameterType) { + case D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS: + out->constants.register_space = p.Constants.RegisterSpace; + out->constants.register_index = p.Constants.ShaderRegister; + out->constants.num_32bit_values = p.Constants.Num32BitValues; + break; + case D3D12_ROOT_PARAMETER_TYPE_CBV: + case D3D12_ROOT_PARAMETER_TYPE_SRV: + case D3D12_ROOT_PARAMETER_TYPE_UAV: + out->descriptor.register_space = p.Descriptor.RegisterSpace; + out->descriptor.register_index = p.Descriptor.ShaderRegister; + break; + case D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE: + out->table.num_ranges = p.DescriptorTable.NumDescriptorRanges; + ptr += sizeof(_RSParameter); + for (UINT r = 0; r < p.DescriptorTable.NumDescriptorRanges; r++) { + auto &rng = p.DescriptorTable.pDescriptorRanges[r]; + auto *orng = reinterpret_cast<_RSDescriptorRange *>(ptr); + orng->range_type = (uint8_t)rng.RangeType; + orng->num_descriptors = rng.NumDescriptors; + orng->base_register = rng.BaseShaderRegister; + orng->register_space = rng.RegisterSpace; + orng->offset_in_table = rng.OffsetInDescriptorsFromTableStart; + ptr += sizeof(_RSDescriptorRange); + } + continue; + } + ptr += sizeof(_RSParameter); + } + + for (UINT i = 0; i < desc->NumStaticSamplers; i++) { + auto &s = desc->pStaticSamplers[i]; + auto *out = reinterpret_cast<_RSStaticSampler *>(ptr); + out->filter = s.Filter; + out->address_u = s.AddressU; + out->address_v = s.AddressV; + out->address_w = s.AddressW; + out->mip_lod_bias = s.MipLODBias; + out->max_anisotropy = s.MaxAnisotropy; + out->comparison_func = s.ComparisonFunc; + out->border_color = s.BorderColor; + out->min_lod = s.MinLOD; + out->max_lod = s.MaxLOD; + out->register_space = s.RegisterSpace; + out->register_index = s.ShaderRegister; + out->shader_register_space = s.RegisterSpace; + out->shader_visibility = (uint8_t)s.ShaderVisibility; + ptr += sizeof(_RSStaticSampler); + } + + *ppBlob = new _RSBlob(std::move(buf)); + return S_OK; +} + +using namespace dxmt; + +extern "C" HRESULT WINAPI +D3D12CreateDevice(IUnknown *pAdapter, D3D_FEATURE_LEVEL MinimumFeatureLevel, + REFIID riid, void **ppDevice) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "=== D3D12CreateDevice CALLED FL=%d adapter=%p riid=%s ===\n", MinimumFeatureLevel, pAdapter, str::format(riid).c_str()); fclose(f); } + } + if (!ppDevice) + return E_POINTER; + *ppDevice = nullptr; + + Com dxgi_adapter; + + if (pAdapter) { + if (FAILED(pAdapter->QueryInterface(IID_PPV_ARGS(&dxgi_adapter)))) { + ERR("D3D12CreateDevice: adapter is not a DXMT adapter"); + return E_INVALIDARG; + } + } else { + Com factory; + if (FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&factory)))) { + ERR("D3D12CreateDevice: failed to create DXGI factory"); + return E_FAIL; + } + Com adapter; + if (FAILED(factory->EnumAdapters(0, &adapter))) { + ERR("D3D12CreateDevice: no adapters available"); + return E_FAIL; + } + if (FAILED(adapter->QueryInterface(IID_PPV_ARGS(&dxgi_adapter)))) { + ERR("D3D12CreateDevice: default adapter is not DXMT"); + return E_FAIL; + } + } + + try { + void *device_mem = VirtualAlloc((void*)0x500000000ULL, sizeof(MTLD3D12DXGIDevice), + MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + if (!device_mem) { + device_mem = VirtualAlloc((void*)0x200000000ULL, sizeof(MTLD3D12DXGIDevice), + MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + } + if (!device_mem) { + device_mem = VirtualAlloc(nullptr, sizeof(MTLD3D12DXGIDevice), + MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + } + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "Device allocated at %p size=%zu\n", device_mem, sizeof(MTLD3D12DXGIDevice)); fclose(f); } + } + auto dxgi_device = new (device_mem) MTLD3D12DXGIDevice( + CreateDXMTDevice({.device = dxgi_adapter->GetMTLDevice()}), + dxgi_adapter.ptr()); + + HRESULT hr = dxgi_device->QueryInterface(riid, ppDevice); + if (FAILED(hr)) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12CreateDevice QI FAILED hr=0x%lx FL=%d\n", hr, MinimumFeatureLevel); fclose(f); } + } + dxgi_device->Release(); + return hr; + } + + Logger::info(str::format("D3D12CreateDevice: created device with FL ", + MinimumFeatureLevel)); + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12CreateDevice SUCCESS FL=%d\n", MinimumFeatureLevel); fclose(f); } + } + return S_OK; + } catch (const std::exception &e) { + Logger::err(str::format("D3D12CreateDevice: exception: ", e.what())); + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12CreateDevice EXCEPTION: %s FL=%d\n", e.what(), MinimumFeatureLevel); fclose(f); } + } + return E_FAIL; + } +} + +extern "C" HRESULT WINAPI +D3D12SerializeRootSignature(const D3D12_ROOT_SIGNATURE_DESC *pRootSignature, + D3D_ROOT_SIGNATURE_VERSION Version, + ID3DBlob **ppBlob, ID3DBlob **ppErrorBlob) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12SerializeRootSignature version=%u params=%u\n", Version, pRootSignature ? pRootSignature->NumParameters : 0); fclose(f); } + } + return _SerializeRootSig(pRootSignature, ppBlob); +} + +extern "C" HRESULT WINAPI +D3D12SerializeVersionedRootSignature( + const D3D12_VERSIONED_ROOT_SIGNATURE_DESC *pRootSignature, + ID3DBlob **ppBlob, ID3DBlob **ppErrorBlob) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12SerializeVersionedRootSignature version=%u\n", pRootSignature ? pRootSignature->Version : 0); fclose(f); } + } + if (!pRootSignature) return E_INVALIDARG; + if (pRootSignature->Version == D3D_ROOT_SIGNATURE_VERSION_1_0) + return _SerializeRootSig(&pRootSignature->Desc_1_0, ppBlob); + if (pRootSignature->Version == D3D_ROOT_SIGNATURE_VERSION_1_1) { + const auto &d1 = pRootSignature->Desc_1_1; + D3D12_ROOT_SIGNATURE_DESC desc0 = {}; + desc0.NumParameters = d1.NumParameters; + desc0.NumStaticSamplers = d1.NumStaticSamplers; + desc0.pStaticSamplers = d1.pStaticSamplers; + desc0.Flags = d1.Flags; + std::vector params(d1.NumParameters); + std::vector ranges; + for (UINT i = 0; i < d1.NumParameters; i++) { + auto &src = d1.pParameters[i]; + auto &dst = params[i]; + dst.ParameterType = src.ParameterType; + dst.ShaderVisibility = src.ShaderVisibility; + switch (src.ParameterType) { + case D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS: + dst.Constants = src.Constants; break; + case D3D12_ROOT_PARAMETER_TYPE_CBV: + case D3D12_ROOT_PARAMETER_TYPE_SRV: + case D3D12_ROOT_PARAMETER_TYPE_UAV: + dst.Descriptor.ShaderRegister = src.Descriptor.ShaderRegister; + dst.Descriptor.RegisterSpace = src.Descriptor.RegisterSpace; break; + case D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE: { + dst.DescriptorTable.NumDescriptorRanges = src.DescriptorTable.NumDescriptorRanges; + size_t base = ranges.size(); + for (UINT r = 0; r < src.DescriptorTable.NumDescriptorRanges; r++) { + auto &rs = src.DescriptorTable.pDescriptorRanges[r]; + D3D12_DESCRIPTOR_RANGE dr = {}; + dr.RangeType = rs.RangeType; + dr.NumDescriptors = rs.NumDescriptors; + dr.BaseShaderRegister = rs.BaseShaderRegister; + dr.RegisterSpace = rs.RegisterSpace; + dr.OffsetInDescriptorsFromTableStart = rs.OffsetInDescriptorsFromTableStart; + ranges.push_back(dr); + } + dst.DescriptorTable.pDescriptorRanges = ranges.data() + base; + break; + } + } + } + desc0.pParameters = params.data(); + return _SerializeRootSig(&desc0, ppBlob); + } + return E_INVALIDARG; +} + +extern "C" HRESULT WINAPI +D3D12CreateRootSignatureDeserializer(const void *pData, SIZE_T NumBytes, + REFIID riid, void **ppDeserializer) { + Logger::info("D3D12CreateRootSignatureDeserializer: stub"); + return E_NOTIMPL; +} + +extern "C" HRESULT WINAPI D3D12GetDebugInterface(REFIID riid, + void **ppDebug) { + return E_NOINTERFACE; +} + +extern "C" UINT WINAPI D3D12SDKVersion() { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12SDKVersion() -> 606\n"); fclose(f); } + return 606; +} + +extern "C" HRESULT WINAPI D3D12GetInterface(REFCLSID clsid, REFIID riid, void **ppv) { + if (!ppv) + return E_POINTER; + *ppv = nullptr; + Logger::warn(str::format("D3D12GetInterface: clsid=", clsid, " riid=", riid, " -> E_NOINTERFACE")); + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "D3D12GetInterface clsid=%s riid=%s -> E_NOINTERFACE\n", str::format(clsid).c_str(), str::format(riid).c_str()); fclose(f); } + } + return E_NOINTERFACE; +} + +#ifdef _WIN32 +extern void install_crash_handler(); +BOOL WINAPI DllMain(HINSTANCE instance, DWORD reason, LPVOID reserved) { + if (reason == DLL_PROCESS_ATTACH) { + DisableThreadLibraryCalls(instance); + install_crash_handler(); + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + char exe[MAX_PATH]; + GetModuleFileNameA(NULL, exe, MAX_PATH); + fprintf(f, "=== d3d12.dll DllMain PROCESS_ATTACH pid=%lu exe=[%s] ===\n", GetCurrentProcessId(), exe); + fclose(f); + } + } else if (reason == DLL_PROCESS_DETACH) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + char exe[MAX_PATH]; + GetModuleFileNameA(NULL, exe, MAX_PATH); + fprintf(f, "=== d3d12.dll DllMain PROCESS_DETACH pid=%lu exe=[%s] ===\n", GetCurrentProcessId(), exe); + fclose(f); + } + } + return TRUE; +} +#endif diff --git a/src/d3d12/d3d12.def b/src/d3d12/d3d12.def new file mode 100644 index 000000000..229ca9314 --- /dev/null +++ b/src/d3d12/d3d12.def @@ -0,0 +1,9 @@ +LIBRARY D3D12.DLL +EXPORTS + D3D12CreateDevice + D3D12CreateRootSignatureDeserializer + D3D12GetDebugInterface + D3D12GetInterface + D3D12SDKVersion DATA + D3D12SerializeRootSignature + D3D12SerializeVersionedRootSignature diff --git a/src/d3d12/d3d12_command_allocator.cpp b/src/d3d12/d3d12_command_allocator.cpp new file mode 100644 index 000000000..c6a380f8d --- /dev/null +++ b/src/d3d12/d3d12_command_allocator.cpp @@ -0,0 +1,85 @@ +#include "d3d12_command_allocator.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +#define CATRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "CmdAlloc::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12CommandAllocator::MTLD3D12CommandAllocator(MTLD3D12Device *device, + D3D12_COMMAND_LIST_TYPE type) + : m_device(device), m_type(type) { + m_device->AddRef(); +} + +MTLD3D12CommandAllocator::~MTLD3D12CommandAllocator() { + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12CommandAllocator) { + *ppvObject = ref(this); + return S_OK; + } + CATRACE("QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12CommandAllocator::AddRef() { + return ++m_refCount; +} + +ULONG STDMETHODCALLTYPE MTLD3D12CommandAllocator::Release() { + uint32_t rc = --m_refCount; + if (!rc) { + uint32_t rp = --m_refPrivate; + if (!rp) { + m_refPrivate += 0x80000000; + delete this; + } + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + CATRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandAllocator::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12CommandAllocator::Reset() { + return S_OK; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_command_allocator.hpp b/src/d3d12/d3d12_command_allocator.hpp new file mode 100644 index 000000000..d6c3c991e --- /dev/null +++ b/src/d3d12/d3d12_command_allocator.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12CommandAllocator : public ID3D12CommandAllocator { +public: + MTLD3D12CommandAllocator(MTLD3D12Device *device, + D3D12_COMMAND_LIST_TYPE type); + ~MTLD3D12CommandAllocator(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + HRESULT STDMETHODCALLTYPE Reset() override; + + D3D12_COMMAND_LIST_TYPE GetType() const { return m_type; } + +private: + MTLD3D12Device *m_device; + D3D12_COMMAND_LIST_TYPE m_type; + std::atomic m_refCount = {1ul}; + std::atomic m_refPrivate = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_command_list.cpp b/src/d3d12/d3d12_command_list.cpp new file mode 100644 index 000000000..2208d28b0 --- /dev/null +++ b/src/d3d12/d3d12_command_list.cpp @@ -0,0 +1,710 @@ +#include "d3d12_command_list.hpp" +#include "d3d12_command_allocator.hpp" +#include "d3d12_device.hpp" +#include "d3d12_pipeline_state.hpp" +#include "d3d12_resource.hpp" +#include "d3d12_root_signature.hpp" +#include "d3d12_descriptor_heap.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +#define CLTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "CmdList::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12GraphicsCommandList::MTLD3D12GraphicsCommandList( + MTLD3D12Device *device, MTLD3D12CommandAllocator *allocator, + D3D12_COMMAND_LIST_TYPE type, ID3D12PipelineState *initial_state) + : m_device(device), m_allocator(allocator), m_type(type) { + m_device->AddRef(); + if (m_allocator) + m_allocator->AddRef(); +} + +MTLD3D12GraphicsCommandList::~MTLD3D12GraphicsCommandList() { + if (m_allocator) + m_allocator->Release(); + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12CommandList || + riid == IID_ID3D12GraphicsCommandList || + riid == IID_ID3D12GraphicsCommandList1 || + riid == IID_ID3D12GraphicsCommandList2 || + riid == IID_ID3D12GraphicsCommandList3 || + riid == IID_ID3D12GraphicsCommandList4 || + riid == IID_ID3D12GraphicsCommandList5 || + riid == IID_ID3D12GraphicsCommandList6) { + *ppvObject = ref(this); + return S_OK; + } + CLTRACE("QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::AddRef() { + return ++m_refCount; +} + +ULONG STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::Release() { + uint32_t rc = --m_refCount; + if (!rc) { + uint32_t rp = --m_refPrivate; + if (!rp) { + m_refPrivate += 0x80000000; + delete this; + } + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + CLTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetName(LPCWSTR name) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +D3D12_COMMAND_LIST_TYPE STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::GetType() { + return m_type; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::Close() { + CLTRACE("Close"); + m_closed = true; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::Reset( + ID3D12CommandAllocator *allocator, ID3D12PipelineState *initial_state) { + CLTRACE("Reset"); + m_closed = false; + m_cmds.clear(); + return S_OK; +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::ClearState(ID3D12PipelineState *pipeline_state) { + m_cmds.clear(); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::DrawInstanced( + UINT vertex_count, UINT instance_count, UINT start_vertex, + UINT start_instance) { + CLTRACE("DrawInstanced v=%u i=%u", vertex_count, instance_count); + CmdDrawInstanced cmd = {}; + cmd.header = {CmdType::DrawInstanced, sizeof(cmd)}; + cmd.vertex_count = vertex_count; + cmd.instance_count = instance_count; + cmd.start_vertex = start_vertex; + cmd.start_instance = start_instance; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::DrawIndexedInstanced( + UINT index_count, UINT instance_count, UINT start_vertex, + INT base_vertex, UINT start_instance) { + CLTRACE("DrawIndexedInstanced idx=%u i=%u", index_count, instance_count); + CmdDrawIndexedInstanced cmd = {}; + cmd.header = {CmdType::DrawIndexedInstanced, sizeof(cmd)}; + cmd.index_count = index_count; + cmd.instance_count = instance_count; + cmd.start_vertex = start_vertex; + cmd.base_vertex = base_vertex; + cmd.start_instance = start_instance; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::Dispatch(UINT x, UINT y, + UINT z) { + CLTRACE("Dispatch %ux%ux%u", x, y, z); + CmdDispatch cmd = {}; + cmd.header = {CmdType::Dispatch, sizeof(cmd)}; + cmd.x = x; + cmd.y = y; + cmd.z = z; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::CopyBufferRegion( + ID3D12Resource *dst, UINT64 dst_offset, ID3D12Resource *src, + UINT64 src_offset, UINT64 byte_count) { + CmdCopyBufferRegion cmd = {}; + cmd.header = {CmdType::CopyBufferRegion, sizeof(cmd)}; + cmd.dst = dst; + cmd.dst_offset = dst_offset; + cmd.src = src; + cmd.src_offset = src_offset; + cmd.byte_count = byte_count; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::CopyTextureRegion( + const D3D12_TEXTURE_COPY_LOCATION *dst, UINT dst_x, UINT dst_y, + UINT dst_z, const D3D12_TEXTURE_COPY_LOCATION *src, + const D3D12_BOX *src_box) { + if (!dst || !src) return; + CmdCopyTextureRegion cmd = {}; + cmd.header = {CmdType::CopyTextureRegion, sizeof(cmd)}; + cmd.dst_resource = dst->pResource; + cmd.dst_type = dst->Type; + cmd.dst_x = dst_x; + cmd.dst_y = dst_y; + cmd.dst_z = dst_z; + if (dst->Type == D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX) { + cmd.dst_subresource = dst->SubresourceIndex; + } else { + cmd.dst_offset = dst->PlacedFootprint.Offset; + cmd.dst_footprint_width = dst->PlacedFootprint.Footprint.Width; + cmd.dst_footprint_height = dst->PlacedFootprint.Footprint.Height; + cmd.dst_footprint_depth = dst->PlacedFootprint.Footprint.Depth; + cmd.dst_footprint_row_pitch = dst->PlacedFootprint.Footprint.RowPitch; + } + cmd.src_resource = src->pResource; + cmd.src_type = src->Type; + if (src->Type == D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX) { + cmd.src_subresource = src->SubresourceIndex; + } else { + cmd.src_offset = src->PlacedFootprint.Offset; + cmd.src_footprint_width = src->PlacedFootprint.Footprint.Width; + cmd.src_footprint_height = src->PlacedFootprint.Footprint.Height; + cmd.src_footprint_depth = src->PlacedFootprint.Footprint.Depth; + cmd.src_footprint_row_pitch = src->PlacedFootprint.Footprint.RowPitch; + } + if (src_box) { + cmd.src_box = *src_box; + cmd.has_src_box = 1; + } + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::CopyResource(ID3D12Resource *dst, + ID3D12Resource *src) { + if (!dst || !src) return; + CmdCopyResource cmd = {}; + cmd.header = {CmdType::CopyResource, sizeof(cmd)}; + cmd.dst = dst; + cmd.src = src; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::CopyTiles( + ID3D12Resource *tiled_resource, + const D3D12_TILED_RESOURCE_COORDINATE *tile_region_start_coordinate, + const D3D12_TILE_REGION_SIZE *tile_region_size, + ID3D12Resource *buffer, UINT64 buffer_offset, + D3D12_TILE_COPY_FLAGS flags) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ResolveSubresource( + ID3D12Resource *dst, UINT dst_sub, ID3D12Resource *src, UINT src_sub, + DXGI_FORMAT format) { + CmdResolveSubresource cmd = {}; + cmd.header = {CmdType::ResolveSubresource, sizeof(cmd)}; + cmd.dst = dst; + cmd.dst_sub = dst_sub; + cmd.src = src; + cmd.src_sub = src_sub; + cmd.format = format; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::IASetPrimitiveTopology( + D3D12_PRIMITIVE_TOPOLOGY topology) { + CmdIASetPrimitiveTopology cmd = {}; + cmd.header = {CmdType::IASetPrimitiveTopology, sizeof(cmd)}; + cmd.topology = topology; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::RSSetViewports( + UINT count, const D3D12_VIEWPORT *viewports) { + size_t extra = count * sizeof(D3D12_VIEWPORT); + auto total = sizeof(CmdRSSetViewports) - sizeof(D3D12_VIEWPORT) + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdRSSetViewports cmd = {}; + cmd.header = {CmdType::RSSetViewports, (uint32_t)total}; + cmd.count = count; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdRSSetViewports) - sizeof(D3D12_VIEWPORT)); + memcpy(m_cmds.data() + offset + sizeof(CmdRSSetViewports) - sizeof(D3D12_VIEWPORT), + viewports, extra); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::RSSetScissorRects( + UINT count, const D3D12_RECT *rects) { + size_t extra = count * sizeof(D3D12_RECT); + auto total = sizeof(CmdRSSetScissorRects) - sizeof(D3D12_RECT) + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdRSSetScissorRects cmd = {}; + cmd.header = {CmdType::RSSetScissorRects, (uint32_t)total}; + cmd.count = count; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdRSSetScissorRects) - sizeof(D3D12_RECT)); + memcpy(m_cmds.data() + offset + sizeof(CmdRSSetScissorRects) - sizeof(D3D12_RECT), + rects, extra); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::OMSetBlendFactor(const FLOAT blend_factor[4]) { + CmdOMBlendFactor cmd = {}; + cmd.header = {CmdType::OMSetBlendFactor, sizeof(cmd)}; + memcpy(cmd.factor, blend_factor, 16); + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::OMSetStencilRef(UINT stencil_ref) { + CmdOMStencilRef cmd = {}; + cmd.header = {CmdType::OMSetStencilRef, sizeof(cmd)}; + cmd.stencil_ref = stencil_ref; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetPipelineState( + ID3D12PipelineState *pipeline_state) { + CmdSetPipelineState cmd = {}; + cmd.header = {CmdType::SetPipelineState, sizeof(cmd)}; + cmd.pso = pipeline_state; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ResourceBarrier( + UINT barrier_count, const D3D12_RESOURCE_BARRIER *barriers) { + size_t extra = barrier_count * sizeof(D3D12_RESOURCE_BARRIER); + auto total = sizeof(CmdResourceBarrier) - sizeof(D3D12_RESOURCE_BARRIER) + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdResourceBarrier cmd = {}; + cmd.header = {CmdType::ResourceBarrier, (uint32_t)total}; + cmd.count = barrier_count; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdResourceBarrier) - sizeof(D3D12_RESOURCE_BARRIER)); + memcpy(m_cmds.data() + offset + sizeof(CmdResourceBarrier) - sizeof(D3D12_RESOURCE_BARRIER), + barriers, extra); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ExecuteBundle( + ID3D12GraphicsCommandList *command_list) { + CLTRACE("ExecuteBundle cmds=%zu", command_list ? static_cast(command_list)->GetCommands().size() : 0); + if (command_list) { + auto *bundle = static_cast(command_list); + const auto &bundle_cmds = bundle->GetCommands(); + m_cmds.insert(m_cmds.end(), bundle_cmds.begin(), bundle_cmds.end()); + } +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetDescriptorHeaps( + UINT heap_count, ID3D12DescriptorHeap *const *heaps) { + size_t extra = heap_count * sizeof(ID3D12DescriptorHeap *); + auto total = sizeof(CmdSetDescriptorHeaps) - sizeof(ID3D12DescriptorHeap *) + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdSetDescriptorHeaps cmd = {}; + cmd.header = {CmdType::SetDescriptorHeaps, (uint32_t)total}; + cmd.count = heap_count; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdSetDescriptorHeaps) - sizeof(ID3D12DescriptorHeap *)); + memcpy(m_cmds.data() + offset + sizeof(CmdSetDescriptorHeaps) - sizeof(ID3D12DescriptorHeap *), + heaps, extra); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetComputeRootSignature( + ID3D12RootSignature *root_signature) { + CmdSetRootSignature cmd = {}; + cmd.header = {CmdType::SetComputeRootSignature, sizeof(cmd)}; + cmd.root_sig = root_signature; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetGraphicsRootSignature( + ID3D12RootSignature *root_signature) { + CmdSetRootSignature cmd = {}; + cmd.header = {CmdType::SetGraphicsRootSignature, sizeof(cmd)}; + cmd.root_sig = root_signature; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetComputeRootDescriptorTable( + UINT root_parameter_index, + D3D12_GPU_DESCRIPTOR_HANDLE base_descriptor) { + CmdSetRootDescriptorTable cmd = {}; + cmd.header = {CmdType::SetComputeRootDescriptorTable, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.base_descriptor = base_descriptor; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetGraphicsRootDescriptorTable( + UINT root_parameter_index, + D3D12_GPU_DESCRIPTOR_HANDLE base_descriptor) { + CmdSetRootDescriptorTable cmd = {}; + cmd.header = {CmdType::SetGraphicsRootDescriptorTable, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.base_descriptor = base_descriptor; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetComputeRoot32BitConstant( + UINT root_parameter_index, UINT data, UINT dst_offset) { + CmdSetRoot32BitConstants cmd = {}; + cmd.header = {CmdType::SetComputeRoot32BitConstants, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.count = 1; + cmd.dst_offset = dst_offset; + memcpy(cmd.data, &data, 4); + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetGraphicsRoot32BitConstant( + UINT root_parameter_index, UINT data, UINT dst_offset) { + CmdSetRoot32BitConstants cmd = {}; + cmd.header = {CmdType::SetGraphicsRoot32BitConstants, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.count = 1; + cmd.dst_offset = dst_offset; + memcpy(cmd.data, &data, 4); + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetComputeRoot32BitConstants( + UINT root_parameter_index, UINT constant_count, const void *data, + UINT dst_offset) { + size_t extra = constant_count * 4; + auto total = sizeof(CmdSetRoot32BitConstants) - 1 + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdSetRoot32BitConstants cmd = {}; + cmd.header = {CmdType::SetComputeRoot32BitConstants, (uint32_t)total}; + cmd.root_param_index = root_parameter_index; + cmd.count = constant_count; + cmd.dst_offset = dst_offset; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdSetRoot32BitConstants) - 1); + memcpy(m_cmds.data() + offset + sizeof(CmdSetRoot32BitConstants) - 1, data, extra); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetGraphicsRoot32BitConstants( + UINT root_parameter_index, UINT constant_count, const void *data, + UINT dst_offset) { + size_t extra = constant_count * 4; + auto total = sizeof(CmdSetRoot32BitConstants) - 1 + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdSetRoot32BitConstants cmd = {}; + cmd.header = {CmdType::SetGraphicsRoot32BitConstants, (uint32_t)total}; + cmd.root_param_index = root_parameter_index; + cmd.count = constant_count; + cmd.dst_offset = dst_offset; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdSetRoot32BitConstants) - 1); + memcpy(m_cmds.data() + offset + sizeof(CmdSetRoot32BitConstants) - 1, data, extra); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetComputeRootConstantBufferView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetComputeRootConstantBufferView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetGraphicsRootConstantBufferView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetGraphicsRootConstantBufferView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetComputeRootShaderResourceView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetComputeRootShaderResourceView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetGraphicsRootShaderResourceView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetGraphicsRootShaderResourceView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetComputeRootUnorderedAccessView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetComputeRootUnorderedAccessView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::SetGraphicsRootUnorderedAccessView( + UINT root_parameter_index, D3D12_GPU_VIRTUAL_ADDRESS address) { + CmdSetRootCBV cmd = {}; + cmd.header = {CmdType::SetGraphicsRootUnorderedAccessView, sizeof(cmd)}; + cmd.root_param_index = root_parameter_index; + cmd.address = address; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::IASetIndexBuffer( + const D3D12_INDEX_BUFFER_VIEW *view) { + CmdIASetIndexBuffer cmd = {}; + cmd.header = {CmdType::IASetIndexBuffer, sizeof(cmd)}; + if (view) + cmd.view = *view; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::IASetVertexBuffers( + UINT start_slot, UINT count, + const D3D12_VERTEX_BUFFER_VIEW *views) { + size_t extra = count * sizeof(D3D12_VERTEX_BUFFER_VIEW); + auto total = sizeof(CmdIASetVertexBuffers) - sizeof(D3D12_VERTEX_BUFFER_VIEW) + extra; + auto offset = m_cmds.size(); + m_cmds.resize(offset + total); + CmdIASetVertexBuffers cmd = {}; + cmd.header = {CmdType::IASetVertexBuffers, (uint32_t)total}; + cmd.start_slot = start_slot; + cmd.count = count; + memcpy(m_cmds.data() + offset, &cmd, sizeof(CmdIASetVertexBuffers) - sizeof(D3D12_VERTEX_BUFFER_VIEW)); + if (views) + memcpy(m_cmds.data() + offset + sizeof(CmdIASetVertexBuffers) - sizeof(D3D12_VERTEX_BUFFER_VIEW), + views, extra); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SOSetTargets( + UINT start_slot, UINT view_count, + const D3D12_STREAM_OUTPUT_BUFFER_VIEW *views) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::OMSetRenderTargets( + UINT rt_count, const D3D12_CPU_DESCRIPTOR_HANDLE *rts, + WINBOOL single_handle, + const D3D12_CPU_DESCRIPTOR_HANDLE *dsv) { + CmdOMSetRenderTargets cmd = {}; + cmd.header = {CmdType::OMSetRenderTargets, sizeof(cmd)}; + cmd.rt_count = rt_count; + cmd.single_handle = single_handle != 0; + cmd.has_dsv = dsv != nullptr; + if (rts) { + for (UINT i = 0; i < rt_count && i < 8; i++) + cmd.rts[i] = rts[i]; + } + if (dsv) + cmd.dsv = *dsv; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ClearDepthStencilView( + D3D12_CPU_DESCRIPTOR_HANDLE dsv, D3D12_CLEAR_FLAGS flags, FLOAT depth, + UINT8 stencil, UINT rect_count, const D3D12_RECT *rects) { + CmdClearDSV cmd = {}; + cmd.header = {CmdType::ClearDepthStencilView, sizeof(cmd)}; + cmd.dsv = dsv; + cmd.flags = flags; + cmd.depth = depth; + cmd.stencil = stencil; + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ClearRenderTargetView( + D3D12_CPU_DESCRIPTOR_HANDLE rtv, const FLOAT color[4], UINT rect_count, + const D3D12_RECT *rects) { + CmdClearRTV cmd = {}; + cmd.header = {CmdType::ClearRenderTargetView, sizeof(cmd)}; + cmd.rtv = rtv; + memcpy(cmd.color, color, 16); + Emit(cmd); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ClearUnorderedAccessViewUint( + D3D12_GPU_DESCRIPTOR_HANDLE gpu_handle, + D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle, ID3D12Resource *resource, + const UINT values[4], UINT rect_count, const D3D12_RECT *rects) {} + +void STDMETHODCALLTYPE +MTLD3D12GraphicsCommandList::ClearUnorderedAccessViewFloat( + D3D12_GPU_DESCRIPTOR_HANDLE gpu_handle, + D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle, ID3D12Resource *resource, + const float values[4], UINT rect_count, const D3D12_RECT *rects) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::DiscardResource( + ID3D12Resource *resource, const D3D12_DISCARD_REGION *region) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::BeginQuery( + ID3D12QueryHeap *heap, D3D12_QUERY_TYPE type, UINT index) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::EndQuery( + ID3D12QueryHeap *heap, D3D12_QUERY_TYPE type, UINT index) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ResolveQueryData( + ID3D12QueryHeap *heap, D3D12_QUERY_TYPE type, UINT start_index, + UINT query_count, ID3D12Resource *dst_buffer, + UINT64 aligned_dst_buffer_offset) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetPredication( + ID3D12Resource *buffer, UINT64 aligned_buffer_offset, + D3D12_PREDICATION_OP operation) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetMarker( + UINT metadata, const void *data, UINT size) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::BeginEvent( + UINT metadata, const void *data, UINT size) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::EndEvent() {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ExecuteIndirect( + ID3D12CommandSignature *command_signature, UINT max_command_count, + ID3D12Resource *arg_buffer, UINT64 arg_buffer_offset, + ID3D12Resource *count_buffer, UINT64 count_buffer_offset) { + CLTRACE("ExecuteIndirect max=%u", max_command_count); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::AtomicCopyBufferUINT( + ID3D12Resource *dst_buffer, UINT64 dst_offset, + ID3D12Resource *src_buffer, UINT64 src_offset, + UINT dependent_resource_count, + ID3D12Resource *const *dependent_resources, + const D3D12_SUBRESOURCE_RANGE_UINT64 *dependent_sub_resource_ranges) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::AtomicCopyBufferUINT64( + ID3D12Resource *dst_buffer, UINT64 dst_offset, + ID3D12Resource *src_buffer, UINT64 src_offset, + UINT dependent_resource_count, + ID3D12Resource *const *dependent_resources, + const D3D12_SUBRESOURCE_RANGE_UINT64 *dependent_sub_resource_ranges) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::OMSetDepthBounds( + FLOAT min, FLOAT max) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetSamplePositions( + UINT sample_count, UINT pixel_count, + D3D12_SAMPLE_POSITION *sample_positions) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ResolveSubresourceRegion( + ID3D12Resource *dst_resource, UINT dst_sub_resource_idx, + UINT dst_x, UINT dst_y, + ID3D12Resource *src_resource, UINT src_sub_resource_idx, + D3D12_RECT *src_rect, DXGI_FORMAT format, + D3D12_RESOLVE_MODE mode) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetViewInstanceMask( + UINT mask) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::WriteBufferImmediate( + UINT count, const D3D12_WRITEBUFFERIMMEDIATE_PARAMETER *parameters, + const D3D12_WRITEBUFFERIMMEDIATE_MODE *modes) {} + +/*** ID3D12GraphicsCommandList3 ***/ +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetProtectedResourceSession( + ID3D12ProtectedResourceSession *protected_session) { + CLTRACE("SetProtectedResourceSession -> noop"); +} + +/*** ID3D12GraphicsCommandList4 ***/ +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::BeginRenderPass( + UINT num_render_targets, + const D3D12_RENDER_PASS_RENDER_TARGET_DESC *render_targets, + const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC *depth_stencil, + D3D12_RENDER_PASS_FLAGS flags) { + CLTRACE("BeginRenderPass numRT=%u flags=0x%x", num_render_targets, (unsigned)flags); + + if (render_targets && num_render_targets > 0) { + D3D12_CPU_DESCRIPTOR_HANDLE rt_handles[8]; + for (UINT i = 0; i < num_render_targets && i < 8; i++) { + rt_handles[i] = render_targets[i].cpuDescriptor; + } + OMSetRenderTargets(num_render_targets, rt_handles, FALSE, + depth_stencil ? &depth_stencil->cpuDescriptor : nullptr); + } +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::EndRenderPass() { + CLTRACE("EndRenderPass"); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::InitializeMetaCommand( + ID3D12MetaCommand *meta_command, const void *initialization_parameters_data, + SIZE_T initialization_parameters_data_size_in_bytes) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::ExecuteMetaCommand( + ID3D12MetaCommand *meta_command, const void *execution_parameters_data, + SIZE_T execution_parameters_data_size_in_bytes) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::BuildRaytracingAccelerationStructure( + const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC *desc, + UINT num_post_build_info_descs, + const D3D12_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO_DESC *post_build_info_descs) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::EmitRaytracingAccelerationStructurePostbuildInfo( + const D3D12_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO_DESC *descs, + UINT num_acceleration_structures, + const D3D12_GPU_VIRTUAL_ADDRESS *source_acceleration_structure_data) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::CopyRaytracingAccelerationStructure( + D3D12_GPU_VIRTUAL_ADDRESS dest_acceleration_structure_data, + D3D12_GPU_VIRTUAL_ADDRESS source_acceleration_structure_data, + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_COPY_MODE mode) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::SetPipelineState1( + ID3D12StateObject *state_object) { + CLTRACE("SetPipelineState1 -> noop (raytracing)"); +} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::DispatchRays( + const D3D12_DISPATCH_RAYS_DESC *desc) {} + +/*** ID3D12GraphicsCommandList5 ***/ +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::RSSetShadingRate( + D3D12_SHADING_RATE base_shading_rate, + const D3D12_SHADING_RATE_COMBINER *combiners) {} + +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::RSSetShadingRateImage( + ID3D12Resource *shading_rate_image) {} + +/*** ID3D12GraphicsCommandList6 ***/ +void STDMETHODCALLTYPE MTLD3D12GraphicsCommandList::DispatchMesh( + UINT thread_group_count_x, UINT thread_group_count_y, + UINT thread_group_count_z) {} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_command_list.hpp b/src/d3d12/d3d12_command_list.hpp new file mode 100644 index 000000000..6766d1cd5 --- /dev/null +++ b/src/d3d12/d3d12_command_list.hpp @@ -0,0 +1,506 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "Metal.hpp" +#include +#include +#include +#include + +namespace dxmt { + +class MTLD3D12Device; +class MTLD3D12CommandAllocator; + +enum class CmdType : uint32_t { + DrawInstanced, + DrawIndexedInstanced, + Dispatch, + CopyBufferRegion, + CopyTextureRegion, + CopyResource, + SetPipelineState, + SetGraphicsRootSignature, + SetComputeRootSignature, + SetGraphicsRoot32BitConstants, + SetComputeRoot32BitConstants, + SetGraphicsRootConstantBufferView, + SetComputeRootConstantBufferView, + SetGraphicsRootShaderResourceView, + SetComputeRootShaderResourceView, + SetGraphicsRootUnorderedAccessView, + SetComputeRootUnorderedAccessView, + SetGraphicsRootDescriptorTable, + SetComputeRootDescriptorTable, + IASetPrimitiveTopology, + IASetVertexBuffers, + IASetIndexBuffer, + RSSetViewports, + RSSetScissorRects, + OMSetRenderTargets, + OMSetBlendFactor, + OMSetStencilRef, + ClearRenderTargetView, + ClearDepthStencilView, + ResourceBarrier, + SetDescriptorHeaps, + ResolveSubresource, +}; + +struct CmdHeader { + CmdType type; + uint32_t size; +}; + +struct CmdDrawInstanced { + CmdHeader header; + uint32_t vertex_count; + uint32_t instance_count; + uint32_t start_vertex; + uint32_t start_instance; +}; + +struct CmdDrawIndexedInstanced { + CmdHeader header; + uint32_t index_count; + uint32_t instance_count; + uint32_t start_vertex; + int32_t base_vertex; + uint32_t start_instance; +}; + +struct CmdDispatch { + CmdHeader header; + uint32_t x, y, z; +}; + +struct CmdCopyBufferRegion { + CmdHeader header; + ID3D12Resource *dst; + uint64_t dst_offset; + ID3D12Resource *src; + uint64_t src_offset; + uint64_t byte_count; +}; + +struct CmdCopyTextureRegion { + CmdHeader header; + ID3D12Resource *dst_resource; + D3D12_TEXTURE_COPY_TYPE dst_type; + UINT dst_subresource; + UINT64 dst_offset; + UINT dst_footprint_width; + UINT dst_footprint_height; + UINT dst_footprint_depth; + UINT dst_footprint_row_pitch; + UINT dst_x, dst_y, dst_z; + ID3D12Resource *src_resource; + D3D12_TEXTURE_COPY_TYPE src_type; + UINT src_subresource; + UINT64 src_offset; + UINT src_footprint_width; + UINT src_footprint_height; + UINT src_footprint_depth; + UINT src_footprint_row_pitch; + D3D12_BOX src_box; + UINT8 has_src_box; +}; + +struct CmdCopyResource { + CmdHeader header; + ID3D12Resource *dst; + ID3D12Resource *src; +}; + +struct CmdSetPipelineState { + CmdHeader header; + ID3D12PipelineState *pso; +}; + +struct CmdSetRootSignature { + CmdHeader header; + ID3D12RootSignature *root_sig; +}; + +struct CmdSetRoot32BitConstants { + CmdHeader header; + uint32_t root_param_index; + uint32_t count; + uint32_t dst_offset; + uint8_t data[1]; +}; + +struct CmdSetRootCBV { + CmdHeader header; + uint32_t root_param_index; + D3D12_GPU_VIRTUAL_ADDRESS address; +}; + +struct CmdSetRootDescriptorTable { + CmdHeader header; + uint32_t root_param_index; + D3D12_GPU_DESCRIPTOR_HANDLE base_descriptor; +}; + +struct CmdIASetPrimitiveTopology { + CmdHeader header; + D3D12_PRIMITIVE_TOPOLOGY topology; +}; + +struct CmdIASetVertexBuffers { + CmdHeader header; + uint32_t start_slot; + uint32_t count; + D3D12_VERTEX_BUFFER_VIEW views[1]; +}; + +struct CmdIASetIndexBuffer { + CmdHeader header; + D3D12_INDEX_BUFFER_VIEW view; +}; + +struct CmdRSSetViewports { + CmdHeader header; + uint32_t count; + D3D12_VIEWPORT viewports[1]; +}; + +struct CmdRSSetScissorRects { + CmdHeader header; + uint32_t count; + D3D12_RECT rects[1]; +}; + +struct CmdOMSetRenderTargets { + CmdHeader header; + uint32_t rt_count; + bool single_handle; + D3D12_CPU_DESCRIPTOR_HANDLE rts[8]; + D3D12_CPU_DESCRIPTOR_HANDLE dsv; + bool has_dsv; +}; + +struct CmdOMBlendFactor { + CmdHeader header; + float factor[4]; +}; + +struct CmdOMStencilRef { + CmdHeader header; + uint32_t stencil_ref; +}; + +struct CmdClearRTV { + CmdHeader header; + D3D12_CPU_DESCRIPTOR_HANDLE rtv; + float color[4]; +}; + +struct CmdClearDSV { + CmdHeader header; + D3D12_CPU_DESCRIPTOR_HANDLE dsv; + D3D12_CLEAR_FLAGS flags; + float depth; + uint8_t stencil; +}; + +struct CmdResourceBarrier { + CmdHeader header; + uint32_t count; + D3D12_RESOURCE_BARRIER barriers[1]; +}; + +struct CmdSetDescriptorHeaps { + CmdHeader header; + uint32_t count; + ID3D12DescriptorHeap *heaps[1]; +}; + +struct CmdResolveSubresource { + CmdHeader header; + ID3D12Resource *dst; + uint32_t dst_sub; + ID3D12Resource *src; + uint32_t src_sub; + DXGI_FORMAT format; +}; + +class MTLD3D12GraphicsCommandList : public ID3D12GraphicsCommandList6 { +public: + MTLD3D12GraphicsCommandList(MTLD3D12Device *device, + MTLD3D12CommandAllocator *allocator, + D3D12_COMMAND_LIST_TYPE type, + ID3D12PipelineState *initial_state); + ~MTLD3D12GraphicsCommandList(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + D3D12_COMMAND_LIST_TYPE STDMETHODCALLTYPE GetType() override; + + HRESULT STDMETHODCALLTYPE Close() override; + HRESULT STDMETHODCALLTYPE Reset(ID3D12CommandAllocator *allocator, + ID3D12PipelineState *initial_state) override; + void STDMETHODCALLTYPE ClearState(ID3D12PipelineState *pipeline_state) override; + + void STDMETHODCALLTYPE DrawInstanced(UINT vertex_count_per_instance, + UINT instance_count, + UINT start_vertex_location, + UINT start_instance_location) override; + void STDMETHODCALLTYPE DrawIndexedInstanced(UINT index_count_per_instance, + UINT instance_count, + UINT start_vertex_location, + INT base_vertex_location, + UINT start_instance_location) override; + void STDMETHODCALLTYPE Dispatch(UINT x, UINT u, UINT z) override; + void STDMETHODCALLTYPE CopyBufferRegion(ID3D12Resource *dst_buffer, + UINT64 dst_offset, + ID3D12Resource *src_buffer, + UINT64 src_offset, + UINT64 byte_count) override; + void STDMETHODCALLTYPE CopyTextureRegion( + const D3D12_TEXTURE_COPY_LOCATION *dst, UINT dst_x, UINT dst_y, + UINT dst_z, const D3D12_TEXTURE_COPY_LOCATION *src, + const D3D12_BOX *src_box) override; + void STDMETHODCALLTYPE CopyResource(ID3D12Resource *dst_resource, + ID3D12Resource *src_resource) override; + void STDMETHODCALLTYPE CopyTiles( + ID3D12Resource *tiled_resource, + const D3D12_TILED_RESOURCE_COORDINATE *tile_region_start_coordinate, + const D3D12_TILE_REGION_SIZE *tile_region_size, + ID3D12Resource *buffer, UINT64 buffer_offset, + D3D12_TILE_COPY_FLAGS flags) override; + void STDMETHODCALLTYPE ResolveSubresource(ID3D12Resource *dst_resource, + UINT dst_sub_resource, + ID3D12Resource *src_resource, + UINT src_sub_resource, + DXGI_FORMAT format) override; + void STDMETHODCALLTYPE + IASetPrimitiveTopology(D3D12_PRIMITIVE_TOPOLOGY primitive_topology) override; + void STDMETHODCALLTYPE RSSetViewports(UINT viewport_count, + const D3D12_VIEWPORT *viewports) override; + void STDMETHODCALLTYPE RSSetScissorRects(UINT rect_count, + const D3D12_RECT *rects) override; + void STDMETHODCALLTYPE OMSetBlendFactor(const FLOAT blend_factor[4]) override; + void STDMETHODCALLTYPE OMSetStencilRef(UINT stencil_ref) override; + void STDMETHODCALLTYPE SetPipelineState( + ID3D12PipelineState *pipeline_state) override; + void STDMETHODCALLTYPE ResourceBarrier( + UINT barrier_count, + const D3D12_RESOURCE_BARRIER *barriers) override; + void STDMETHODCALLTYPE ExecuteBundle( + ID3D12GraphicsCommandList *command_list) override; + void STDMETHODCALLTYPE SetDescriptorHeaps( + UINT heap_count, + ID3D12DescriptorHeap *const *heaps) override; + void STDMETHODCALLTYPE + SetComputeRootSignature(ID3D12RootSignature *root_signature) override; + void STDMETHODCALLTYPE + SetGraphicsRootSignature(ID3D12RootSignature *root_signature) override; + void STDMETHODCALLTYPE SetComputeRootDescriptorTable( + UINT root_parameter_index, + D3D12_GPU_DESCRIPTOR_HANDLE base_descriptor) override; + void STDMETHODCALLTYPE SetGraphicsRootDescriptorTable( + UINT root_parameter_index, + D3D12_GPU_DESCRIPTOR_HANDLE base_descriptor) override; + void STDMETHODCALLTYPE SetComputeRoot32BitConstant(UINT root_parameter_index, + UINT data, + UINT dst_offset) override; + void STDMETHODCALLTYPE SetGraphicsRoot32BitConstant( + UINT root_parameter_index, UINT data, UINT dst_offset) override; + void STDMETHODCALLTYPE SetComputeRoot32BitConstants( + UINT root_parameter_index, UINT constant_count, const void *data, + UINT dst_offset) override; + void STDMETHODCALLTYPE SetGraphicsRoot32BitConstants( + UINT root_parameter_index, UINT constant_count, const void *data, + UINT dst_offset) override; + void STDMETHODCALLTYPE SetComputeRootConstantBufferView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE SetGraphicsRootConstantBufferView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE SetComputeRootShaderResourceView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE SetGraphicsRootShaderResourceView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE SetComputeRootUnorderedAccessView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE SetGraphicsRootUnorderedAccessView( + UINT root_parameter_index, + D3D12_GPU_VIRTUAL_ADDRESS address) override; + void STDMETHODCALLTYPE + IASetIndexBuffer(const D3D12_INDEX_BUFFER_VIEW *view) override; + void STDMETHODCALLTYPE IASetVertexBuffers(UINT start_slot, UINT view_count, + const D3D12_VERTEX_BUFFER_VIEW *views) override; + void STDMETHODCALLTYPE SOSetTargets( + UINT start_slot, UINT view_count, + const D3D12_STREAM_OUTPUT_BUFFER_VIEW *views) override; + void STDMETHODCALLTYPE OMSetRenderTargets( + UINT render_target_descriptor_count, + const D3D12_CPU_DESCRIPTOR_HANDLE *render_target_descriptors, + WINBOOL single_descriptor_handle, + const D3D12_CPU_DESCRIPTOR_HANDLE *depth_stencil_descriptor) override; + void STDMETHODCALLTYPE ClearDepthStencilView(D3D12_CPU_DESCRIPTOR_HANDLE dsv, + D3D12_CLEAR_FLAGS flags, + FLOAT depth, UINT8 stencil, + UINT rect_count, + const D3D12_RECT *rects) override; + void STDMETHODCALLTYPE ClearRenderTargetView( + D3D12_CPU_DESCRIPTOR_HANDLE rtv, const FLOAT color[4], UINT rect_count, + const D3D12_RECT *rects) override; + void STDMETHODCALLTYPE ClearUnorderedAccessViewUint( + D3D12_GPU_DESCRIPTOR_HANDLE gpu_handle, + D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle, ID3D12Resource *resource, + const UINT values[4], UINT rect_count, + const D3D12_RECT *rects) override; + void STDMETHODCALLTYPE ClearUnorderedAccessViewFloat( + D3D12_GPU_DESCRIPTOR_HANDLE gpu_handle, + D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle, ID3D12Resource *resource, + const float values[4], UINT rect_count, + const D3D12_RECT *rects) override; + void STDMETHODCALLTYPE DiscardResource(ID3D12Resource *resource, + const D3D12_DISCARD_REGION *region) override; + void STDMETHODCALLTYPE BeginQuery(ID3D12QueryHeap *heap, + D3D12_QUERY_TYPE type, + UINT index) override; + void STDMETHODCALLTYPE EndQuery(ID3D12QueryHeap *heap, + D3D12_QUERY_TYPE type, + UINT index) override; + void STDMETHODCALLTYPE ResolveQueryData(ID3D12QueryHeap *heap, + D3D12_QUERY_TYPE type, + UINT start_index, UINT query_count, + ID3D12Resource *dst_buffer, + UINT64 aligned_dst_buffer_offset) override; + void STDMETHODCALLTYPE SetPredication(ID3D12Resource *buffer, + UINT64 aligned_buffer_offset, + D3D12_PREDICATION_OP operation) override; + void STDMETHODCALLTYPE SetMarker(UINT metadata, const void *data, + UINT size) override; + void STDMETHODCALLTYPE BeginEvent(UINT metadata, const void *data, + UINT size) override; + void STDMETHODCALLTYPE EndEvent() override; + void STDMETHODCALLTYPE ExecuteIndirect( + ID3D12CommandSignature *command_signature, UINT max_command_count, + ID3D12Resource *arg_buffer, UINT64 arg_buffer_offset, + ID3D12Resource *count_buffer, + UINT64 count_buffer_offset) override; + + void STDMETHODCALLTYPE AtomicCopyBufferUINT( + ID3D12Resource *dst_buffer, UINT64 dst_offset, + ID3D12Resource *src_buffer, UINT64 src_offset, + UINT dependent_resource_count, + ID3D12Resource *const *dependent_resources, + const D3D12_SUBRESOURCE_RANGE_UINT64 *dependent_sub_resource_ranges) override; + void STDMETHODCALLTYPE AtomicCopyBufferUINT64( + ID3D12Resource *dst_buffer, UINT64 dst_offset, + ID3D12Resource *src_buffer, UINT64 src_offset, + UINT dependent_resource_count, + ID3D12Resource *const *dependent_resources, + const D3D12_SUBRESOURCE_RANGE_UINT64 *dependent_sub_resource_ranges) override; + void STDMETHODCALLTYPE OMSetDepthBounds(FLOAT min, FLOAT max) override; + void STDMETHODCALLTYPE SetSamplePositions( + UINT sample_count, UINT pixel_count, + D3D12_SAMPLE_POSITION *sample_positions) override; + void STDMETHODCALLTYPE ResolveSubresourceRegion( + ID3D12Resource *dst_resource, UINT dst_sub_resource_idx, + UINT dst_x, UINT dst_y, + ID3D12Resource *src_resource, UINT src_sub_resource_idx, + D3D12_RECT *src_rect, DXGI_FORMAT format, + D3D12_RESOLVE_MODE mode) override; + void STDMETHODCALLTYPE SetViewInstanceMask(UINT mask) override; + void STDMETHODCALLTYPE WriteBufferImmediate( + UINT count, const D3D12_WRITEBUFFERIMMEDIATE_PARAMETER *parameters, + const D3D12_WRITEBUFFERIMMEDIATE_MODE *modes) override; + + /*** ID3D12GraphicsCommandList3 ***/ + void STDMETHODCALLTYPE SetProtectedResourceSession( + ID3D12ProtectedResourceSession *protected_session) override; + + /*** ID3D12GraphicsCommandList4 ***/ + void STDMETHODCALLTYPE BeginRenderPass( + UINT num_render_targets, + const D3D12_RENDER_PASS_RENDER_TARGET_DESC *render_targets, + const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC *depth_stencil, + D3D12_RENDER_PASS_FLAGS flags) override; + void STDMETHODCALLTYPE EndRenderPass() override; + void STDMETHODCALLTYPE InitializeMetaCommand( + ID3D12MetaCommand *meta_command, const void *initialization_parameters_data, + SIZE_T initialization_parameters_data_size_in_bytes) override; + void STDMETHODCALLTYPE ExecuteMetaCommand( + ID3D12MetaCommand *meta_command, const void *execution_parameters_data, + SIZE_T execution_parameters_data_size_in_bytes) override; + void STDMETHODCALLTYPE BuildRaytracingAccelerationStructure( + const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC *desc, + UINT num_post_build_info_descs, + const D3D12_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO_DESC *post_build_info_descs) override; + void STDMETHODCALLTYPE EmitRaytracingAccelerationStructurePostbuildInfo( + const D3D12_RAYTRACING_ACCELERATION_STRUCTURE_POSTBUILD_INFO_DESC *descs, + UINT num_acceleration_structures, + const D3D12_GPU_VIRTUAL_ADDRESS *source_acceleration_structure_data) override; + void STDMETHODCALLTYPE CopyRaytracingAccelerationStructure( + D3D12_GPU_VIRTUAL_ADDRESS dest_acceleration_structure_data, + D3D12_GPU_VIRTUAL_ADDRESS source_acceleration_structure_data, + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_COPY_MODE mode) override; + void STDMETHODCALLTYPE SetPipelineState1( + ID3D12StateObject *state_object) override; + void STDMETHODCALLTYPE DispatchRays( + const D3D12_DISPATCH_RAYS_DESC *desc) override; + + /*** ID3D12GraphicsCommandList5 ***/ + void STDMETHODCALLTYPE RSSetShadingRate( + D3D12_SHADING_RATE base_shading_rate, + const D3D12_SHADING_RATE_COMBINER *combiners) override; + void STDMETHODCALLTYPE RSSetShadingRateImage( + ID3D12Resource *shading_rate_image) override; + + /*** ID3D12GraphicsCommandList6 ***/ + void STDMETHODCALLTYPE DispatchMesh( + UINT thread_group_count_x, UINT thread_group_count_y, + UINT thread_group_count_z) override; + + const std::vector &GetCommands() const { return m_cmds; } + void ClearCommands() { m_cmds.clear(); } + +private: + template void Emit(const T &cmd) { + auto offset = m_cmds.size(); + m_cmds.resize(offset + sizeof(T)); + memcpy(m_cmds.data() + offset, &cmd, sizeof(T)); + } + + template + void EmitVar(T &cmd, const void *extra, uint32_t extra_size) { + auto offset = m_cmds.size(); + m_cmds.resize(offset + sizeof(T) - 1 + extra_size); + cmd.header.size = sizeof(T) - 1 + extra_size; + memcpy(m_cmds.data() + offset, &cmd, sizeof(T) - 1); + memcpy(m_cmds.data() + offset + sizeof(T) - 1, extra, extra_size); + } + + MTLD3D12Device *m_device; + MTLD3D12CommandAllocator *m_allocator; + D3D12_COMMAND_LIST_TYPE m_type; + bool m_closed = false; + std::vector m_cmds; + std::atomic m_refCount = {1ul}; + std::atomic m_refPrivate = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_command_queue.cpp b/src/d3d12/d3d12_command_queue.cpp new file mode 100644 index 000000000..7c4e8e5c0 --- /dev/null +++ b/src/d3d12/d3d12_command_queue.cpp @@ -0,0 +1,1126 @@ +#include "d3d12_command_queue.hpp" +#include "d3d12_command_list.hpp" +#include "d3d12_descriptor_heap.hpp" +#include "d3d12_device.hpp" +#include "d3d12_pipeline_state.hpp" +#include "d3d12_resource.hpp" +#include "d3d12_root_signature.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include "Metal.hpp" + +#define QTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +static uint64_t g_enc_id = 0; +#define ENC_CREATE(type, handle) do { uint64_t _eid = __atomic_add_fetch(&g_enc_id, 1, __ATOMIC_SEQ_CST); QTRACE("[ENC+%llu] CREATE %s handle=%llu", (unsigned long long)_eid, type, (unsigned long long)(handle)); } while(0) +#define ENC_END(handle) do { QTRACE("[ENC] END handle=%llu", (unsigned long long)(handle)); } while(0) +#define ENC_COMMIT(cmdbuf_handle) do { QTRACE("[ENC] COMMIT cmdbuf=%llu", (unsigned long long)(cmdbuf_handle)); } while(0) + +namespace dxmt { + +namespace { + +struct ReplayState { + WMT::CommandBuffer cmdbuf; + WMT::RenderCommandEncoder render_enc; + bool render_enc_open = false; + + MTLD3D12PipelineState *pso = nullptr; + MTLD3D12RootSignature *graphics_root_sig = nullptr; + D3D12_PRIMITIVE_TOPOLOGY topology = D3D_PRIMITIVE_TOPOLOGY_UNDEFINED; + D3D12_VERTEX_BUFFER_VIEW vbs[16] = {}; + D3D12_INDEX_BUFFER_VIEW ib = {}; + D3D12_VIEWPORT viewports[16] = {}; + uint32_t viewport_count = 0; + D3D12_RECT scissor_rects[16] = {}; + uint32_t scissor_count = 0; + float blend_factor[4] = {1, 1, 1, 1}; + uint32_t stencil_ref = 0; + + D3D12_CPU_DESCRIPTOR_HANDLE rt_handles[8] = {}; + D3D12_CPU_DESCRIPTOR_HANDLE dsv_handle = {}; + uint32_t rt_count = 0; + bool has_dsv = false; + + ID3D12DescriptorHeap *desc_heaps[2] = {}; + uint32_t desc_heap_count = 0; + + D3D12_GPU_VIRTUAL_ADDRESS root_cbvs[16] = {}; + D3D12_GPU_DESCRIPTOR_HANDLE root_tables[16] = {}; + uint8_t root_constants_buf[16 * 64] = {}; + uint32_t root_constant_offsets[16] = {}; + uint32_t root_constant_sizes[16] = {}; + bool root_constant_set[16] = {}; + bool root_cbv_set[16] = {}; + bool root_table_set[16] = {}; + + MTLD3D12RootSignature *compute_root_sig = nullptr; + D3D12_GPU_VIRTUAL_ADDRESS comp_cbvs[16] = {}; + D3D12_GPU_DESCRIPTOR_HANDLE comp_tables[16] = {}; + uint8_t comp_constants_buf[16 * 64] = {}; + uint32_t comp_constant_offsets[16] = {}; + uint32_t comp_constant_sizes[16] = {}; + bool comp_constant_set[16] = {}; + bool comp_cbv_set[16] = {}; + bool comp_table_set[16] = {}; + bool comp_uav_root[16] = {}; + + void CloseRenderEncoder() { + if (render_enc_open) { + ENC_END(render_enc.handle); + render_enc.endEncoding(); + render_enc_open = false; + } + } + + WMTPrimitiveType GetMetalPrimitiveType() { + switch (topology) { + case D3D_PRIMITIVE_TOPOLOGY_POINTLIST: return WMTPrimitiveTypePoint; + case D3D_PRIMITIVE_TOPOLOGY_LINELIST: return WMTPrimitiveTypeLine; + case D3D_PRIMITIVE_TOPOLOGY_LINESTRIP: return WMTPrimitiveTypeLineStrip; + case D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST: return WMTPrimitiveTypeTriangle; + case D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP: return WMTPrimitiveTypeTriangleStrip; + default: return WMTPrimitiveTypeTriangle; + } + } + + void EnsureRenderEncoder() { + if (render_enc_open) + return; + + if (rt_count == 0) { + QTRACE("EnsureRenderEncoder: no render targets set, skipping"); + return; + } + + WMTRenderPassInfo rp = {}; + for (uint32_t i = 0; i < 8; i++) { + rp.colors[i].texture = NULL_OBJECT_HANDLE; + rp.colors[i].load_action = WMTLoadActionLoad; + rp.colors[i].store_action = WMTStoreActionStore; + rp.colors[i].level = 0; + rp.colors[i].slice = 0; + } + rp.depth.texture = NULL_OBJECT_HANDLE; + rp.depth.load_action = WMTLoadActionLoad; + rp.depth.store_action = WMTStoreActionStore; + rp.stencil.texture = NULL_OBJECT_HANDLE; + rp.stencil.load_action = WMTLoadActionLoad; + rp.stencil.store_action = WMTStoreActionStore; + + bool has_valid_rt = false; + for (uint32_t i = 0; i < rt_count && i < 8; i++) { + auto *desc = reinterpret_cast(rt_handles[i].ptr); + if (desc && desc->resource) { + auto *res = static_cast(desc->resource); + auto tex = res->GetMTLTexture(); + if (tex.handle) { + rp.colors[i].texture = tex.handle; + has_valid_rt = true; + } + } + } + + if (has_dsv) { + auto *desc = reinterpret_cast(dsv_handle.ptr); + if (desc && desc->resource) { + auto *res = static_cast(desc->resource); + if (res->GetMTLTexture().handle) { + rp.depth.texture = res->GetMTLTexture().handle; + rp.stencil.texture = res->GetMTLTexture().handle; + has_valid_rt = true; + } + } + } + + if (!has_valid_rt) { + QTRACE("EnsureRenderEncoder: no valid RT texture found, skipping"); + return; + } + + QTRACE("EnsureRenderEncoder: creating render encoder rt_count=%u", rt_count); + render_enc = cmdbuf.renderCommandEncoder(rp); + ENC_CREATE("render_ensure", render_enc.handle); + if (!render_enc.handle) { + QTRACE("EnsureRenderEncoder: FAILED to create render encoder!"); + return; + } + render_enc_open = true; + + if (pso && pso->IsCompiled() && pso->GetRenderPSO().handle) { + render_enc.setRenderPipelineState(pso->GetRenderPSO()); + } + + if (viewport_count > 0) { + for (uint32_t i = 0; i < viewport_count; i++) { + WMTViewport vp = {(double)viewports[i].TopLeftX, + (double)viewports[i].TopLeftY, + (double)viewports[i].Width, + (double)viewports[i].Height, + viewports[i].MinDepth, + viewports[i].MaxDepth}; + render_enc.setViewport(vp); + } + } + } + + void ApplyRootBindings(MTLD3D12Device *device) { + if (!render_enc_open || !pso) + return; + + for (uint32_t i = 0; i < 16; i++) { + if (root_constant_set[i] && root_constant_sizes[i] > 0) { + render_enc.setFragmentBytes(root_constants_buf + root_constant_offsets[i], + root_constant_sizes[i], i); + } + + if (root_cbv_set[i] && root_cbvs[i]) { + auto *res = device->LookupResourceByGPUAddress(root_cbvs[i]); + if (res && res->GetMTLBuffer().handle) { + uint64_t offset = root_cbvs[i] - res->GetGPUVirtualAddress(); + render_enc.setVertexBuffer(res->GetMTLBuffer(), offset, i); + render_enc.setFragmentBuffer(res->GetMTLBuffer(), offset, i); + } + } + + if (root_table_set[i] && desc_heap_count > 0) { + for (uint32_t h = 0; h < desc_heap_count; h++) { + auto *heap = static_cast(desc_heaps[h]); + if (!heap) continue; + auto *desc = heap->GetDescriptorFromGPUHandle(root_tables[i]); + if (!desc || !desc->resource) continue; + auto *res = static_cast(desc->resource); + if (res->GetMTLBuffer().handle) { + uint64_t off = 0; + if (desc->cbv.BufferLocation) { + auto *cbv_res = device->LookupResourceByGPUAddress(desc->cbv.BufferLocation); + if (cbv_res) off = desc->cbv.BufferLocation - cbv_res->GetGPUVirtualAddress(); + } + render_enc.setVertexBuffer(res->GetMTLBuffer(), off, i); + render_enc.setFragmentBuffer(res->GetMTLBuffer(), off, i); + } else if (res->GetMTLTexture().handle) { + render_enc.setFragmentTexture(res->GetMTLTexture(), i); + } + } + } + } + } + + void ApplyVertexBuffers(MTLD3D12Device *device) { + if (!render_enc_open) + return; + for (uint32_t i = 0; i < 16; i++) { + if (vbs[i].BufferLocation) { + auto *res = device->LookupResourceByGPUAddress(vbs[i].BufferLocation); + if (res && res->GetMTLBuffer().handle) { + uint64_t offset = vbs[i].BufferLocation - res->GetGPUVirtualAddress(); + render_enc.setVertexBuffer(res->GetMTLBuffer(), offset, i); + } + } + } + } +}; + +WMTIndexType DXGIToWMTIndexFormat(DXGI_FORMAT fmt) { + switch (fmt) { + case DXGI_FORMAT_R16_UINT: return WMTIndexTypeUInt16; + case DXGI_FORMAT_R32_UINT: return WMTIndexTypeUInt32; + default: return WMTIndexTypeUInt16; + } +} + +} // anonymous namespace + +static bool rt_handles_match(D3D12_CPU_DESCRIPTOR_HANDLE a, + D3D12_CPU_DESCRIPTOR_HANDLE b) { + return a.ptr == b.ptr; +} + +MTLD3D12CommandQueue::MTLD3D12CommandQueue(MTLD3D12Device *device, + CommandQueue &queue, + D3D12_COMMAND_QUEUE_DESC desc) + : m_device(device), m_queue(queue), m_desc(desc) { + m_device->AddRef(); + auto wmt_dev = m_device->GetDXMTDevice().device(); + m_wmt_queue = wmt_dev.newCommandQueue(1); + Logger::info("D3D12CommandQueue created"); +} + +MTLD3D12CommandQueue::~MTLD3D12CommandQueue() { + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12CommandQueue) { + *ppvObject = ref(this); + return S_OK; + } + + if (riid == __uuidof(IMTLDXGIDevice)) { + return m_device->QueryInterface(riid, ppvObject); + } + QTRACE("CmdQueue::QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12CommandQueue::AddRef() { + return ++m_refCount; +} + +ULONG STDMETHODCALLTYPE MTLD3D12CommandQueue::Release() { + uint32_t rc = --m_refCount; + if (!rc) { + uint32_t rp = --m_refPrivate; + if (!rp) { + m_refPrivate += 0x80000000; + delete this; + } + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + QTRACE("CmdQueue::GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12CommandQueue::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::UpdateTileMappings( + ID3D12Resource *resource, UINT region_count, + const D3D12_TILED_RESOURCE_COORDINATE *region_start_coordinates, + const D3D12_TILE_REGION_SIZE *region_sizes, ID3D12Heap *heap, + UINT range_count, const D3D12_TILE_RANGE_FLAGS *range_flags, + const UINT *heap_range_offsets, const UINT *range_tile_counts, + D3D12_TILE_MAPPING_FLAGS flags) {} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::CopyTileMappings( + ID3D12Resource *dst_resource, + const D3D12_TILED_RESOURCE_COORDINATE *dst_region_start_coordinate, + ID3D12Resource *src_resource, + const D3D12_TILED_RESOURCE_COORDINATE *src_region_start_coordinate, + const D3D12_TILE_REGION_SIZE *region_size, + D3D12_TILE_MAPPING_FLAGS flags) {} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::ExecuteCommandLists( + UINT command_list_count, + ID3D12CommandList *const *command_lists) { + QTRACE("ExecuteCommandLists count=%u", command_list_count); + + for (UINT li = 0; li < command_list_count; li++) { + QTRACE("ECL: processing list %u", li); + auto *list = static_cast(command_lists[li]); + if (!list) { + QTRACE("ECL: list %u is null, skipping", li); + continue; + } + + QTRACE("ECL: creating cmdbuf from m_wmt_queue"); + auto cmdbuf = m_wmt_queue.commandBuffer(); + QTRACE("ECL: cmdbuf handle=%llu", (unsigned long long)cmdbuf.handle); + if (!cmdbuf.handle) { + Logger::err("ExecuteCommandLists: failed to create Metal command buffer"); + continue; + } + + const auto cmds = list->GetCommands(); + QTRACE("ExecuteCommandLists: cmds.size=%zu empty=%d", cmds.size(), cmds.empty()); + if (cmds.empty()) { + QTRACE("ExecuteCommandLists: empty cmdlist, committing"); + cmdbuf.commit(); + QTRACE("ExecuteCommandLists: empty cmdlist committed ok"); + continue; + } + + ReplayState st; + st.cmdbuf = cmdbuf; + + QTRACE("ExecuteCommandLists: cmd_size=%zu", cmds.size()); + size_t offset = 0; + size_t cmd_count = 0; + uint32_t type_counts[30] = {}; + while (offset < cmds.size()) { + if (offset + sizeof(CmdHeader) > cmds.size()) + break; + auto *header = reinterpret_cast(cmds.data() + offset); + if (header->size < sizeof(CmdHeader) || header->size > 65536 || offset + header->size > cmds.size()) { + QTRACE("ECL: corrupt cmd at offset=%zu type=%d size=%zu cmds_size=%zu — skipping rest", + offset, (int)header->type, header->size, cmds.size()); + break; + } + + if ((uint32_t)header->type < 30) + type_counts[(uint32_t)header->type]++; + cmd_count++; + + if (cmd_count <= 5 || (cmd_count % 50) == 0) + QTRACE("ECL cmd[%zu] type=%d size=%u offset=%zu", cmd_count, (int)header->type, (unsigned)header->size, offset); + + switch (header->type) { + case CmdType::DrawInstanced: { + auto *cmd = reinterpret_cast(header); + st.EnsureRenderEncoder(); + st.ApplyRootBindings(m_device); + st.ApplyVertexBuffers(m_device); + QTRACE("DrawInstanced v=%u i=%u enc_open=%d", cmd->vertex_count, cmd->instance_count, st.render_enc_open); + + if (cmd->instance_count > 0 && cmd->vertex_count > 0 && st.render_enc_open) { + struct wmtcmd_render_draw draw = {}; + draw.type = WMTRenderCommandDraw; + draw.next.set(nullptr); + draw.primitive_type = st.GetMetalPrimitiveType(); + draw.vertex_start = cmd->start_vertex; + draw.vertex_count = cmd->vertex_count; + draw.base_instance = cmd->start_instance; + draw.instance_count = cmd->instance_count; + st.render_enc.encodeCommands( + reinterpret_cast(&draw)); + } + break; + } + case CmdType::DrawIndexedInstanced: { + auto *cmd = reinterpret_cast(header); + st.EnsureRenderEncoder(); + st.ApplyRootBindings(m_device); + st.ApplyVertexBuffers(m_device); + + if (cmd->instance_count > 0 && cmd->index_count > 0 && st.ib.BufferLocation) { + auto *ib_res = m_device->LookupResourceByGPUAddress(st.ib.BufferLocation); + if (!ib_res && st.ib.BufferLocation) { + ib_res = reinterpret_cast(st.ib.BufferLocation); + } + struct wmtcmd_render_draw_indexed draw = {}; + draw.type = WMTRenderCommandDrawIndexed; + draw.next.set(nullptr); + draw.primitive_type = st.GetMetalPrimitiveType(); + draw.index_type = DXGIToWMTIndexFormat(st.ib.Format); + draw.index_count = cmd->index_count; + draw.index_buffer = ib_res ? ib_res->GetMTLBuffer().handle : NULL_OBJECT_HANDLE; + draw.index_buffer_offset = st.ib.SizeInBytes ? 0 : 0; + draw.instance_count = cmd->instance_count; + draw.base_vertex = cmd->base_vertex; + draw.base_instance = cmd->start_instance; + st.render_enc.encodeCommands( + reinterpret_cast(&draw)); + } + break; + } + case CmdType::Dispatch: { + auto *cmd = reinterpret_cast(header); + QTRACE("Dispatch x=%u y=%u z=%u pso=%p compiled=%d compute=%d heaps=%u", + cmd->x, cmd->y, cmd->z, (void*)st.pso, + st.pso ? st.pso->IsCompiled() : 0, + st.pso ? st.pso->IsCompute() : 0, + st.desc_heap_count); + if (st.pso && st.pso->IsCompiled() && st.pso->IsCompute() && + st.pso->GetComputePSO().handle) { + st.CloseRenderEncoder(); + auto comp = cmdbuf.computeCommandEncoder(false); + ENC_CREATE("compute_dispatch", comp.handle); + + uint8_t cmd_buf[4096]; + uint8_t *cmd_ptr = cmd_buf; + wmtcmd_compute_nop *chain_head = nullptr; + wmtcmd_base *chain_tail = nullptr; + + auto append_cmd = [&](void *data, size_t sz) -> wmtcmd_base * { + auto *c = (wmtcmd_base *)cmd_ptr; + memcpy(cmd_ptr, data, sz); + cmd_ptr += sz; + c->next.set(nullptr); + if (chain_tail) + chain_tail->next.set(c); + else + chain_head = (wmtcmd_compute_nop *)c; + chain_tail = c; + return c; + }; + + struct wmtcmd_compute_setpso setpso = {}; + setpso.type = WMTComputeCommandSetPSO; + setpso.pso = st.pso->GetComputePSO(); + setpso.threadgroup_size = st.pso->GetThreadgroupSize(); + append_cmd(&setpso, sizeof(setpso)); + + bool is_uav_slot[16] = {}; + if (st.compute_root_sig) { + auto ¶ms = st.compute_root_sig->GetParameters(); + QTRACE("ECL UAV scan: root_sig=%p num_params=%u", (void*)st.compute_root_sig, (uint32_t)params.size()); + for (uint32_t p = 0; p < params.size() && p < 16; p++) { + QTRACE(" param[%u] type=%u range_type=%u vis=%u", p, params[p].type, params[p].range_type, params[p].shader_visibility); + if (params[p].type == D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE && + params[p].range_type == D3D12_DESCRIPTOR_RANGE_TYPE_UAV) { + is_uav_slot[p] = true; + } else if (params[p].type == D3D12_ROOT_PARAMETER_TYPE_UAV) { + is_uav_slot[p] = true; + } + } + } else { + QTRACE("ECL UAV scan: no compute_root_sig set!"); + } + + for (uint32_t i = 0; i < 16; i++) { + bool const_set = st.comp_constant_set[i] || st.root_constant_set[i]; + uint32_t const_size = st.comp_constant_set[i] ? st.comp_constant_sizes[i] : st.root_constant_sizes[i]; + uint32_t const_off = st.comp_constant_set[i] ? st.comp_constant_offsets[i] : st.root_constant_offsets[i]; + uint8_t *const_buf = st.comp_constant_set[i] ? st.comp_constants_buf : st.root_constants_buf; + + bool cbv_set = st.comp_cbv_set[i] || st.root_cbv_set[i]; + D3D12_GPU_VIRTUAL_ADDRESS cbv_addr = st.comp_cbv_set[i] ? st.comp_cbvs[i] : st.root_cbvs[i]; + + bool tbl_set = st.comp_table_set[i] || st.root_table_set[i]; + D3D12_GPU_DESCRIPTOR_HANDLE tbl_handle = st.comp_table_set[i] ? st.comp_tables[i] : st.root_tables[i]; + + if (const_set && const_size > 0) { + struct wmtcmd_compute_setbytes sb = {}; + sb.type = WMTComputeCommandSetBytes; + sb.length = const_size; + sb.index = i; + sb.bytes.ptr = (void *)(const_buf + const_off); + append_cmd(&sb, sizeof(sb)); + } + if (cbv_set && cbv_addr) { + auto *res = m_device->LookupResourceByGPUAddress(cbv_addr); + if (res && res->GetMTLBuffer().handle) { + struct wmtcmd_compute_setbuffer sbuf = {}; + sbuf.type = WMTComputeCommandSetBuffer; + sbuf.buffer = res->GetMTLBuffer().handle; + sbuf.offset = cbv_addr - res->GetGPUVirtualAddress(); + sbuf.index = i; + append_cmd(&sbuf, sizeof(sbuf)); + if (st.comp_uav_root[i]) { + struct wmtcmd_compute_useresource use = {}; + use.type = WMTComputeCommandUseResource; + use.resource = res->GetMTLBuffer().handle; + use.usage = (WMTResourceUsage)(WMTResourceUsageRead | WMTResourceUsageWrite); + append_cmd(&use, sizeof(use)); + QTRACE(" UAV UseResource root buf slot=%u handle=%llu", i, (unsigned long long)res->GetMTLBuffer().handle); + } + } + } + if (tbl_set && st.desc_heap_count > 0) { + for (uint32_t h = 0; h < st.desc_heap_count; h++) { + auto *heap = static_cast(st.desc_heaps[h]); + if (heap) { + auto *desc = heap->GetDescriptorFromGPUHandle(tbl_handle); + QTRACE(" tbl[%u] heap=%u handle=0x%llx desc=%p res=%p", i, h, + (unsigned long long)tbl_handle.ptr, (void*)desc, + desc ? (void*)desc->resource : nullptr); + if (desc && desc->resource) { + auto *res = static_cast(desc->resource); + if (res->GetMTLBuffer().handle) { + struct wmtcmd_compute_setbuffer sbuf = {}; + sbuf.type = WMTComputeCommandSetBuffer; + sbuf.buffer = res->GetMTLBuffer().handle; + sbuf.offset = 0; + sbuf.index = i; + append_cmd(&sbuf, sizeof(sbuf)); + if (is_uav_slot[i]) { + struct wmtcmd_compute_useresource use = {}; + use.type = WMTComputeCommandUseResource; + use.resource = res->GetMTLBuffer().handle; + use.usage = (WMTResourceUsage)(WMTResourceUsageRead | WMTResourceUsageWrite); + append_cmd(&use, sizeof(use)); + } + } else if (res->GetMTLTexture().handle) { + struct wmtcmd_compute_settexture stex = {}; + stex.type = WMTComputeCommandSetTexture; + stex.texture = res->GetMTLTexture().handle; + stex.index = i; + append_cmd(&stex, sizeof(stex)); + if (is_uav_slot[i]) { + QTRACE(" UAV UseResource tex slot=%u handle=%llu", i, (unsigned long long)res->GetMTLTexture().handle); + struct wmtcmd_compute_useresource use = {}; + use.type = WMTComputeCommandUseResource; + use.resource = res->GetMTLTexture().handle; + use.usage = (WMTResourceUsage)(WMTResourceUsageRead | WMTResourceUsageWrite); + append_cmd(&use, sizeof(use)); + } + } + } + } + } + } + } + + int num_consts = 0, num_cbvs = 0, num_tables = 0; + for (uint32_t i = 0; i < 16; i++) { + if ((st.comp_constant_set[i] || st.root_constant_set[i]) && + (st.comp_constant_sizes[i] > 0 || st.root_constant_sizes[i] > 0)) + num_consts++; + if ((st.comp_cbv_set[i] && st.comp_cbvs[i]) || + (st.root_cbv_set[i] && st.root_cbvs[i])) + num_cbvs++; + if (st.comp_table_set[i] || st.root_table_set[i]) + num_tables++; + } + QTRACE(" bindings: consts=%d cbvs=%d tables=%d tg=%ux%ux%u", + num_consts, num_cbvs, num_tables, + st.pso->GetThreadgroupSize().width, + st.pso->GetThreadgroupSize().height, + st.pso->GetThreadgroupSize().depth); + + struct wmtcmd_compute_dispatch disp = {}; + disp.type = WMTComputeCommandDispatch; + disp.size = {(uint64_t)cmd->x, (uint64_t)cmd->y, (uint64_t)cmd->z}; + append_cmd(&disp, sizeof(disp)); + + if (chain_head) + comp.encodeCommands(chain_head); + ENC_END(comp.handle); + comp.endEncoding(); + } + break; + } + case CmdType::CopyBufferRegion: { + auto *cmd = reinterpret_cast(header); + QTRACE("CopyBufferRegion dst=%p +%llu src=%p +%llu bytes=%llu", (void*)cmd->dst, (unsigned long long)cmd->dst_offset, (void*)cmd->src, (unsigned long long)cmd->src_offset, (unsigned long long)cmd->byte_count); + if (cmd->dst && cmd->src) { + st.CloseRenderEncoder(); + auto *dst_res = static_cast(cmd->dst); + auto *src_res = static_cast(cmd->src); + if (dst_res->GetMTLBuffer().handle && src_res->GetMTLBuffer().handle) { + auto blit = cmdbuf.blitCommandEncoder(); + ENC_CREATE("blit_copybuf", blit.handle); + struct wmtcmd_blit_copy_from_buffer_to_buffer copy = {}; + copy.type = WMTBlitCommandCopyFromBufferToBuffer; + copy.next.set(nullptr); + copy.src = src_res->GetMTLBuffer().handle; + copy.src_offset = cmd->src_offset; + copy.dst = dst_res->GetMTLBuffer().handle; + copy.dst_offset = cmd->dst_offset; + copy.copy_length = cmd->byte_count; + blit.encodeCommands(reinterpret_cast(©)); + ENC_END(blit.handle); + blit.endEncoding(); + } + } + break; + } + case CmdType::CopyTextureRegion: { + auto *cmd = reinterpret_cast(header); + auto *dst_res = static_cast(cmd->dst_resource); + auto *src_res = static_cast(cmd->src_resource); + QTRACE("CopyTextureRegion dst=%p(%p) src=%p(%p) dst_type=%u src_type=%u", + (void*)dst_res, dst_res ? (void*)dst_res->GetMTLTexture().handle : nullptr, + (void*)src_res, src_res ? (void*)src_res->GetMTLTexture().handle : nullptr, + cmd->dst_type, cmd->src_type); + if (!dst_res || !src_res) break; + + QTRACE("CopyTextureRegion dst=%p src=%p dst_type=%u src_type=%u", + (void*)dst_res, (void*)src_res, cmd->dst_type, cmd->src_type); + + st.CloseRenderEncoder(); + auto blit = cmdbuf.blitCommandEncoder(); + ENC_CREATE("blit_copytex", blit.handle); + if (!blit.handle) { + QTRACE("CopyTextureRegion: FAILED to create blit encoder"); + break; + } + + bool src_is_buffer = (cmd->src_type == D3D12_TEXTURE_COPY_TYPE_PLACED_FOOTPRINT); + bool dst_is_buffer = (cmd->dst_type == D3D12_TEXTURE_COPY_TYPE_PLACED_FOOTPRINT); + + auto src_tex = src_res->GetMTLTexture(); + auto dst_tex = dst_res->GetMTLTexture(); + auto src_buf = src_res->GetMTLBuffer(); + auto dst_buf = dst_res->GetMTLBuffer(); + + if (!src_is_buffer && !src_tex.handle) src_is_buffer = (src_buf.handle != 0); + if (!dst_is_buffer && !dst_tex.handle) dst_is_buffer = (dst_buf.handle != 0); + + QTRACE("CopyTextureRegion src_tex=%llu src_buf=%llu dst_tex=%llu dst_buf=%llu src_buf_flag=%d dst_buf_flag=%d", + (unsigned long long)src_tex.handle, (unsigned long long)src_buf.handle, + (unsigned long long)dst_tex.handle, (unsigned long long)dst_buf.handle, + src_is_buffer, dst_is_buffer); + + UINT copy_w, copy_h, copy_d; + if (cmd->has_src_box) { + copy_w = cmd->src_box.right - cmd->src_box.left; + copy_h = cmd->src_box.bottom - cmd->src_box.top; + copy_d = cmd->src_box.back - cmd->src_box.front; + } else { + D3D12_RESOURCE_DESC tex_desc; + if (!dst_is_buffer && dst_tex.handle) { + dst_res->GetDesc(&tex_desc); + copy_w = tex_desc.Width; + copy_h = tex_desc.Height; + copy_d = 1; + } else if (!src_is_buffer && src_tex.handle) { + src_res->GetDesc(&tex_desc); + copy_w = tex_desc.Width; + copy_h = tex_desc.Height; + copy_d = 1; + } else { + copy_w = 1; + copy_h = 1; + copy_d = 1; + } + if (copy_w == 0) copy_w = 1; + if (copy_h == 0) copy_h = 1; + } + + if (src_is_buffer && !dst_is_buffer && dst_tex.handle) { + UINT row_pitch = cmd->src_footprint_row_pitch; + if (row_pitch == 0) row_pitch = copy_w * 4; + struct wmtcmd_blit_copy_from_buffer_to_texture copy = {}; + copy.type = WMTBlitCommandCopyFromBufferToTexture; + copy.next.set(nullptr); + copy.src = src_buf.handle; + copy.src_offset = cmd->src_offset; + copy.bytes_per_row = row_pitch; + copy.bytes_per_image = row_pitch * copy_h; + copy.size = {copy_w, copy_h, copy_d}; + copy.dst = dst_tex.handle; + copy.slice = 0; + copy.level = 0; + copy.origin = {cmd->dst_x, cmd->dst_y, cmd->dst_z}; + blit.encodeCommands(reinterpret_cast(©)); + } else if (!src_is_buffer && dst_is_buffer && src_tex.handle) { + struct wmtcmd_blit_copy_from_texture_to_buffer copy = {}; + copy.type = WMTBlitCommandCopyFromTextureToBuffer; + copy.next.set(nullptr); + copy.src = src_tex.handle; + copy.slice = 0; + copy.level = 0; + UINT src_x = cmd->has_src_box ? cmd->src_box.left : 0; + UINT src_y = cmd->has_src_box ? cmd->src_box.top : 0; + UINT src_z = cmd->has_src_box ? cmd->src_box.front : 0; + copy.origin = {src_x, src_y, src_z}; + copy.size = {copy_w, copy_h, copy_d}; + copy.dst = dst_buf.handle; + copy.offset = cmd->dst_offset; + copy.bytes_per_row = cmd->dst_footprint_row_pitch; + copy.bytes_per_image = cmd->dst_footprint_row_pitch * copy_h; + blit.encodeCommands(reinterpret_cast(©)); + } else if (!src_is_buffer && !dst_is_buffer && src_tex.handle && dst_tex.handle) { + struct wmtcmd_blit_copy_from_texture_to_texture copy = {}; + copy.type = WMTBlitCommandCopyFromTextureToTexture; + copy.next.set(nullptr); + copy.src = src_tex.handle; + copy.src_slice = 0; + copy.src_level = 0; + UINT src_x = cmd->has_src_box ? cmd->src_box.left : 0; + UINT src_y = cmd->has_src_box ? cmd->src_box.top : 0; + UINT src_z = cmd->has_src_box ? cmd->src_box.front : 0; + copy.src_origin = {src_x, src_y, src_z}; + copy.src_size = {copy_w, copy_h, copy_d}; + copy.dst = dst_tex.handle; + copy.dst_slice = 0; + copy.dst_level = 0; + copy.dst_origin = {cmd->dst_x, cmd->dst_y, cmd->dst_z}; + blit.encodeCommands(reinterpret_cast(©)); + } else { + QTRACE("CopyTextureRegion: unhandled buffer-to-buffer or null resources"); + } + + QTRACE("CopyTextureRegion: blit.endEncoding src_buf=%d dst_buf=%d w=%u h=%u d=%u", + src_is_buffer, dst_is_buffer, copy_w, copy_h, copy_d); + ENC_END(blit.handle); + blit.endEncoding(); + break; + } + case CmdType::CopyResource: { + auto *cmd = reinterpret_cast(header); + auto *dst_res = static_cast(cmd->dst); + auto *src_res = static_cast(cmd->src); + if (!dst_res || !src_res) break; + st.CloseRenderEncoder(); + + if (dst_res->GetMTLBuffer().handle && src_res->GetMTLBuffer().handle) { + auto blit = cmdbuf.blitCommandEncoder(); + ENC_CREATE("blit_copyres_buf", blit.handle); + struct wmtcmd_blit_copy_from_buffer_to_buffer copy = {}; + copy.type = WMTBlitCommandCopyFromBufferToBuffer; + copy.next.set(nullptr); + copy.src = src_res->GetMTLBuffer().handle; + copy.src_offset = 0; + copy.dst = dst_res->GetMTLBuffer().handle; + copy.dst_offset = 0; + D3D12_RESOURCE_DESC src_desc; + src_res->GetDesc(&src_desc); + copy.copy_length = src_desc.Width; + blit.encodeCommands(reinterpret_cast(©)); + ENC_END(blit.handle); + blit.endEncoding(); + } else if (dst_res->GetMTLTexture().handle && src_res->GetMTLTexture().handle) { + auto blit = cmdbuf.blitCommandEncoder(); + ENC_CREATE("blit_copyres_tex", blit.handle); + D3D12_RESOURCE_DESC src_desc; + src_res->GetDesc(&src_desc); + struct wmtcmd_blit_copy_from_texture_to_texture copy = {}; + copy.type = WMTBlitCommandCopyFromTextureToTexture; + copy.next.set(nullptr); + copy.src = src_res->GetMTLTexture().handle; + copy.src_slice = 0; + copy.src_level = 0; + copy.src_origin = {0, 0, 0}; + copy.src_size = {src_desc.Width, src_desc.Height, 1}; + copy.dst = dst_res->GetMTLTexture().handle; + copy.dst_slice = 0; + copy.dst_level = 0; + copy.dst_origin = {0, 0, 0}; + blit.encodeCommands(reinterpret_cast(©)); + ENC_END(blit.handle); + blit.endEncoding(); + } + break; + } + case CmdType::SetPipelineState: { + auto *cmd = reinterpret_cast(header); + st.pso = static_cast(cmd->pso); + if (st.render_enc_open && st.pso && st.pso->IsCompiled() && + st.pso->GetRenderPSO().handle) { + st.render_enc.setRenderPipelineState(st.pso->GetRenderPSO()); + } + break; + } + case CmdType::ResourceBarrier: { + break; + } + case CmdType::OMSetRenderTargets: { + auto *cmd = reinterpret_cast(header); + st.CloseRenderEncoder(); + st.rt_count = cmd->rt_count; + for (uint32_t i = 0; i < cmd->rt_count && i < 8; i++) + st.rt_handles[i] = cmd->rts[i]; + st.has_dsv = cmd->has_dsv; + if (cmd->has_dsv) + st.dsv_handle = cmd->dsv; + break; + } + case CmdType::ClearRenderTargetView: { + auto *cmd = reinterpret_cast(header); + st.CloseRenderEncoder(); + + WMTRenderPassInfo rp = {}; + for (uint32_t i = 0; i < 8; i++) { + rp.colors[i].texture = NULL_OBJECT_HANDLE; + rp.colors[i].load_action = WMTLoadActionDontCare; + rp.colors[i].store_action = WMTStoreActionDontCare; + } + rp.depth.texture = NULL_OBJECT_HANDLE; + rp.depth.load_action = WMTLoadActionDontCare; + rp.depth.store_action = WMTStoreActionDontCare; + rp.stencil.texture = NULL_OBJECT_HANDLE; + rp.stencil.load_action = WMTLoadActionDontCare; + rp.stencil.store_action = WMTStoreActionDontCare; + + for (uint32_t i = 0; i < st.rt_count && i < 8; i++) { + auto *desc = reinterpret_cast(st.rt_handles[i].ptr); + if (desc && desc->resource) { + auto *res = static_cast(desc->resource); + if (res->GetMTLTexture().handle) { + rp.colors[i].texture = res->GetMTLTexture().handle; + if (rt_handles_match(st.rt_handles[i], cmd->rtv)) { + rp.colors[i].load_action = WMTLoadActionClear; + rp.colors[i].store_action = WMTStoreActionStore; + rp.colors[i].clear_color = {cmd->color[0], cmd->color[1], + cmd->color[2], cmd->color[3]}; + } else { + rp.colors[i].load_action = WMTLoadActionLoad; + rp.colors[i].store_action = WMTStoreActionStore; + } + } + } + } + + if (st.has_dsv) { + auto *desc = reinterpret_cast(st.dsv_handle.ptr); + if (desc && desc->resource) { + auto *res = static_cast(desc->resource); + if (res->GetMTLTexture().handle) { + rp.depth.texture = res->GetMTLTexture().handle; + rp.depth.load_action = WMTLoadActionLoad; + rp.depth.store_action = WMTStoreActionStore; + } + } + } + + auto enc = cmdbuf.renderCommandEncoder(rp); + ENC_CREATE("render_clearrtv", enc.handle); + ENC_END(enc.handle); + enc.endEncoding(); + break; + } + case CmdType::ClearDepthStencilView: { + auto *cmd = reinterpret_cast(header); + break; + } + case CmdType::RSSetViewports: { + auto *cmd = reinterpret_cast(header); + auto *vps = reinterpret_cast( + reinterpret_cast(cmd) + + sizeof(CmdRSSetViewports) - sizeof(D3D12_VIEWPORT)); + st.viewport_count = cmd->count > 16 ? 16 : cmd->count; + for (uint32_t i = 0; i < st.viewport_count; i++) + st.viewports[i] = vps[i]; + if (st.render_enc_open) { + for (uint32_t i = 0; i < st.viewport_count; i++) { + WMTViewport vp = {(double)vps[i].TopLeftX, (double)vps[i].TopLeftY, + (double)vps[i].Width, (double)vps[i].Height, + vps[i].MinDepth, vps[i].MaxDepth}; + st.render_enc.setViewport(vp); + } + } + break; + } + case CmdType::RSSetScissorRects: { + auto *cmd = reinterpret_cast(header); + st.scissor_count = cmd->count > 16 ? 16 : cmd->count; + break; + } + case CmdType::IASetPrimitiveTopology: { + auto *cmd = reinterpret_cast(header); + st.topology = cmd->topology; + break; + } + case CmdType::SetGraphicsRootSignature: { + auto *cmd = reinterpret_cast(header); + st.graphics_root_sig = static_cast(cmd->root_sig); + break; + } + case CmdType::SetGraphicsRoot32BitConstants: { + auto *cmd = reinterpret_cast(header); + QTRACE("SetGraphicsRoot32BitConstants idx=%u count=%u", cmd->root_param_index, cmd->count); + if (cmd->root_param_index < 16) { + uint32_t sz = cmd->count * 4; + uint32_t off = cmd->dst_offset * 4; + if (off + sz <= sizeof(st.root_constants_buf)) { + memcpy(st.root_constants_buf + off, cmd->data, sz); + st.root_constant_offsets[cmd->root_param_index] = off; + st.root_constant_sizes[cmd->root_param_index] = sz; + st.root_constant_set[cmd->root_param_index] = true; + } + } + break; + } + case CmdType::SetGraphicsRootConstantBufferView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.root_cbvs[cmd->root_param_index] = cmd->address; + st.root_cbv_set[cmd->root_param_index] = true; + } + break; + } + case CmdType::SetGraphicsRootShaderResourceView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.root_cbvs[cmd->root_param_index] = cmd->address; + st.root_cbv_set[cmd->root_param_index] = true; + } + break; + } + case CmdType::SetGraphicsRootUnorderedAccessView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.root_cbvs[cmd->root_param_index] = cmd->address; + st.root_cbv_set[cmd->root_param_index] = true; + } + break; + } + case CmdType::SetGraphicsRootDescriptorTable: { + auto *cmd = reinterpret_cast(header); + QTRACE("SetGraphicsRootDescriptorTable idx=%u handle=0x%llx", cmd->root_param_index, (unsigned long long)cmd->base_descriptor.ptr); + if (cmd->root_param_index < 16) { + st.root_tables[cmd->root_param_index] = cmd->base_descriptor; + st.root_table_set[cmd->root_param_index] = true; + } + break; + } + case CmdType::SetComputeRootSignature: { + auto *cmd = reinterpret_cast(header); + st.compute_root_sig = static_cast(cmd->root_sig); + break; + } + case CmdType::SetComputeRoot32BitConstants: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + uint32_t sz = cmd->count * 4; + uint32_t off = cmd->dst_offset * 4; + if (off + sz <= sizeof(st.comp_constants_buf)) { + memcpy(st.comp_constants_buf + off, cmd->data, sz); + st.comp_constant_offsets[cmd->root_param_index] = off; + st.comp_constant_sizes[cmd->root_param_index] = sz; + st.comp_constant_set[cmd->root_param_index] = true; + } + } + break; + } + case CmdType::SetComputeRootConstantBufferView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.comp_cbvs[cmd->root_param_index] = cmd->address; + st.comp_cbv_set[cmd->root_param_index] = true; + st.comp_uav_root[cmd->root_param_index] = false; + } + break; + } + case CmdType::SetComputeRootShaderResourceView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.comp_cbvs[cmd->root_param_index] = cmd->address; + st.comp_cbv_set[cmd->root_param_index] = true; + st.comp_uav_root[cmd->root_param_index] = false; + } + break; + } + case CmdType::SetComputeRootUnorderedAccessView: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.comp_cbvs[cmd->root_param_index] = cmd->address; + st.comp_cbv_set[cmd->root_param_index] = true; + st.comp_uav_root[cmd->root_param_index] = true; + } + break; + } + case CmdType::SetComputeRootDescriptorTable: { + auto *cmd = reinterpret_cast(header); + if (cmd->root_param_index < 16) { + st.comp_tables[cmd->root_param_index] = cmd->base_descriptor; + st.comp_table_set[cmd->root_param_index] = true; + } + break; + } + case CmdType::IASetVertexBuffers: { + auto *cmd = reinterpret_cast(header); + auto *views = reinterpret_cast( + reinterpret_cast(cmd) + + sizeof(CmdIASetVertexBuffers) - sizeof(D3D12_VERTEX_BUFFER_VIEW)); + for (uint32_t i = 0; i < cmd->count && (cmd->start_slot + i) < 16; i++) + st.vbs[cmd->start_slot + i] = views[i]; + break; + } + case CmdType::IASetIndexBuffer: { + auto *cmd = reinterpret_cast(header); + st.ib = cmd->view; + break; + } + case CmdType::OMSetBlendFactor: { + auto *cmd = reinterpret_cast(header); + memcpy(st.blend_factor, cmd->factor, 16); + break; + } + case CmdType::OMSetStencilRef: { + auto *cmd = reinterpret_cast(header); + st.stencil_ref = cmd->stencil_ref; + break; + } + case CmdType::SetDescriptorHeaps: { + auto *cmd = reinterpret_cast(header); + st.desc_heap_count = cmd->count > 2 ? 2 : cmd->count; + auto *heaps = reinterpret_cast( + reinterpret_cast(cmd) + + sizeof(CmdSetDescriptorHeaps) - sizeof(ID3D12DescriptorHeap *)); + for (uint32_t i = 0; i < st.desc_heap_count; i++) + st.desc_heaps[i] = heaps[i]; + break; + } + default: + break; + } + + offset += header->size; + } + + QTRACE("ECL: replayed %zu cmds, types:", cmd_count); + for (int i = 0; i < 30; i++) + if (type_counts[i]) + QTRACE(" type[%d]=%u", i, type_counts[i]); + + st.CloseRenderEncoder(); + QTRACE("ExecuteCommandLists: committing cmdbuf"); + ENC_COMMIT(cmdbuf.handle); + cmdbuf.commit(); + cmdbuf.waitUntilCompleted(); + + auto status = cmdbuf.status(); + QTRACE("ExecuteCommandLists: cmdbuf status=%d", (int)status); + if (status != WMTCommandBufferStatusCompleted) { + auto err = cmdbuf.error(); + Logger::err(str::format("ExecuteCommandLists: cmdbuf status=", status, " error_handle=", err.handle)); + } + } +} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::SetMarker(UINT metadata, + const void *data, + UINT size) {} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::BeginEvent(UINT metadata, + const void *data, + UINT size) {} + +void STDMETHODCALLTYPE MTLD3D12CommandQueue::EndEvent() {} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::Signal(ID3D12Fence *fence, UINT64 value) { + QTRACE("CmdQueue::Signal value=%llu fence_iface=%p", (unsigned long long)value, (void *)fence); + if (!fence) + return E_POINTER; + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "CmdQueue::Signal value=%llu fence=%p\n", (unsigned long long)value, (void *)fence); fclose(f); } + } + return fence->Signal(value); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::Wait(ID3D12Fence *fence, UINT64 value) { + if (!fence) + return E_POINTER; + return fence->SetEventOnCompletion(value, nullptr); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::GetTimestampFrequency(UINT64 *frequency) { + if (frequency) + *frequency = 1000000000; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12CommandQueue::GetClockCalibration(UINT64 *gpu_timestamp, + UINT64 *cpu_timestamp) { + if (gpu_timestamp) + *gpu_timestamp = 0; + if (cpu_timestamp) + *cpu_timestamp = 0; + return S_OK; +} + +D3D12_COMMAND_QUEUE_DESC *STDMETHODCALLTYPE +MTLD3D12CommandQueue::GetDesc(D3D12_COMMAND_QUEUE_DESC *__ret) { + *__ret = m_desc; + return __ret; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_command_queue.hpp b/src/d3d12/d3d12_command_queue.hpp new file mode 100644 index 000000000..38453a989 --- /dev/null +++ b/src/d3d12/d3d12_command_queue.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "dxgi_interfaces.h" +#include "dxmt_command_queue.hpp" +#include "Metal.hpp" +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12CommandQueue : public ID3D12CommandQueue { +public: + MTLD3D12CommandQueue(MTLD3D12Device *device, CommandQueue &queue, + D3D12_COMMAND_QUEUE_DESC desc); + ~MTLD3D12CommandQueue(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + void STDMETHODCALLTYPE UpdateTileMappings( + ID3D12Resource *resource, UINT region_count, + const D3D12_TILED_RESOURCE_COORDINATE *region_start_coordinates, + const D3D12_TILE_REGION_SIZE *region_sizes, ID3D12Heap *heap, + UINT range_count, const D3D12_TILE_RANGE_FLAGS *range_flags, + const UINT *heap_range_offsets, const UINT *range_tile_counts, + D3D12_TILE_MAPPING_FLAGS flags) override; + + void STDMETHODCALLTYPE CopyTileMappings( + ID3D12Resource *dst_resource, + const D3D12_TILED_RESOURCE_COORDINATE *dst_region_start_coordinate, + ID3D12Resource *src_resource, + const D3D12_TILED_RESOURCE_COORDINATE *src_region_start_coordinate, + const D3D12_TILE_REGION_SIZE *region_size, + D3D12_TILE_MAPPING_FLAGS flags) override; + + void STDMETHODCALLTYPE ExecuteCommandLists( + UINT command_list_count, + ID3D12CommandList *const *command_lists) override; + + void STDMETHODCALLTYPE SetMarker(UINT metadata, const void *data, + UINT size) override; + + void STDMETHODCALLTYPE BeginEvent(UINT metadata, const void *data, + UINT size) override; + + void STDMETHODCALLTYPE EndEvent() override; + + HRESULT STDMETHODCALLTYPE Signal(ID3D12Fence *fence, + UINT64 value) override; + + HRESULT STDMETHODCALLTYPE Wait(ID3D12Fence *fence, UINT64 value) override; + + HRESULT STDMETHODCALLTYPE + GetTimestampFrequency(UINT64 *frequency) override; + + HRESULT STDMETHODCALLTYPE + GetClockCalibration(UINT64 *gpu_timestamp, UINT64 *cpu_timestamp) override; + + D3D12_COMMAND_QUEUE_DESC *STDMETHODCALLTYPE + GetDesc(D3D12_COMMAND_QUEUE_DESC *__ret) override; + + CommandQueue &GetDXMTCommandQueue() { return m_queue; } + +private: + MTLD3D12Device *m_device; + CommandQueue &m_queue; + D3D12_COMMAND_QUEUE_DESC m_desc; + std::atomic m_refCount = {1ul}; + std::atomic m_refPrivate = {1ul}; + WMT::Reference m_wmt_queue; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_descriptor_heap.cpp b/src/d3d12/d3d12_descriptor_heap.cpp new file mode 100644 index 000000000..8c3ce258f --- /dev/null +++ b/src/d3d12/d3d12_descriptor_heap.cpp @@ -0,0 +1,129 @@ +#include "d3d12_descriptor_heap.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include +#include +#include + +#define HTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "DescHeap::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12DescriptorHeap::MTLD3D12DescriptorHeap( + MTLD3D12Device *device, const D3D12_DESCRIPTOR_HEAP_DESC &desc) + : m_device(device), m_desc(desc), m_data(nullptr), m_data_is_virtual(false), m_data_size(0) { + HTRACE("DescriptorHeap ctor: device=%p type=%u num=%u this=%p", (void*)m_device, desc.Type, desc.NumDescriptors, (void*)this); + size_t alloc_size = (size_t)desc.NumDescriptors * sizeof(D3D12Descriptor); + HTRACE("DescriptorHeap ctor: alloc_size computed"); + m_data_size = alloc_size; + HTRACE("DescriptorHeap ctor: type=%u num=%u alloc=%u bytes", (unsigned)desc.Type, (unsigned)desc.NumDescriptors, (unsigned)alloc_size); + if (alloc_size >= 1024 * 1024) { + m_data_is_virtual = true; + m_data = (D3D12Descriptor *)VirtualAlloc(nullptr, alloc_size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + HTRACE("DescriptorHeap ctor: VirtualAlloc data=%p", (void *)m_data); + } else { + m_data = (D3D12Descriptor *)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, alloc_size); + HTRACE("DescriptorHeap ctor: HeapAlloc data=%p", (void *)m_data); + } + if (m_data) { + HTRACE("DescriptorHeap ctor: OK data=%p", (void *)m_data); + } else { + HTRACE("DescriptorHeap ctor: ALLOC FAILED for %u bytes!", (unsigned)alloc_size); + } + Logger::info(str::format("D3D12DescriptorHeap: type=", desc.Type, + " count=", desc.NumDescriptors, + " flags=", desc.Flags, + " data=", (void *)m_data)); +} + +MTLD3D12DescriptorHeap::~MTLD3D12DescriptorHeap() { + if (m_data) { + if (m_data_is_virtual) { + VirtualFree(m_data, 0, MEM_RELEASE); + } else { + HeapFree(GetProcessHeap(), 0, m_data); + } + } +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12DescriptorHeap) { + *ppvObject = ref(this); + return S_OK; + } + HTRACE("QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12DescriptorHeap::Release() { + uint32_t rc = --m_refCount; + if (!rc) { + this->~MTLD3D12DescriptorHeap(); + HeapFree(GetProcessHeap(), 0, this); + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + HTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +D3D12_DESCRIPTOR_HEAP_DESC *STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::GetDesc(D3D12_DESCRIPTOR_HEAP_DESC *__ret) { + *__ret = m_desc; + return __ret; +} + +D3D12_CPU_DESCRIPTOR_HANDLE *STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::GetCPUDescriptorHandleForHeapStart( + D3D12_CPU_DESCRIPTOR_HANDLE *__ret) { + HTRACE("GetCPUDescriptorHandleForHeapStart"); + __ret->ptr = reinterpret_cast(m_data); + return __ret; +} + +D3D12_GPU_DESCRIPTOR_HANDLE *STDMETHODCALLTYPE +MTLD3D12DescriptorHeap::GetGPUDescriptorHandleForHeapStart( + D3D12_GPU_DESCRIPTOR_HANDLE *__ret) { + HTRACE("GetGPUDescriptorHandleForHeapStart ptr=0x%llx", (unsigned long long)reinterpret_cast(m_data)); + __ret->ptr = reinterpret_cast(m_data); + return __ret; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_descriptor_heap.hpp b/src/d3d12/d3d12_descriptor_heap.hpp new file mode 100644 index 000000000..81bc5b348 --- /dev/null +++ b/src/d3d12/d3d12_descriptor_heap.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include +#include + +namespace dxmt { + +class MTLD3D12Device; + +struct D3D12Descriptor { + D3D12_DESCRIPTOR_HEAP_TYPE type; + union { + D3D12_CONSTANT_BUFFER_VIEW_DESC cbv; + D3D12_SHADER_RESOURCE_VIEW_DESC srv; + D3D12_UNORDERED_ACCESS_VIEW_DESC uav; + D3D12_RENDER_TARGET_VIEW_DESC rtv; + D3D12_DEPTH_STENCIL_VIEW_DESC dsv; + D3D12_SAMPLER_DESC sampler; + }; + ID3D12Resource *resource = nullptr; + ID3D12Resource *resource_uav_counter = nullptr; +}; + +class MTLD3D12DescriptorHeap : public ID3D12DescriptorHeap { +public: + MTLD3D12DescriptorHeap(MTLD3D12Device *device, + const D3D12_DESCRIPTOR_HEAP_DESC &desc); + ~MTLD3D12DescriptorHeap(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + D3D12_DESCRIPTOR_HEAP_DESC *STDMETHODCALLTYPE + GetDesc(D3D12_DESCRIPTOR_HEAP_DESC *__ret) override; + + D3D12_CPU_DESCRIPTOR_HANDLE *STDMETHODCALLTYPE + GetCPUDescriptorHandleForHeapStart(D3D12_CPU_DESCRIPTOR_HANDLE *__ret) override; + D3D12_GPU_DESCRIPTOR_HANDLE *STDMETHODCALLTYPE + GetGPUDescriptorHandleForHeapStart(D3D12_GPU_DESCRIPTOR_HANDLE *__ret) override; + + D3D12Descriptor *GetDescriptors() { return m_data; } + uint32_t GetDescriptorCount() { return m_desc.NumDescriptors; } + + D3D12Descriptor *GetDescriptorFromGPUHandle(D3D12_GPU_DESCRIPTOR_HANDLE handle) { + auto *desc = reinterpret_cast(handle.ptr); + auto *end = m_data + m_desc.NumDescriptors; + if (desc < m_data || desc >= end) + return nullptr; + return desc; + } + D3D12Descriptor *GetDescriptorFromCPUHandle(D3D12_CPU_DESCRIPTOR_HANDLE handle) { + auto *desc = reinterpret_cast(handle.ptr); + auto *end = m_data + m_desc.NumDescriptors; + if (desc < m_data || desc >= end) + return nullptr; + return desc; + } + +private: + MTLD3D12Device *m_device; + D3D12_DESCRIPTOR_HEAP_DESC m_desc; + D3D12Descriptor *m_data = nullptr; + bool m_data_is_virtual = false; + size_t m_data_size = 0; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_device.cpp b/src/d3d12/d3d12_device.cpp new file mode 100644 index 000000000..06ba1a47e --- /dev/null +++ b/src/d3d12/d3d12_device.cpp @@ -0,0 +1,1559 @@ +#define INITGUID +#include "d3d12_command_queue.hpp" +#include "d3d12_command_allocator.hpp" +#include "d3d12_command_list.hpp" +#include "d3d12_descriptor_heap.hpp" +#include "d3d12_device.hpp" +#include "d3d12_fence.hpp" +#include "d3d12_heap.hpp" +#include "d3d12_pipeline_state.hpp" +#include "d3d12_query_heap.hpp" +#include "d3d12_resource.hpp" + +#define TRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) +#include "d3d12_root_signature.hpp" +#include "com/com_object.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include "d3d12_resource.hpp" +#include +#include +#include + +static LONG WINAPI crash_handler(EXCEPTION_POINTERS *ep) { + if (ep->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION || + ep->ExceptionRecord->ExceptionCode == EXCEPTION_STACK_OVERFLOW) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "!!! EXCEPTION code=0x%lx addr=%p flags=0x%lx\n", + ep->ExceptionRecord->ExceptionCode, + ep->ExceptionRecord->ExceptionAddress, + ep->ExceptionRecord->ExceptionFlags); + void *buf[32]; + ULONG n = RtlCaptureStackBackTrace(0, 32, buf, nullptr); + for (ULONG i = 0; i < n; i++) { + fprintf(f, " [%lu] %p\n", (unsigned long)i, buf[i]); + } + fclose(f); + } + } + return EXCEPTION_CONTINUE_SEARCH; +} + +void install_crash_handler() { + AddVectoredExceptionHandler(1, crash_handler); +} + +namespace dxmt { + +static const GUID IID_ID3D12Device2_ = {0x30baa41e, 0xb15b, 0x475c, {0xa0, 0xbb, 0x1a, 0xf5, 0xc5, 0xb6, 0x43, 0x28}}; +static const GUID IID_ID3D12Device3_ = {0x81dadc15, 0x2bad, 0x4392, {0x93, 0xc5, 0x10, 0x13, 0x45, 0xc4, 0xaa, 0x98}}; +static const GUID IID_ID3D12Device4_ = {0xe865df17, 0xa9ee, 0x46f9, {0xa4, 0x63, 0x30, 0x98, 0x31, 0x5a, 0xa2, 0xe5}}; +static const GUID IID_ID3D12Device5_ = {0x8b4f173b, 0x2fea, 0x4b80, {0x8f, 0x58, 0x43, 0x07, 0x19, 0x1a, 0xb9, 0x5d}}; +static const GUID IID_ID3D12Device6_ = {0xc70b221b, 0x40e4, 0x4a17, {0x89, 0xaf, 0x02, 0x5a, 0x07, 0x27, 0xa6, 0xdc}}; +static const GUID IID_ID3D12Device7_ = {0x5c014b53, 0x68a1, 0x4b9b, {0x8b, 0xd1, 0xdd, 0x60, 0x46, 0xb9, 0x35, 0x8b}}; +static const GUID IID_ID3D12Device8_ = {0x9218e6bb, 0xf944, 0x4f7e, {0xa7, 0x5c, 0xb1, 0xb2, 0xc7, 0xb7, 0x01, 0xf3}}; +static const GUID IID_ID3D12Device9_ = {0x4c80e962, 0xf032, 0x4f60, {0xbc, 0x9e, 0xeb, 0xc2, 0xcf, 0xa1, 0xd8, 0x3c}}; +static const GUID IID_ID3D12Device10_ = {0x517f8718, 0xaa66, 0x49f9, {0xb0, 0x2b, 0xa7, 0xab, 0x89, 0xc0, 0x60, 0x31}}; +static const GUID IID_ID3D12Device11_ = {0x5405c344, 0xd457, 0x444e, {0xb4, 0xdd, 0x23, 0x66, 0xe4, 0x5a, 0xee, 0x39}}; +static const GUID IID_ID3D12Device12_ = {0x5af5c532, 0x4c91, 0x4cd0, {0xb5, 0x41, 0x15, 0xa4, 0x05, 0x39, 0x5f, 0xc5}}; + +Logger Logger::s_instance("d3d12.log"); + +class MTLD3D12CommandSignature : public ComObject { +public: + MTLD3D12CommandSignature(MTLD3D12Device *device, const D3D12_COMMAND_SIGNATURE_DESC &desc) + : m_device(device), m_desc(desc) { m_device->AddRef(); } + ~MTLD3D12CommandSignature() { m_device->Release(); } + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppv) override { + if (!ppv) return E_POINTER; + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12CommandSignature) { + *ppv = ref(this); return S_OK; + } + return E_NOINTERFACE; + } + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID, UINT *, void *) override { return E_NOTIMPL; } + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID, UINT, const void *) override { return S_OK; } + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface(REFGUID, const IUnknown *) override { return S_OK; } + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR) override { return S_OK; } + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override { return m_device->QueryInterface(riid, device); } +private: + MTLD3D12Device *m_device; + D3D12_COMMAND_SIGNATURE_DESC m_desc; +}; + +} // namespace dxmt + +namespace dxmt { + +static void *g_device_this = nullptr; +static void *g_device_expected_vtable = nullptr; +static uint64_t g_device_expected_m_device = 0; +static std::atomic g_device_watcher_running{false}; +static int g_watcher_restore_count = 0; + +static void device_vtable_watcher() { + int check_count = 0; + int snapshot_count = 0; + while (g_device_watcher_running.load()) { + if (g_device_this) { + void *current = *(void**)g_device_this; + uint64_t current_m_device = *((uint64_t*)((char*)g_device_this + 8)); + bool vtable_bad = (current != g_device_expected_vtable); + bool m_device_bad = (g_device_expected_m_device != 0 && current_m_device != g_device_expected_m_device); + if (vtable_bad || m_device_bad) { + g_watcher_restore_count++; + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "!!! CORRUPTION #%d after %d checks: this=%p vtable_expected=%p vtable_now=%p m_device_expected=0x%llx m_device_now=0x%llx watcher_tid=%lu\n", + g_watcher_restore_count, check_count, + g_device_this, g_device_expected_vtable, current, + (unsigned long long)g_device_expected_m_device, (unsigned long long)current_m_device, + (unsigned long)GetCurrentThreadId()); + unsigned char *raw = (unsigned char *)g_device_this; + fprintf(f, "!!! DEVICE DUMP [0x00-0x3F]:"); + for (int i = 0; i < 64; i += 8) { + fprintf(f, " %02x%02x%02x%02x%02x%02x%02x%02x", + raw[i], raw[i+1], raw[i+2], raw[i+3], raw[i+4], raw[i+5], raw[i+6], raw[i+7]); + } + fprintf(f, "\n!!! DEVICE DUMP [0x40-0x7F]:"); + for (int i = 64; i < 128; i += 8) { + fprintf(f, " %02x%02x%02x%02x%02x%02x%02x%02x", + raw[i], raw[i+1], raw[i+2], raw[i+3], raw[i+4], raw[i+5], raw[i+6], raw[i+7]); + } + fprintf(f, "\n"); + fclose(f); + } + *(void**)g_device_this = g_device_expected_vtable; + if (g_device_expected_m_device != 0) { + *((uint64_t*)((char*)g_device_this + 8)) = g_device_expected_m_device; + } + check_count = 0; + continue; + } + check_count++; + snapshot_count++; + if (snapshot_count % 10000 == 0) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "watcher snapshot #%d: vtable=%p m_device=0x%llx OK\n", + snapshot_count, current, (unsigned long long)current_m_device); + fclose(f); + } + } + } + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } +} + +MTLD3D12Device::MTLD3D12Device(std::unique_ptr &&device, + IMTLDXGIAdapter *pAdapter) + : m_device(std::move(device)), m_adapter(pAdapter) { + if (m_adapter) + m_adapter->AddRef(); + m_expected_vtable = *(void**)this; + g_device_this = (void*)this; + g_device_expected_vtable = m_expected_vtable; + g_device_expected_m_device = (uint64_t)m_device.get(); + TRACE("Device ctor: this=%p vtable=%p m_device=%p sizeof=%zu", + (void*)this, m_expected_vtable, (void*)m_device.get(), sizeof(MTLD3D12Device)); + extern void *g_d3d12_device_addr; + extern size_t g_d3d12_device_size; + g_d3d12_device_addr = (void*)this; + g_d3d12_device_size = sizeof(MTLD3D12Device); + TRACE("Device ctor: registered device guard at %p size=%zu", g_d3d12_device_addr, g_d3d12_device_size); + g_device_watcher_running.store(true); + std::thread watcher(device_vtable_watcher); + watcher.detach(); + Logger::info("D3D12 device created via DXMT Metal backend"); +} + +MTLD3D12Device::~MTLD3D12Device() { + void *current_vt = *(void**)this; + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "Device REAL DTOR this=%p vtable=%p expected=%p m_refCount=%u\n", + (void*)this, current_vt, m_expected_vtable, m_refCount.load()); + fclose(f); + } + Logger::info("D3D12 device destroyed"); +} + +void MTLD3D12Device::CheckVtable(const char *where) { + void *current = *(void**)this; + if (current != m_expected_vtable) { + TRACE("VTABLE CORRUPTION at %s: expected=%p got=%p this=%p — AUTO-RESTORING", where, m_expected_vtable, current, (void*)this); + *(void**)this = m_expected_vtable; + } +} + +WMT::Device MTLD3D12Device::GetMTLDevice() { + return m_device->device(); +} + +Device &MTLD3D12Device::GetDXMTDevice() { return *m_device; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + TRACE("D3D12Device::QI(%s)", str::format(riid).c_str()); + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12Device || riid == IID_ID3D12Device1 || + riid == IID_ID3D12Device2_ || riid == IID_ID3D12Device3_ || + riid == IID_ID3D12Device4_ || riid == IID_ID3D12Device5_ || + riid == IID_ID3D12Device6_ || riid == IID_ID3D12Device7_ || + riid == IID_ID3D12Device8_ || riid == IID_ID3D12Device9_ || + riid == IID_ID3D12Device10_ || riid == IID_ID3D12Device11_ || + riid == IID_ID3D12Device12_) { + *ppvObject = ref(this); + TRACE("D3D12Device::QI(%s) -> S_OK (device)", str::format(riid).c_str()); + return S_OK; + } + + if (riid == __uuidof(IMTLDXGIDevice) && m_dxgi_device) { + TRACE("D3D12Device::QI(%s) -> delegating to dxgi_device", str::format(riid).c_str()); + return m_dxgi_device->QueryInterface(riid, ppvObject); + } + + if (m_dxgi_device) { + if (riid == IID_IDXGIDevice) { + TRACE("D3D12Device::QI(%s) -> delegating DXGI to dxgi_device", str::format(riid).c_str()); + return m_dxgi_device->QueryInterface(riid, ppvObject); + } + } + + Logger::warn(str::format("D3D12Device::QueryInterface: unknown IID ", riid)); + TRACE("D3D12Device::QI(%s) -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12Device::AddRef() { + CheckVtable("AddRef"); + uint32_t rc = m_refCount++; + if (!rc) + ++m_refPrivate; + return rc + 1; +} + +ULONG STDMETHODCALLTYPE MTLD3D12Device::Release() { + CheckVtable("Release"); + uint32_t rc = --m_refCount; + if (rc <= 1) TRACE("Device::Release rc=%u this=%p", rc, (void*)this); + if (!rc) { + uint32_t rp = --m_refPrivate; + if (!rp) { + TRACE("Device::Release DELETING this=%p", (void*)this); + m_refPrivate += 0x80000000; + delete this; + } + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::GetPrivateData(REFGUID guid, UINT *data_size, void *data) { + TRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::SetPrivateDataInterface(REFGUID guid, const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::SetName(LPCWSTR name) { + return S_OK; +} + +UINT STDMETHODCALLTYPE MTLD3D12Device::GetNodeCount() { return 1; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateCommandQueue(const D3D12_COMMAND_QUEUE_DESC *desc, + REFIID riid, void **command_queue) { + TRACE("CreateCommandQueue type=%u", desc ? desc->Type : 0xFF); + if (!desc || !command_queue) + return E_POINTER; + InitReturnPtr(command_queue); + + auto queue = new MTLD3D12CommandQueue(this, m_device->queue(), *desc); + HRESULT hr = queue->QueryInterface(riid, command_queue); + if (FAILED(hr)) + queue->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE type, + REFIID riid, + void **command_allocator) { + TRACE("CreateCommandAllocator type=%u", type); + if (!command_allocator) + return E_POINTER; + InitReturnPtr(command_allocator); + + auto allocator = new MTLD3D12CommandAllocator(this, type); + HRESULT hr = allocator->QueryInterface(riid, command_allocator); + if (FAILED(hr)) + allocator->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateGraphicsPipelineState( + const D3D12_GRAPHICS_PIPELINE_STATE_DESC *desc, REFIID riid, + void **pipeline_state) { + if (!desc || !pipeline_state) + return E_POINTER; + InitReturnPtr(pipeline_state); + + TRACE("CreateGraphicsPSO ENTER: VS=%p(%zu) PS=%p(%zu) NumRT=%u DSV=%u Topo=%u", + desc->VS.pShaderBytecode, desc->VS.BytecodeLength, + desc->PS.pShaderBytecode, desc->PS.BytecodeLength, + desc->NumRenderTargets, (unsigned)desc->DSVFormat, + (unsigned)desc->PrimitiveTopologyType); + + auto pso = new MTLD3D12PipelineState(this, false); + pso->SetGraphicsDesc(*desc); + bool compiled = pso->Compile(); + TRACE("CreateGraphicsPSO: compile=%d VS=%p PS=%p", compiled, desc->VS.pShaderBytecode, desc->PS.pShaderBytecode); + if (!compiled) { + Logger::warn("CreateGraphicsPipelineState: shader compilation deferred/failed"); + } + HRESULT hr = pso->QueryInterface(riid, pipeline_state); + if (FAILED(hr)) + pso->Release(); + TRACE("CreateGraphicsPSO EXIT hr=0x%lx pso=%p", hr, *pipeline_state); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateComputePipelineState( + const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc, REFIID riid, + void **pipeline_state) { + if (!desc || !pipeline_state) + return E_POINTER; + InitReturnPtr(pipeline_state); + + auto pso = new MTLD3D12PipelineState(this, true); + pso->SetComputeDesc(*desc); + bool compiled = pso->Compile(); + TRACE("CreateComputePSO: compile=%d CS=%p", compiled, desc->CS.pShaderBytecode); + if (!compiled) { + Logger::warn("CreateComputePipelineState: shader compilation deferred/failed"); + } + HRESULT hr = pso->QueryInterface(riid, pipeline_state); + if (FAILED(hr)) + pso->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommandList( + UINT node_mask, D3D12_COMMAND_LIST_TYPE type, + ID3D12CommandAllocator *command_allocator, + ID3D12PipelineState *initial_pipeline_state, REFIID riid, + void **command_list) { + TRACE("CreateCommandList type=%u", type); + if (!command_list) + return E_POINTER; + InitReturnPtr(command_list); + + auto allocator = static_cast(command_allocator); + auto list = new MTLD3D12GraphicsCommandList(this, allocator, type, + initial_pipeline_state); + HRESULT hr = list->QueryInterface(riid, command_list); + if (FAILED(hr)) + list->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CheckFeatureSupport(D3D12_FEATURE feature, + void *feature_data, + UINT feature_data_size) { + TRACE("CheckFeatureSupport this=%p feature=%u data=%p size=%u", (void*)this, (unsigned)feature, feature_data, feature_data_size); + if ((UINT_PTR)feature_data > 0 && (UINT_PTR)feature_data < 0x10000) { + TRACE("!!! SUSPICIOUS CheckFeatureSupport: feature_data=%p looks like row_pitch (small int), this=%p — probable vtable slot 13 collision (ReadFromSubresource->CheckFeatureSupport)", feature_data, (void*)this); + } + switch (feature) { + case D3D12_FEATURE_D3D12_OPTIONS: { + auto *opts = (D3D12_FEATURE_DATA_D3D12_OPTIONS *)feature_data; + if (feature_data_size < sizeof(*opts)) + return E_INVALIDARG; + opts->DoublePrecisionFloatShaderOps = FALSE; + opts->OutputMergerLogicOp = TRUE; + opts->MinPrecisionSupport = D3D12_SHADER_MIN_PRECISION_SUPPORT_10_BIT; + opts->TiledResourcesTier = D3D12_TILED_RESOURCES_TIER_2; + opts->ResourceBindingTier = D3D12_RESOURCE_BINDING_TIER_3; + opts->PSSpecifiedStencilRefSupported = TRUE; + opts->TypedUAVLoadAdditionalFormats = TRUE; + opts->ROVsSupported = TRUE; + opts->ConservativeRasterizationTier = D3D12_CONSERVATIVE_RASTERIZATION_TIER_3; + opts->MaxGPUVirtualAddressBitsPerResource = 40; + opts->StandardSwizzle64KBSupported = FALSE; + opts->CrossNodeSharingTier = D3D12_CROSS_NODE_SHARING_TIER_NOT_SUPPORTED; + opts->CrossAdapterRowMajorTextureSupported = FALSE; + opts->VPAndRTArrayIndexFromAnyShaderFeedingRasterizerSupportedWithoutGSEmulation = TRUE; + opts->ResourceHeapTier = D3D12_RESOURCE_HEAP_TIER_2; + TRACE(" OPTIONS: DoubleFP=%d LogicOp=%d TiledTier=%u BindingTier=%u PSStencilRef=%d TypedUAV=%d ROV=%d ConsRaster=%u VAbit=%u HeapTier=%u", + opts->DoublePrecisionFloatShaderOps, opts->OutputMergerLogicOp, + opts->TiledResourcesTier, opts->ResourceBindingTier, + opts->PSSpecifiedStencilRefSupported, opts->TypedUAVLoadAdditionalFormats, + opts->ROVsSupported, opts->ConservativeRasterizationTier, + opts->MaxGPUVirtualAddressBitsPerResource, opts->ResourceHeapTier); + return S_OK; + } + case D3D12_FEATURE_ARCHITECTURE: { + auto *arch = (D3D12_FEATURE_DATA_ARCHITECTURE *)feature_data; + if (feature_data_size < sizeof(*arch)) + return E_INVALIDARG; + arch->NodeIndex = 0; + arch->TileBasedRenderer = FALSE; + arch->UMA = TRUE; + arch->CacheCoherentUMA = TRUE; + return S_OK; + } + case D3D12_FEATURE_FEATURE_LEVELS: { + auto *fl = (D3D12_FEATURE_DATA_FEATURE_LEVELS *)feature_data; + if (feature_data_size < sizeof(*fl)) + return E_INVALIDARG; + fl->MaxSupportedFeatureLevel = D3D_FEATURE_LEVEL_9_1; + for (UINT i = 0; i < fl->NumFeatureLevels; i++) { + if (fl->pFeatureLevelsRequested[i] <= D3D_FEATURE_LEVEL_12_1 && + fl->pFeatureLevelsRequested[i] > fl->MaxSupportedFeatureLevel) { + fl->MaxSupportedFeatureLevel = fl->pFeatureLevelsRequested[i]; + } + } + TRACE(" FEATURE_LEVELS: MaxSupported=%u (from %u requested)", (unsigned)fl->MaxSupportedFeatureLevel, fl->NumFeatureLevels); + return S_OK; + } + case D3D12_FEATURE_FORMAT_SUPPORT: { + auto *fmt = (D3D12_FEATURE_DATA_FORMAT_SUPPORT *)feature_data; + if (feature_data_size < sizeof(*fmt)) + return E_INVALIDARG; + TRACE(" FORMAT_SUPPORT: format=%u", (unsigned)fmt->Format); + fmt->Support1 = (D3D12_FORMAT_SUPPORT1)( + D3D12_FORMAT_SUPPORT1_TEXTURE2D | D3D12_FORMAT_SUPPORT1_RENDER_TARGET | + D3D12_FORMAT_SUPPORT1_DEPTH_STENCIL | + D3D12_FORMAT_SUPPORT1_SHADER_SAMPLE | + D3D12_FORMAT_SUPPORT1_SHADER_LOAD | + D3D12_FORMAT_SUPPORT1_SHADER_SAMPLE_COMPARISON | + D3D12_FORMAT_SUPPORT1_BUFFER | + D3D12_FORMAT_SUPPORT1_IA_INDEX_BUFFER | + D3D12_FORMAT_SUPPORT1_IA_VERTEX_BUFFER | + D3D12_FORMAT_SUPPORT1_TYPED_UNORDERED_ACCESS_VIEW | + D3D12_FORMAT_SUPPORT1_MULTISAMPLE_RENDERTARGET | + D3D12_FORMAT_SUPPORT1_MULTISAMPLE_RESOLVE | + D3D12_FORMAT_SUPPORT1_DISPLAY); + fmt->Support2 = (D3D12_FORMAT_SUPPORT2)( + D3D12_FORMAT_SUPPORT2_UAV_TYPED_LOAD | D3D12_FORMAT_SUPPORT2_UAV_TYPED_STORE | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_ADD | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_BITWISE_OPS | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_COMPARE_STORE_OR_COMPARE_EXCHANGE | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_EXCHANGE | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_SIGNED_MIN_OR_MAX | + D3D12_FORMAT_SUPPORT2_UAV_ATOMIC_UNSIGNED_MIN_OR_MAX); + TRACE(" FORMAT_SUPPORT: format=%u Support1=0x%x Support2=0x%x", + (unsigned)fmt->Format, (unsigned)fmt->Support1, (unsigned)fmt->Support2); + return S_OK; + } + case D3D12_FEATURE_MULTISAMPLE_QUALITY_LEVELS: { + auto *ms = + (D3D12_FEATURE_DATA_MULTISAMPLE_QUALITY_LEVELS *)feature_data; + if (feature_data_size < sizeof(*ms)) + return E_INVALIDARG; + ms->NumQualityLevels = 1; + return S_OK; + } + case D3D12_FEATURE_FORMAT_INFO: { + auto *fi = (D3D12_FEATURE_DATA_FORMAT_INFO *)feature_data; + if (feature_data_size < sizeof(*fi)) + return E_INVALIDARG; + fi->PlaneCount = 1; + return S_OK; + } + case D3D12_FEATURE_GPU_VIRTUAL_ADDRESS_SUPPORT: { + auto *va = (D3D12_FEATURE_DATA_GPU_VIRTUAL_ADDRESS_SUPPORT *)feature_data; + if (feature_data_size < sizeof(*va)) + return E_INVALIDARG; + va->MaxGPUVirtualAddressBitsPerResource = 40; + va->MaxGPUVirtualAddressBitsPerProcess = 40; + return S_OK; + } + case D3D12_FEATURE_SHADER_MODEL: { + auto *sm = (D3D12_FEATURE_DATA_SHADER_MODEL *)feature_data; + if (feature_data_size < sizeof(*sm)) + return E_INVALIDARG; + sm->HighestShaderModel = D3D_SHADER_MODEL_6_5; + TRACE(" SHADER_MODEL: HighestSM=%u", (unsigned)sm->HighestShaderModel); + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS1: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS1 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->WaveOps = TRUE; + o->WaveLaneCountMin = 4; + o->WaveLaneCountMax = 64; + o->TotalLaneCount = 256; + o->ExpandedComputeResourceStates = FALSE; + o->Int64ShaderOps = TRUE; + return S_OK; + } + case D3D12_FEATURE_ROOT_SIGNATURE: { + auto *rs = (D3D12_FEATURE_DATA_ROOT_SIGNATURE *)feature_data; + if (feature_data_size < sizeof(*rs)) + return E_INVALIDARG; + rs->HighestVersion = D3D_ROOT_SIGNATURE_VERSION_1_1; + return S_OK; + } + case D3D12_FEATURE_ARCHITECTURE1: { + auto *a = (D3D12_FEATURE_DATA_ARCHITECTURE1 *)feature_data; + if (feature_data_size < sizeof(*a)) + return E_INVALIDARG; + a->NodeIndex = 0; + a->TileBasedRenderer = FALSE; + a->UMA = TRUE; + a->CacheCoherentUMA = TRUE; + a->IsolatedMMU = FALSE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS2: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS2 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->DepthBoundsTestSupported = TRUE; + o->ProgrammableSamplePositionsTier = D3D12_PROGRAMMABLE_SAMPLE_POSITIONS_TIER_1; + return S_OK; + } + case D3D12_FEATURE_SHADER_CACHE: { + auto *sc = (D3D12_FEATURE_DATA_SHADER_CACHE *)feature_data; + if (feature_data_size < sizeof(*sc)) + return E_INVALIDARG; + sc->SupportFlags = D3D12_SHADER_CACHE_SUPPORT_NONE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS3: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS3 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->CopyQueueTimestampQueriesSupported = FALSE; + o->CastingFullyTypedFormatSupported = TRUE; + o->WriteBufferImmediateSupportFlags = D3D12_COMMAND_LIST_SUPPORT_FLAG_DIRECT; + o->ViewInstancingTier = D3D12_VIEW_INSTANCING_TIER_NOT_SUPPORTED; + o->BarycentricsSupported = FALSE; + TRACE(" OPTIONS3: CopyQueueTS=%d CastFullyTyped=%d WriteBufImm=0x%x ViewInstTier=%u Bary=%d", + o->CopyQueueTimestampQueriesSupported, o->CastingFullyTypedFormatSupported, + o->WriteBufferImmediateSupportFlags, o->ViewInstancingTier, o->BarycentricsSupported); + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS4: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS4 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->MSAA64KBAlignedTextureSupported = FALSE; + o->SharedResourceCompatibilityTier = D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_0; + o->Native16BitShaderOpsSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_SERIALIZATION: { + auto *s = (D3D12_FEATURE_DATA_SERIALIZATION *)feature_data; + if (feature_data_size < sizeof(*s)) + return E_INVALIDARG; + s->HeapSerializationTier = D3D12_HEAP_SERIALIZATION_TIER_0; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS5: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS5 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->SRVOnlyTiledResourceTier3 = FALSE; + o->RenderPassesTier = D3D12_RENDER_PASS_TIER_1; + o->RaytracingTier = D3D12_RAYTRACING_TIER_NOT_SUPPORTED; + TRACE(" OPTIONS5: SRVTiled3=%d RenderPassesTier=%u RayTier=%u", + o->SRVOnlyTiledResourceTier3, o->RenderPassesTier, o->RaytracingTier); + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS6: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS6 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->AdditionalShadingRatesSupported = FALSE; + o->PerPrimitiveShadingRateSupportedWithViewportIndexing = FALSE; + o->VariableShadingRateTier = D3D12_VARIABLE_SHADING_RATE_TIER_NOT_SUPPORTED; + o->ShadingRateImageTileSize = 0; + o->BackgroundProcessingSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS7: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS7 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->MeshShaderTier = D3D12_MESH_SHADER_TIER_NOT_SUPPORTED; + o->SamplerFeedbackTier = D3D12_SAMPLER_FEEDBACK_TIER_NOT_SUPPORTED; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS8: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS8 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->UnalignedBlockTexturesSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS9: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS9 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->MeshShaderPipelineStatsSupported = FALSE; + o->MeshShaderSupportsFullRangeRenderTargetArrayIndex = FALSE; + o->AtomicInt64OnTypedResourceSupported = FALSE; + o->AtomicInt64OnGroupSharedSupported = FALSE; + o->DerivativesInMeshAndAmplificationShadersSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS10: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS10 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->VariableRateShadingSumCombinerSupported = FALSE; + o->MeshShaderPerPrimitiveShadingRateSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_D3D12_OPTIONS11: { + auto *o = (D3D12_FEATURE_DATA_D3D12_OPTIONS11 *)feature_data; + if (feature_data_size < sizeof(*o)) + return E_INVALIDARG; + o->AtomicInt64OnDescriptorHeapResourceSupported = FALSE; + return S_OK; + } + case D3D12_FEATURE_PROTECTED_RESOURCE_SESSION_SUPPORT: { + auto *p = (D3D12_FEATURE_DATA_PROTECTED_RESOURCE_SESSION_SUPPORT *)feature_data; + if (feature_data_size < sizeof(*p)) + return E_INVALIDARG; + p->Support = D3D12_PROTECTED_RESOURCE_SESSION_SUPPORT_FLAG_NONE; + return S_OK; + } + default: + TRACE("CheckFeatureSupport UNHANDLED feature=%u size=%u -> zeroing and returning S_OK", feature, feature_data_size); + if (feature_data && feature_data_size > 0) + memset(feature_data, 0, feature_data_size); + return S_OK; + } +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateDescriptorHeap(const D3D12_DESCRIPTOR_HEAP_DESC *desc, + REFIID riid, void **descriptor_heap) { + if (!desc || !descriptor_heap) + return E_POINTER; + CheckVtable("CreateDescriptorHeap"); + TRACE("CreateDescriptorHeap type=%u num=%u starting", desc->Type, desc->NumDescriptors); + InitReturnPtr(descriptor_heap); + + TRACE("CreateDescriptorHeap: about to allocate %u bytes for object", (unsigned)sizeof(MTLD3D12DescriptorHeap)); + void *raw = HeapAlloc(GetProcessHeap(), 0, sizeof(MTLD3D12DescriptorHeap)); + TRACE("CreateDescriptorHeap: HeapAlloc returned %p", raw); + if (!raw) { + TRACE("CreateDescriptorHeap: HeapAlloc for object FAILED"); + return E_OUTOFMEMORY; + } + TRACE("CreateDescriptorHeap: about to placement-new, sizeof=%u", (unsigned)sizeof(MTLD3D12DescriptorHeap)); + MTLD3D12DescriptorHeap *heap = new (raw) MTLD3D12DescriptorHeap(this, *desc); + TRACE("CreateDescriptorHeap: heap=%p data=%p", (void *)heap, heap->GetDescriptors()); + HRESULT hr = heap->QueryInterface(riid, descriptor_heap); + if (FAILED(hr)) { + heap->~MTLD3D12DescriptorHeap(); + HeapFree(GetProcessHeap(), 0, raw); + } + return hr; +} + +UINT STDMETHODCALLTYPE MTLD3D12Device::GetDescriptorHandleIncrementSize( + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) { + return sizeof(D3D12Descriptor); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateRootSignature( + UINT node_mask, const void *bytecode, SIZE_T bytecode_length, + REFIID riid, void **root_signature) { + TRACE("CreateRootSignature len=%llu", (unsigned long long)bytecode_length); + if (!bytecode || !root_signature) + return E_POINTER; + InitReturnPtr(root_signature); + + auto rs = + new MTLD3D12RootSignature(this, bytecode, bytecode_length); + HRESULT hr = rs->QueryInterface(riid, root_signature); + if (FAILED(hr)) + rs->Release(); + TRACE("CreateRootSignature DONE hr=0x%lx rs=%p out=%p", hr, (void*)rs, root_signature ? *root_signature : nullptr); + return hr; +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateConstantBufferView( + const D3D12_CONSTANT_BUFFER_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + TRACE("CreateConstantBufferView"); + CheckVtable("CreateConstantBufferView"); + if (!desc) + return; + auto *d = reinterpret_cast(descriptor.ptr); + if (d) { + d->cbv = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateShaderResourceView( + ID3D12Resource *resource, const D3D12_SHADER_RESOURCE_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + TRACE("CreateShaderResourceView res=%p handle=0x%llx device=%p", (void*)resource, (unsigned long long)descriptor.ptr, (void*)this); + if ((void*)resource == (void*)this) { + TRACE("!!! LEAK DETECTED: CreateShaderResourceView called with device pointer as resource!"); + } + CheckVtable("CreateShaderResourceView"); + auto *d = reinterpret_cast(descriptor.ptr); + if (d) { + d->resource = resource; + if (desc) + d->srv = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateUnorderedAccessView( + ID3D12Resource *resource, ID3D12Resource *counter_resource, + const D3D12_UNORDERED_ACCESS_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + TRACE("CreateUnorderedAccessView res=%p counter=%p handle=0x%llx device=%p", (void*)resource, (void*)counter_resource, (unsigned long long)descriptor.ptr, (void*)this); + if ((void*)resource == (void*)this || (void*)counter_resource == (void*)this) { + TRACE("!!! LEAK DETECTED: CreateUnorderedAccessView called with device pointer as resource!"); + } + CheckVtable("CreateUnorderedAccessView"); + auto *d = reinterpret_cast(descriptor.ptr); + if (d) { + d->resource = resource; + d->resource_uav_counter = counter_resource; + if (desc) + d->uav = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateRenderTargetView( + ID3D12Resource *resource, const D3D12_RENDER_TARGET_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + TRACE("CreateRenderTargetView res=%p device=%p", (void*)resource, (void*)this); + if ((void*)resource == (void*)this) { + TRACE("!!! LEAK DETECTED: CreateRenderTargetView called with device pointer as resource!"); + } + auto *d = reinterpret_cast(descriptor.ptr); + if (d) { + d->resource = resource; + if (desc) + d->rtv = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_RTV; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateDepthStencilView( + ID3D12Resource *resource, const D3D12_DEPTH_STENCIL_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + TRACE("CreateDepthStencilView res=%p device=%p", (void*)resource, (void*)this); + if ((void*)resource == (void*)this) { + TRACE("!!! LEAK DETECTED: CreateDepthStencilView called with device pointer as resource!"); + } + auto *d = reinterpret_cast(descriptor.ptr); + if (d) { + d->resource = resource; + if (desc) + d->dsv = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_DSV; + } +} + +void STDMETHODCALLTYPE +MTLD3D12Device::CreateSampler(const D3D12_SAMPLER_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) { + auto *d = reinterpret_cast(descriptor.ptr); + if (d && desc) { + d->sampler = *desc; + d->type = D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CopyDescriptors( + UINT dst_descriptor_range_count, + const D3D12_CPU_DESCRIPTOR_HANDLE *dst_descriptor_range_offsets, + const UINT *dst_descriptor_range_sizes, + UINT src_descriptor_range_count, + const D3D12_CPU_DESCRIPTOR_HANDLE *src_descriptor_range_offsets, + const UINT *src_descriptor_range_sizes, + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) { + UINT src_idx = 0; + for (UINT dst_range = 0; dst_range < dst_descriptor_range_count; dst_range++) { + for (UINT i = 0; i < dst_descriptor_range_sizes[dst_range]; i++) { + auto *dst = reinterpret_cast( + dst_descriptor_range_offsets[dst_range].ptr) + + i * (GetDescriptorHandleIncrementSize(descriptor_heap_type) / + sizeof(D3D12Descriptor)); + if (src_idx < src_descriptor_range_count) { + auto *src = reinterpret_cast( + src_descriptor_range_offsets[src_idx].ptr) + + i * (GetDescriptorHandleIncrementSize(descriptor_heap_type) / + sizeof(D3D12Descriptor)); + if (src->resource && (void*)src->resource == (void*)this) { + TRACE("!!! CopyDescriptors: src descriptor at %p has device pointer as resource! copying to dst %p", (void*)src, (void*)dst); + } + *dst = *src; + } + } + src_idx++; + } +} + +void STDMETHODCALLTYPE MTLD3D12Device::CopyDescriptorsSimple( + UINT descriptor_count, + const D3D12_CPU_DESCRIPTOR_HANDLE dst_descriptor_range_offset, + const D3D12_CPU_DESCRIPTOR_HANDLE src_descriptor_range_offset, + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) { + CopyDescriptors(1, &dst_descriptor_range_offset, &descriptor_count, 1, + &src_descriptor_range_offset, &descriptor_count, + descriptor_heap_type); +} + +D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE +MTLD3D12Device::GetResourceAllocationInfo( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_desc_count, + const D3D12_RESOURCE_DESC *resource_descs) { + __ret->Alignment = D3D12_DEFAULT_RESOURCE_PLACEMENT_ALIGNMENT; + __ret->SizeInBytes = 0; + for (UINT i = 0; i < resource_desc_count; i++) { + if (resource_descs[i].Dimension == D3D12_RESOURCE_DIMENSION_BUFFER) { + __ret->SizeInBytes += resource_descs[i].Width; + } else { + __ret->SizeInBytes += + resource_descs[i].Width * resource_descs[i].Height * + resource_descs[i].DepthOrArraySize; + } + } + return __ret; +} + +D3D12_HEAP_PROPERTIES* STDMETHODCALLTYPE +MTLD3D12Device::GetCustomHeapProperties(D3D12_HEAP_PROPERTIES *__ret, + UINT node_mask, + D3D12_HEAP_TYPE heap_type) { + __ret->Type = heap_type; + __ret->CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + __ret->MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + __ret->CreationNodeMask = 1; + __ret->VisibleNodeMask = 1; + return __ret; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommittedResource( + const D3D12_HEAP_PROPERTIES *heap_properties, D3D12_HEAP_FLAGS heap_flags, + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) { + TRACE("CreateCommittedResource dim=%u fmt=%u width=%llu state=0x%x heap_type=%u", desc ? desc->Dimension : 0xFF, desc ? desc->Format : 0, desc ? desc->Width : 0, initial_state, heap_properties ? heap_properties->Type : 0xFF); + CheckVtable("CreateCommittedResource"); + if (!desc || !resource) + return E_POINTER; + InitReturnPtr(resource); + + auto res = new MTLD3D12Resource(this, *desc, initial_state, + heap_properties ? *heap_properties + : D3D12_HEAP_PROPERTIES{}); + HRESULT hr = res->QueryInterface(riid, resource); + TRACE("CreateCommittedResource res_obj=%p out=%p hr=0x%lx", (void*)res, resource ? *resource : nullptr, hr); + if (resource && *resource == (void*)this) { + TRACE("!!! LEAK DETECTED: CreateCommittedResource returned device pointer %p as resource!", (void*)this); + } + if (FAILED(hr)) + res->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateHeap(const D3D12_HEAP_DESC *desc, REFIID riid, + void **heap) { + TRACE("CreateHeap size=%llu type=%u flags=0x%x", + desc ? (unsigned long long)desc->SizeInBytes : 0, + desc ? desc->Properties.Type : 0xFF, + desc ? desc->Flags : 0); + if (!desc || !heap) + return E_POINTER; + InitReturnPtr(heap); + + auto h = new MTLD3D12Heap(this, *desc); + HRESULT hr = h->QueryInterface(riid, heap); + if (FAILED(hr)) + h->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreatePlacedResource( + ID3D12Heap *heap, UINT64 heap_offset, const D3D12_RESOURCE_DESC *desc, + D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) { + TRACE("CreatePlacedResource heap=%p offset=%llu dim=%u fmt=%u w=%llu", + (void*)heap, (unsigned long long)heap_offset, + desc ? desc->Dimension : 0, desc ? desc->Format : 0, + desc ? desc->Width : 0); + if (!desc || !resource || !heap) + return E_POINTER; + InitReturnPtr(resource); + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_DEFAULT; + auto mt_heap = static_cast(heap); + if (mt_heap) { + heap_props = mt_heap->GetHeapDesc().Properties; + } + + auto res = new MTLD3D12Resource(this, *desc, initial_state, heap_props); + HRESULT hr = res->QueryInterface(riid, resource); + if (FAILED(hr)) + res->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateReservedResource( + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) { + TRACE("CreateReservedResource dim=%u fmt=%u w=%llu (fallback to committed)", desc ? desc->Dimension : 0, desc ? desc->Format : 0, desc ? desc->Width : 0); + CheckVtable("CreateReservedResource"); + if (!desc || !resource) + return E_POINTER; + InitReturnPtr(resource); + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_DEFAULT; + auto res = new MTLD3D12Resource(this, *desc, initial_state, heap_props); + HRESULT hr = res->QueryInterface(riid, resource); + if (FAILED(hr)) + res->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateSharedHandle( + ID3D12DeviceChild *object, const SECURITY_ATTRIBUTES *attributes, + DWORD access, const WCHAR *name, HANDLE *handle) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::OpenSharedHandle(HANDLE handle, REFIID riid, void **object) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::OpenSharedHandleByName(const WCHAR *name, DWORD access, + HANDLE *handle) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::MakeResident(UINT object_count, + ID3D12Pageable *const *objects) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::Evict(UINT object_count, ID3D12Pageable *const *objects) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateFence(UINT64 initial_value, D3D12_FENCE_FLAGS flags, + REFIID riid, void **fence) { + if (!fence) + return E_POINTER; + InitReturnPtr(fence); + + auto f = new MTLD3D12Fence(this, initial_value, flags); + TRACE("CreateFence init=%llu fence=%p", (unsigned long long)initial_value, (void *)f); + HRESULT hr = f->QueryInterface(riid, fence); + if (FAILED(hr)) + f->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::GetDeviceRemovedReason() { + return S_OK; +} + +void STDMETHODCALLTYPE MTLD3D12Device::GetCopyableFootprints( + const D3D12_RESOURCE_DESC *desc, UINT first_sub_resource, + UINT sub_resource_count, UINT64 base_offset, + D3D12_PLACED_SUBRESOURCE_FOOTPRINT *layouts, UINT *row_count, + UINT64 *row_size, UINT64 *total_bytes) { + if (total_bytes) + *total_bytes = 0; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::CreateQueryHeap(const D3D12_QUERY_HEAP_DESC *desc, + REFIID riid, void **heap) { + if (!desc || !heap) + return E_POINTER; + InitReturnPtr(heap); + + auto qh = new MTLD3D12QueryHeap(this, *desc); + HRESULT hr = qh->QueryInterface(riid, heap); + if (FAILED(hr)) + qh->Release(); + return hr; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Device::SetStablePowerState(WINBOOL enable) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommandSignature( + const D3D12_COMMAND_SIGNATURE_DESC *desc, + ID3D12RootSignature *root_signature, REFIID riid, + void **command_signature) { + if (!command_signature) + return E_POINTER; + InitReturnPtr(command_signature); + TRACE("CreateCommandSignature stride=%u num_args=%u", desc ? desc->ByteStride : 0, desc ? desc->NumArgumentDescs : 0); + if (!desc) + return E_INVALIDARG; + auto *obj = new MTLD3D12CommandSignature(this, *desc); + HRESULT hr = obj->QueryInterface(riid, command_signature); + if (FAILED(hr)) + obj->Release(); + return hr; +} + +void STDMETHODCALLTYPE MTLD3D12Device::GetResourceTiling( + ID3D12Resource *resource, UINT *total_tile_count, + D3D12_PACKED_MIP_INFO *packed_mip_info, + D3D12_TILE_SHAPE *standard_tile_shape, + UINT *sub_resource_tiling_count, UINT first_sub_resource_tiling, + D3D12_SUBRESOURCE_TILING *sub_resource_tilings) {} + +LUID* STDMETHODCALLTYPE MTLD3D12Device::GetAdapterLuid(LUID *__ret) { + *__ret = {}; + return __ret; +} + +void MTLD3D12Device::RegisterResource(MTLD3D12Resource *res) { + if (!res) return; + D3D12_GPU_VIRTUAL_ADDRESS addr = res->GetGPUVirtualAddress(); + if (addr) { + std::lock_guard lock(m_resource_mutex); + m_resources_by_gpu_addr[addr] = res; + } +} + +void MTLD3D12Device::UnregisterResource(MTLD3D12Resource *res) { + if (!res) return; + D3D12_GPU_VIRTUAL_ADDRESS addr = res->GetGPUVirtualAddress(); + if (addr) { + std::lock_guard lock(m_resource_mutex); + m_resources_by_gpu_addr.erase(addr); + } +} + +MTLD3D12Resource *MTLD3D12Device::LookupResourceByGPUAddress(D3D12_GPU_VIRTUAL_ADDRESS addr) { + if (!addr) return nullptr; + std::lock_guard lock(m_resource_mutex); + auto it = m_resources_by_gpu_addr.find(addr); + if (it != m_resources_by_gpu_addr.end()) + return it->second; + for (auto &[gpu_addr, res] : m_resources_by_gpu_addr) { + auto *desc = res->GetDesc({}); + if (desc && desc->Dimension == D3D12_RESOURCE_DIMENSION_BUFFER) { + if (addr >= gpu_addr && addr < gpu_addr + desc->Width) + return res; + } + } + return nullptr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreatePipelineLibrary( + const void *blob, SIZE_T blob_size, REFIID riid, void **lib) { + TRACE("CreatePipelineLibrary -> DXGI_ERROR_UNSUPPORTED"); + if (lib) *lib = nullptr; + return DXGI_ERROR_UNSUPPORTED; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::SetEventOnMultipleFenceCompletion( + ID3D12Fence *const *fences, const UINT64 *values, UINT fence_count, + D3D12_MULTIPLE_FENCE_WAIT_FLAGS flags, HANDLE event) { + TRACE("SetEventOnMultipleFenceCompletion count=%u flags=0x%x event=%p", + fence_count, flags, (void*)(uintptr_t)event); + if (!fences || !values || !event) + return E_POINTER; + + bool all_signaled = true; + for (UINT i = 0; i < fence_count; i++) { + if (fences[i]->GetCompletedValue() < values[i]) { + all_signaled = false; + break; + } + } + + if (all_signaled) { + SetEvent(event); + return S_OK; + } + + if (flags == D3D12_MULTIPLE_FENCE_WAIT_FLAG_ALL) { + for (UINT i = 0; i < fence_count; i++) { + fences[i]->SetEventOnCompletion(values[i], nullptr); + } + SetEvent(event); + } else { + for (UINT i = 0; i < fence_count; i++) { + if (fences[i]->GetCompletedValue() < values[i]) { + fences[i]->SetEventOnCompletion(values[i], nullptr); + SetEvent(event); + break; + } + } + } + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::SetResidencyPriority( + UINT object_count, ID3D12Pageable *const *objects, + const D3D12_RESIDENCY_PRIORITY *priorities) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreatePipelineState( + const D3D12_PIPELINE_STATE_STREAM_DESC *desc, REFIID riid, + void **ppPipelineState) { + TRACE("ID3D12Device2::CreatePipelineState ENTER: size=%zu", desc ? desc->SizeInBytes : 0); + + if (!desc || !desc->pPipelineStateSubobjectStream || !ppPipelineState) + return E_INVALIDARG; + + *ppPipelineState = nullptr; + + struct SubobjectHeader { + UINT Type; + UINT Size; + }; + + auto *stream = (uint8_t *)desc->pPipelineStateSubobjectStream; + auto *end = stream + desc->SizeInBytes; + + D3D12_GRAPHICS_PIPELINE_STATE_DESC graphics_desc = {}; + D3D12_COMPUTE_PIPELINE_STATE_DESC compute_desc = {}; + bool has_cs = false; + bool has_vs = false; + bool is_compute = true; + + while (stream + sizeof(SubobjectHeader) <= end) { + auto *header = (SubobjectHeader *)stream; + stream += sizeof(SubobjectHeader); + + if (stream + header->Size > end) + break; + + switch (header->Type) { + case 0: { // D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE + if (header->Size >= sizeof(ID3D12RootSignature*)) + graphics_desc.pRootSignature = *(ID3D12RootSignature**)stream; + break; + } + case 1: { // VS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) { + graphics_desc.VS = *(D3D12_SHADER_BYTECODE*)stream; + has_vs = true; + is_compute = false; + } + break; + } + case 2: { // PS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) + graphics_desc.PS = *(D3D12_SHADER_BYTECODE*)stream; + break; + } + case 3: { // DS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) + graphics_desc.DS = *(D3D12_SHADER_BYTECODE*)stream; + break; + } + case 4: { // HS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) + graphics_desc.HS = *(D3D12_SHADER_BYTECODE*)stream; + break; + } + case 5: { // GS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) + graphics_desc.GS = *(D3D12_SHADER_BYTECODE*)stream; + break; + } + case 6: { // CS + if (header->Size >= sizeof(D3D12_SHADER_BYTECODE)) { + compute_desc.CS = *(D3D12_SHADER_BYTECODE*)stream; + has_cs = true; + } + break; + } + case 7: { // STREAM_OUTPUT + if (header->Size >= sizeof(D3D12_STREAM_OUTPUT_DESC)) + graphics_desc.StreamOutput = *(D3D12_STREAM_OUTPUT_DESC*)stream; + break; + } + case 8: { // BLEND + if (header->Size >= sizeof(D3D12_BLEND_DESC)) + graphics_desc.BlendState = *(D3D12_BLEND_DESC*)stream; + break; + } + case 9: { // SAMPLE_MASK + if (header->Size >= sizeof(UINT)) + graphics_desc.SampleMask = *(UINT*)stream; + break; + } + case 10: { // RASTERIZER + if (header->Size >= sizeof(D3D12_RASTERIZER_DESC)) + graphics_desc.RasterizerState = *(D3D12_RASTERIZER_DESC*)stream; + break; + } + case 11: { // DEPTH_STENCIL + if (header->Size >= sizeof(D3D12_DEPTH_STENCIL_DESC)) + graphics_desc.DepthStencilState = *(D3D12_DEPTH_STENCIL_DESC*)stream; + break; + } + case 12: { // INPUT_LAYOUT + if (header->Size >= sizeof(D3D12_INPUT_LAYOUT_DESC)) + graphics_desc.InputLayout = *(D3D12_INPUT_LAYOUT_DESC*)stream; + break; + } + case 13: { // IB_STRIP_CUT_VALUE + if (header->Size >= sizeof(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE)) + graphics_desc.IBStripCutValue = *(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE*)stream; + break; + } + case 14: { // PRIMITIVE_TOPOLOGY + if (header->Size >= sizeof(D3D12_PRIMITIVE_TOPOLOGY_TYPE)) + graphics_desc.PrimitiveTopologyType = *(D3D12_PRIMITIVE_TOPOLOGY_TYPE*)stream; + break; + } + case 15: { // RENDER_TARGET_FORMATS + struct RTVFormats { UINT NumRenderTargets; DXGI_FORMAT RTFormats[8]; }; + if (header->Size >= sizeof(RTVFormats)) { + auto *fmt = (RTVFormats*)stream; + graphics_desc.NumRenderTargets = fmt->NumRenderTargets; + for (UINT i = 0; i < 8 && i < fmt->NumRenderTargets; i++) + graphics_desc.RTVFormats[i] = fmt->RTFormats[i]; + } + break; + } + case 16: { // DEPTH_STENCIL_FORMAT + if (header->Size >= sizeof(DXGI_FORMAT)) + graphics_desc.DSVFormat = *(DXGI_FORMAT*)stream; + break; + } + case 17: { // SAMPLE_DESC + if (header->Size >= sizeof(DXGI_SAMPLE_DESC)) + graphics_desc.SampleDesc = *(DXGI_SAMPLE_DESC*)stream; + break; + } + case 18: { // NODE_MASK + if (header->Size >= sizeof(UINT)) + graphics_desc.NodeMask = *(UINT*)stream; + break; + } + case 19: { // CACHED_PSO + if (header->Size >= sizeof(D3D12_CACHED_PIPELINE_STATE)) + graphics_desc.CachedPSO = *(D3D12_CACHED_PIPELINE_STATE*)stream; + break; + } + case 20: { // FLAGS + if (header->Size >= sizeof(D3D12_PIPELINE_STATE_FLAGS)) + graphics_desc.Flags = *(D3D12_PIPELINE_STATE_FLAGS*)stream; + break; + } + default: + TRACE("CreatePipelineState: unknown subobject type %u size %u", header->Type, header->Size); + break; + } + stream += header->Size; + } + + if (has_cs && is_compute) { + compute_desc.pRootSignature = graphics_desc.pRootSignature; + TRACE("ID3D12Device2::CreatePipelineState -> delegating to CreateComputePSO CS=%p", compute_desc.CS.pShaderBytecode); + return CreateComputePipelineState(&compute_desc, riid, ppPipelineState); + } + + TRACE("ID3D12Device2::CreatePipelineState -> delegating to CreateGraphicsPSO VS=%p PS=%p NumRT=%u", + graphics_desc.VS.pShaderBytecode, graphics_desc.PS.pShaderBytecode, graphics_desc.NumRenderTargets); + return CreateGraphicsPipelineState(&graphics_desc, riid, ppPipelineState); +} + +/*** ID3D12Device3 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::OpenExistingHeapFromAddress( + const void *address, REFIID riid, void **heap) { + TRACE("ID3D12Device3::OpenExistingHeapFromAddress -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::OpenExistingHeapFromFileMapping( + HANDLE file_mapping, REFIID riid, void **heap) { + TRACE("ID3D12Device3::OpenExistingHeapFromFileMapping -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::EnqueueMakeResident( + D3D12_RESIDENCY_FLAGS flags, UINT num_objects, + ID3D12Pageable *const *objects, ID3D12Fence *fence, + UINT64 fence_value) { + TRACE("ID3D12Device3::EnqueueMakeResident -> S_OK (delegating to MakeResident)"); + HRESULT hr = MakeResident(num_objects, objects); + if (SUCCEEDED(hr) && fence) { + fence->Signal(fence_value); + } + return hr; +} + +/*** ID3D12Device4 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommandList1( + UINT node_mask, D3D12_COMMAND_LIST_TYPE type, + D3D12_COMMAND_LIST_FLAGS flags, REFIID riid, + void **command_list) { + TRACE("ID3D12Device4::CreateCommandList1 -> delegating to CreateCommandList"); + return CreateCommandList(node_mask, type, nullptr, nullptr, riid, command_list); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateProtectedResourceSession( + const D3D12_PROTECTED_RESOURCE_SESSION_DESC *desc, REFIID riid, + void **session) { + TRACE("ID3D12Device4::CreateProtectedResourceSession -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommittedResource1( + const D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS heap_flags, const D3D12_RESOURCE_DESC *desc, + D3D12_RESOURCE_STATES initial_resource_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid_resource, void **resource) { + if (protected_session) { + TRACE("ID3D12Device4::CreateCommittedResource1 -> E_NOTIMPL (protected session)"); + return E_NOTIMPL; + } + return CreateCommittedResource(heap_properties, heap_flags, desc, + initial_resource_state, optimized_clear_value, riid_resource, resource); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateHeap1( + const D3D12_HEAP_DESC *desc, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid, void **heap) { + if (protected_session) { + TRACE("ID3D12Device4::CreateHeap1 -> E_NOTIMPL (protected session)"); + return E_NOTIMPL; + } + return CreateHeap(desc, riid, heap); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateReservedResource1( + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid, void **resource) { + if (protected_session) { + TRACE("ID3D12Device4::CreateReservedResource1 -> E_NOTIMPL (protected session)"); + return E_NOTIMPL; + } + return CreateReservedResource(desc, initial_state, optimized_clear_value, riid, resource); +} + +D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE MTLD3D12Device::GetResourceAllocationInfo1( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_descs_count, const D3D12_RESOURCE_DESC *resource_descs, + D3D12_RESOURCE_ALLOCATION_INFO1 *resource_allocation_info1) { + TRACE("ID3D12Device4::GetResourceAllocationInfo1 -> delegating"); + return GetResourceAllocationInfo(__ret, visible_mask, resource_descs_count, resource_descs); +} + +/*** ID3D12Device5 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateLifetimeTracker( + ID3D12LifetimeOwner *owner, REFIID riid, void **tracker) { + TRACE("ID3D12Device5::CreateLifetimeTracker -> E_NOTIMPL"); + return E_NOTIMPL; +} + +void STDMETHODCALLTYPE MTLD3D12Device::RemoveDevice() { + TRACE("ID3D12Device5::RemoveDevice"); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::EnumerateMetaCommands( + UINT *meta_commands_count, D3D12_META_COMMAND_DESC *descs) { + TRACE("ID3D12Device5::EnumerateMetaCommands -> E_NOTIMPL"); + if (meta_commands_count) *meta_commands_count = 0; + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::EnumerateMetaCommandParameters( + REFGUID command_id, D3D12_META_COMMAND_PARAMETER_STAGE stage, + UINT *total_structure_size_in_bytes, UINT *parameter_count, + D3D12_META_COMMAND_PARAMETER_DESC *parameter_descs) { + TRACE("ID3D12Device5::EnumerateMetaCommandParameters -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateMetaCommand( + REFGUID command_id, UINT node_mask, + const void *creation_parameters_data, + SIZE_T creation_parameters_data_size_in_bytes, + REFIID riid, void **meta_command) { + TRACE("ID3D12Device5::CreateMetaCommand -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateStateObject( + const D3D12_STATE_OBJECT_DESC *desc, REFIID riid, + void **state_object) { + TRACE("ID3D12Device5::CreateStateObject -> E_NOTIMPL"); + return E_NOTIMPL; +} + +void STDMETHODCALLTYPE MTLD3D12Device::GetRaytracingAccelerationStructurePrebuildInfo( + const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS *desc, + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO *info) { + TRACE("ID3D12Device5::GetRaytracingAccelerationStructurePrebuildInfo"); + if (info) { + memset(info, 0, sizeof(*info)); + } +} + +D3D12_DRIVER_MATCHING_IDENTIFIER_STATUS STDMETHODCALLTYPE MTLD3D12Device::CheckDriverMatchingIdentifier( + D3D12_SERIALIZED_DATA_TYPE serialized_data_type, + const D3D12_SERIALIZED_DATA_DRIVER_MATCHING_IDENTIFIER *identifier_to_check) { + TRACE("ID3D12Device5::CheckDriverMatchingIdentifier -> UNRECOGNIZED"); + return D3D12_DRIVER_MATCHING_IDENTIFIER_UNRECOGNIZED; +} + +/*** ID3D12Device6 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::SetBackgroundProcessingMode( + D3D12_BACKGROUND_PROCESSING_MODE mode, + D3D12_MEASUREMENTS_ACTION action, HANDLE event, + WINBOOL *further_measurements_desired) { + TRACE("ID3D12Device6::SetBackgroundProcessingMode -> E_NOTIMPL"); + return E_NOTIMPL; +} + +/*** ID3D12Device7 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::AddToStateObject( + const D3D12_STATE_OBJECT_DESC *addition, + ID3D12StateObject *state_object_to_grow_from, + REFIID riid, void **new_state_object) { + TRACE("ID3D12Device7::AddToStateObject -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateProtectedResourceSession1( + const D3D12_PROTECTED_RESOURCE_SESSION_DESC1 *desc, + REFIID riid, void **session) { + TRACE("ID3D12Device7::CreateProtectedResourceSession1 -> E_NOTIMPL"); + return E_NOTIMPL; +} + +/*** ID3D12Device8 ***/ +static const int MAX_DESCS = 256; + +D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE MTLD3D12Device::GetResourceAllocationInfo2( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_descs_count, const D3D12_RESOURCE_DESC1 *resource_descs, + D3D12_RESOURCE_ALLOCATION_INFO1 *resource_allocation_info1) { + TRACE("ID3D12Device8::GetResourceAllocationInfo2 -> delegating"); + D3D12_RESOURCE_DESC descs_compat[MAX_DESCS]; + for (UINT i = 0; i < resource_descs_count && i < MAX_DESCS; i++) { + memcpy(&descs_compat[i], &resource_descs[i], sizeof(D3D12_RESOURCE_DESC)); + } + return GetResourceAllocationInfo(__ret, visible_mask, resource_descs_count, descs_compat); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommittedResource2( + const D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS heap_flags, const D3D12_RESOURCE_DESC1 *desc, + D3D12_RESOURCE_STATES initial_resource_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid_resource, void **resource) { + if (protected_session) { + TRACE("ID3D12Device8::CreateCommittedResource2 -> E_NOTIMPL (protected session)"); + return E_NOTIMPL; + } + D3D12_RESOURCE_DESC desc_compat; + memcpy(&desc_compat, desc, sizeof(D3D12_RESOURCE_DESC)); + return CreateCommittedResource(heap_properties, heap_flags, &desc_compat, + initial_resource_state, optimized_clear_value, riid_resource, resource); +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreatePlacedResource1( + ID3D12Heap *heap, UINT64 heap_offset, + const D3D12_RESOURCE_DESC1 *desc, + D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + REFIID riid, void **resource) { + D3D12_RESOURCE_DESC desc_compat; + memcpy(&desc_compat, desc, sizeof(D3D12_RESOURCE_DESC)); + return CreatePlacedResource(heap, heap_offset, &desc_compat, + initial_state, optimized_clear_value, riid, resource); +} + +void STDMETHODCALLTYPE MTLD3D12Device::CreateSamplerFeedbackUnorderedAccessView( + ID3D12Resource *targeted_resource, ID3D12Resource *feedback_resource, + D3D12_CPU_DESCRIPTOR_HANDLE dst_descriptor) { + TRACE("ID3D12Device8::CreateSamplerFeedbackUnorderedAccessView -> noop"); +} + +void STDMETHODCALLTYPE MTLD3D12Device::GetCopyableFootprints1( + const D3D12_RESOURCE_DESC1 *resource_desc, UINT first_subresource, + UINT subresources_count, UINT64 base_offset, + D3D12_PLACED_SUBRESOURCE_FOOTPRINT *layouts, UINT *rows_count, + UINT64 *row_size_in_bytes, UINT64 *total_bytes) { + D3D12_RESOURCE_DESC desc_compat; + memcpy(&desc_compat, resource_desc, sizeof(D3D12_RESOURCE_DESC)); + GetCopyableFootprints(&desc_compat, first_subresource, subresources_count, + base_offset, layouts, rows_count, row_size_in_bytes, total_bytes); +} + +/*** ID3D12Device9 ***/ +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateShaderCacheSession( + const D3D12_SHADER_CACHE_SESSION_DESC *desc, REFIID riid, + void **session) { + TRACE("ID3D12Device9::CreateShaderCacheSession -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::ShaderCacheControl( + D3D12_SHADER_CACHE_KIND_FLAGS kinds, + D3D12_SHADER_CACHE_CONTROL_FLAGS control) { + TRACE("ID3D12Device9::ShaderCacheControl -> E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Device::CreateCommandQueue1( + const D3D12_COMMAND_QUEUE_DESC *desc, REFIID creator_id, + REFIID riid, void **command_queue) { + TRACE("ID3D12Device9::CreateCommandQueue1 -> delegating to CreateCommandQueue"); + return CreateCommandQueue(desc, riid, command_queue); +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_device.hpp b/src/d3d12/d3d12_device.hpp new file mode 100644 index 000000000..9a0fd1e5f --- /dev/null +++ b/src/d3d12/d3d12_device.hpp @@ -0,0 +1,364 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "dxgi_interfaces.h" +#include "dxmt_device.hpp" +#include +#include +#include +#include + +namespace dxmt { + +class MTLD3D12Resource; + +class MTLD3D12Device : public ID3D12Device9 { +public: + MTLD3D12Device(std::unique_ptr &&device, + IMTLDXGIAdapter *pAdapter); + ~MTLD3D12Device(); + + void *operator new(size_t size) { + void *ptr = VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + if (!ptr) throw std::bad_alloc(); + return ptr; + } + void operator delete(void *ptr) { + VirtualFree(ptr, 0, MEM_RELEASE); + } + + void SetDXGIDevice(IMTLDXGIDevice *dxgi_device) { m_dxgi_device = dxgi_device; } + + WMT::Device GetMTLDevice(); + Device &GetDXMTDevice(); + + void RegisterResource(MTLD3D12Resource *res); + void UnregisterResource(MTLD3D12Resource *res); + MTLD3D12Resource *LookupResourceByGPUAddress(D3D12_GPU_VIRTUAL_ADDRESS addr); + + /*** IUnknown ***/ + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + /*** ID3D12Object ***/ + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + /*** ID3D12Device ***/ + UINT STDMETHODCALLTYPE GetNodeCount() override; + + HRESULT STDMETHODCALLTYPE CreateCommandQueue( + const D3D12_COMMAND_QUEUE_DESC *desc, REFIID riid, + void **command_queue) override; + + HRESULT STDMETHODCALLTYPE CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE type, REFIID riid, + void **command_allocator) override; + + HRESULT STDMETHODCALLTYPE CreateGraphicsPipelineState( + const D3D12_GRAPHICS_PIPELINE_STATE_DESC *desc, REFIID riid, + void **pipeline_state) override; + + HRESULT STDMETHODCALLTYPE CreateComputePipelineState( + const D3D12_COMPUTE_PIPELINE_STATE_DESC *desc, REFIID riid, + void **pipeline_state) override; + + HRESULT STDMETHODCALLTYPE CreateCommandList( + UINT node_mask, D3D12_COMMAND_LIST_TYPE type, + ID3D12CommandAllocator *command_allocator, + ID3D12PipelineState *initial_pipeline_state, REFIID riid, + void **command_list) override; + + HRESULT STDMETHODCALLTYPE CheckFeatureSupport( + D3D12_FEATURE feature, void *feature_data, + UINT feature_data_size) override; + + HRESULT STDMETHODCALLTYPE CreateDescriptorHeap( + const D3D12_DESCRIPTOR_HEAP_DESC *desc, REFIID riid, + void **descriptor_heap) override; + + UINT STDMETHODCALLTYPE GetDescriptorHandleIncrementSize( + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) override; + + HRESULT STDMETHODCALLTYPE CreateRootSignature( + UINT node_mask, const void *bytecode, SIZE_T bytecode_length, + REFIID riid, void **root_signature) override; + + void STDMETHODCALLTYPE CreateConstantBufferView( + const D3D12_CONSTANT_BUFFER_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CreateShaderResourceView( + ID3D12Resource *resource, + const D3D12_SHADER_RESOURCE_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CreateUnorderedAccessView( + ID3D12Resource *resource, ID3D12Resource *counter_resource, + const D3D12_UNORDERED_ACCESS_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CreateRenderTargetView( + ID3D12Resource *resource, + const D3D12_RENDER_TARGET_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CreateDepthStencilView( + ID3D12Resource *resource, + const D3D12_DEPTH_STENCIL_VIEW_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CreateSampler( + const D3D12_SAMPLER_DESC *desc, + D3D12_CPU_DESCRIPTOR_HANDLE descriptor) override; + + void STDMETHODCALLTYPE CopyDescriptors( + UINT dst_descriptor_range_count, + const D3D12_CPU_DESCRIPTOR_HANDLE *dst_descriptor_range_offsets, + const UINT *dst_descriptor_range_sizes, + UINT src_descriptor_range_count, + const D3D12_CPU_DESCRIPTOR_HANDLE *src_descriptor_range_offsets, + const UINT *src_descriptor_range_sizes, + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) override; + + void STDMETHODCALLTYPE CopyDescriptorsSimple( + UINT descriptor_count, + const D3D12_CPU_DESCRIPTOR_HANDLE dst_descriptor_range_offset, + const D3D12_CPU_DESCRIPTOR_HANDLE src_descriptor_range_offset, + D3D12_DESCRIPTOR_HEAP_TYPE descriptor_heap_type) override; + + D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE GetResourceAllocationInfo( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_desc_count, + const D3D12_RESOURCE_DESC *resource_descs) override; + + D3D12_HEAP_PROPERTIES* STDMETHODCALLTYPE GetCustomHeapProperties( + D3D12_HEAP_PROPERTIES *__ret, UINT node_mask, + D3D12_HEAP_TYPE heap_type) override; + + HRESULT STDMETHODCALLTYPE CreateCommittedResource( + const D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS heap_flags, const D3D12_RESOURCE_DESC *desc, + D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) override; + + HRESULT STDMETHODCALLTYPE CreateHeap(const D3D12_HEAP_DESC *desc, + REFIID riid, + void **heap) override; + + HRESULT STDMETHODCALLTYPE CreatePlacedResource( + ID3D12Heap *heap, UINT64 heap_offset, + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) override; + + HRESULT STDMETHODCALLTYPE CreateReservedResource( + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, REFIID riid, + void **resource) override; + + HRESULT STDMETHODCALLTYPE CreateSharedHandle( + ID3D12DeviceChild *object, const SECURITY_ATTRIBUTES *attributes, + DWORD access, const WCHAR *name, HANDLE *handle) override; + + HRESULT STDMETHODCALLTYPE OpenSharedHandle(HANDLE handle, REFIID riid, + void **object) override; + + HRESULT STDMETHODCALLTYPE OpenSharedHandleByName(const WCHAR *name, + DWORD access, + HANDLE *handle) override; + + HRESULT STDMETHODCALLTYPE MakeResident( + UINT object_count, ID3D12Pageable *const *objects) override; + + HRESULT STDMETHODCALLTYPE Evict(UINT object_count, + ID3D12Pageable *const *objects) override; + + HRESULT STDMETHODCALLTYPE CreateFence(UINT64 initial_value, + D3D12_FENCE_FLAGS flags, REFIID riid, + void **fence) override; + + HRESULT STDMETHODCALLTYPE GetDeviceRemovedReason() override; + + void STDMETHODCALLTYPE GetCopyableFootprints( + const D3D12_RESOURCE_DESC *desc, UINT first_sub_resource, + UINT sub_resource_count, UINT64 base_offset, + D3D12_PLACED_SUBRESOURCE_FOOTPRINT *layouts, UINT *row_count, + UINT64 *row_size, UINT64 *total_bytes) override; + + HRESULT STDMETHODCALLTYPE CreateQueryHeap(const D3D12_QUERY_HEAP_DESC *desc, + REFIID riid, + void **heap) override; + + HRESULT STDMETHODCALLTYPE SetStablePowerState(WINBOOL enable) override; + + HRESULT STDMETHODCALLTYPE CreateCommandSignature( + const D3D12_COMMAND_SIGNATURE_DESC *desc, + ID3D12RootSignature *root_signature, REFIID riid, + void **command_signature) override; + + void STDMETHODCALLTYPE GetResourceTiling( + ID3D12Resource *resource, UINT *total_tile_count, + D3D12_PACKED_MIP_INFO *packed_mip_info, + D3D12_TILE_SHAPE *standard_tile_shape, + UINT *sub_resource_tiling_count, + UINT first_sub_resource_tiling, + D3D12_SUBRESOURCE_TILING *sub_resource_tilings) override; + + LUID* STDMETHODCALLTYPE GetAdapterLuid(LUID *__ret) override; + + HRESULT STDMETHODCALLTYPE CreatePipelineLibrary( + const void *blob, SIZE_T blob_size, REFIID riid, + void **lib) override; + + HRESULT STDMETHODCALLTYPE SetEventOnMultipleFenceCompletion( + ID3D12Fence *const *fences, const UINT64 *values, UINT fence_count, + D3D12_MULTIPLE_FENCE_WAIT_FLAGS flags, HANDLE event) override; + + HRESULT STDMETHODCALLTYPE SetResidencyPriority( + UINT object_count, ID3D12Pageable *const *objects, + const D3D12_RESIDENCY_PRIORITY *priorities) override; + + /*** ID3D12Device2 ***/ + HRESULT STDMETHODCALLTYPE CreatePipelineState( + const D3D12_PIPELINE_STATE_STREAM_DESC *desc, REFIID riid, + void **ppPipelineState) override; + + /*** ID3D12Device3 ***/ + HRESULT STDMETHODCALLTYPE OpenExistingHeapFromAddress( + const void *address, REFIID riid, void **heap) override; + HRESULT STDMETHODCALLTYPE OpenExistingHeapFromFileMapping( + HANDLE file_mapping, REFIID riid, void **heap) override; + HRESULT STDMETHODCALLTYPE EnqueueMakeResident( + D3D12_RESIDENCY_FLAGS flags, UINT num_objects, + ID3D12Pageable *const *objects, ID3D12Fence *fence, + UINT64 fence_value) override; + + /*** ID3D12Device4 ***/ + HRESULT STDMETHODCALLTYPE CreateCommandList1( + UINT node_mask, D3D12_COMMAND_LIST_TYPE type, + D3D12_COMMAND_LIST_FLAGS flags, REFIID riid, + void **command_list) override; + HRESULT STDMETHODCALLTYPE CreateProtectedResourceSession( + const D3D12_PROTECTED_RESOURCE_SESSION_DESC *desc, REFIID riid, + void **session) override; + HRESULT STDMETHODCALLTYPE CreateCommittedResource1( + const D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS heap_flags, const D3D12_RESOURCE_DESC *desc, + D3D12_RESOURCE_STATES initial_resource_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid_resource, void **resource) override; + HRESULT STDMETHODCALLTYPE CreateHeap1(const D3D12_HEAP_DESC *desc, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid, void **heap) override; + HRESULT STDMETHODCALLTYPE CreateReservedResource1( + const D3D12_RESOURCE_DESC *desc, D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid, void **resource) override; + D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE GetResourceAllocationInfo1( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_descs_count, const D3D12_RESOURCE_DESC *resource_descs, + D3D12_RESOURCE_ALLOCATION_INFO1 *resource_allocation_info1) override; + + /*** ID3D12Device5 ***/ + HRESULT STDMETHODCALLTYPE CreateLifetimeTracker( + ID3D12LifetimeOwner *owner, REFIID riid, void **tracker) override; + void STDMETHODCALLTYPE RemoveDevice() override; + HRESULT STDMETHODCALLTYPE EnumerateMetaCommands( + UINT *meta_commands_count, D3D12_META_COMMAND_DESC *descs) override; + HRESULT STDMETHODCALLTYPE EnumerateMetaCommandParameters( + REFGUID command_id, D3D12_META_COMMAND_PARAMETER_STAGE stage, + UINT *total_structure_size_in_bytes, UINT *parameter_count, + D3D12_META_COMMAND_PARAMETER_DESC *parameter_descs) override; + HRESULT STDMETHODCALLTYPE CreateMetaCommand( + REFGUID command_id, UINT node_mask, + const void *creation_parameters_data, + SIZE_T creation_parameters_data_size_in_bytes, + REFIID riid, void **meta_command) override; + HRESULT STDMETHODCALLTYPE CreateStateObject( + const D3D12_STATE_OBJECT_DESC *desc, REFIID riid, + void **state_object) override; + void STDMETHODCALLTYPE GetRaytracingAccelerationStructurePrebuildInfo( + const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS *desc, + D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO *info) override; + D3D12_DRIVER_MATCHING_IDENTIFIER_STATUS STDMETHODCALLTYPE CheckDriverMatchingIdentifier( + D3D12_SERIALIZED_DATA_TYPE serialized_data_type, + const D3D12_SERIALIZED_DATA_DRIVER_MATCHING_IDENTIFIER *identifier_to_check) override; + + /*** ID3D12Device6 ***/ + HRESULT STDMETHODCALLTYPE SetBackgroundProcessingMode( + D3D12_BACKGROUND_PROCESSING_MODE mode, + D3D12_MEASUREMENTS_ACTION action, HANDLE event, + WINBOOL *further_measurements_desired) override; + + /*** ID3D12Device7 ***/ + HRESULT STDMETHODCALLTYPE AddToStateObject( + const D3D12_STATE_OBJECT_DESC *addition, + ID3D12StateObject *state_object_to_grow_from, + REFIID riid, void **new_state_object) override; + HRESULT STDMETHODCALLTYPE CreateProtectedResourceSession1( + const D3D12_PROTECTED_RESOURCE_SESSION_DESC1 *desc, + REFIID riid, void **session) override; + + /*** ID3D12Device8 ***/ + D3D12_RESOURCE_ALLOCATION_INFO* STDMETHODCALLTYPE GetResourceAllocationInfo2( + D3D12_RESOURCE_ALLOCATION_INFO *__ret, UINT visible_mask, + UINT resource_descs_count, const D3D12_RESOURCE_DESC1 *resource_descs, + D3D12_RESOURCE_ALLOCATION_INFO1 *resource_allocation_info1) override; + HRESULT STDMETHODCALLTYPE CreateCommittedResource2( + const D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS heap_flags, const D3D12_RESOURCE_DESC1 *desc, + D3D12_RESOURCE_STATES initial_resource_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + ID3D12ProtectedResourceSession *protected_session, + REFIID riid_resource, void **resource) override; + HRESULT STDMETHODCALLTYPE CreatePlacedResource1( + ID3D12Heap *heap, UINT64 heap_offset, + const D3D12_RESOURCE_DESC1 *desc, + D3D12_RESOURCE_STATES initial_state, + const D3D12_CLEAR_VALUE *optimized_clear_value, + REFIID riid, void **resource) override; + void STDMETHODCALLTYPE CreateSamplerFeedbackUnorderedAccessView( + ID3D12Resource *targeted_resource, ID3D12Resource *feedback_resource, + D3D12_CPU_DESCRIPTOR_HANDLE dst_descriptor) override; + void STDMETHODCALLTYPE GetCopyableFootprints1( + const D3D12_RESOURCE_DESC1 *resource_desc, UINT first_subresource, + UINT subresources_count, UINT64 base_offset, + D3D12_PLACED_SUBRESOURCE_FOOTPRINT *layouts, UINT *rows_count, + UINT64 *row_size_in_bytes, UINT64 *total_bytes) override; + + /*** ID3D12Device9 ***/ + HRESULT STDMETHODCALLTYPE CreateShaderCacheSession( + const D3D12_SHADER_CACHE_SESSION_DESC *desc, REFIID riid, + void **session) override; + HRESULT STDMETHODCALLTYPE ShaderCacheControl( + D3D12_SHADER_CACHE_KIND_FLAGS kinds, + D3D12_SHADER_CACHE_CONTROL_FLAGS control) override; + HRESULT STDMETHODCALLTYPE CreateCommandQueue1( + const D3D12_COMMAND_QUEUE_DESC *desc, REFIID creator_id, + REFIID riid, void **command_queue) override; + +private: + std::unique_ptr m_device; + Com m_adapter; + IMTLDXGIDevice *m_dxgi_device = nullptr; + std::atomic m_refCount = {1ul}; + std::atomic m_refPrivate = {1ul}; + std::mutex m_resource_mutex; + std::unordered_map m_resources_by_gpu_addr; + void *m_expected_vtable = nullptr; + void CheckVtable(const char *where); +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_dxgi_device.cpp b/src/d3d12/d3d12_dxgi_device.cpp new file mode 100644 index 000000000..e99bd4cd8 --- /dev/null +++ b/src/d3d12/d3d12_dxgi_device.cpp @@ -0,0 +1,173 @@ +#include "d3d12_dxgi_device.hpp" +#include "d3d12_device.hpp" +#include "d3d12_swapchain.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include + +#define DDTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "DXGIDev::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12DXGIDevice::MTLD3D12DXGIDevice(std::unique_ptr &&device, + IMTLDXGIAdapter *adapter) + : m_adapter(adapter) { + if (m_adapter) + m_adapter->AddRef(); + void *dev_mem = VirtualAlloc((void*)0x500000000ULL, sizeof(MTLD3D12Device), + MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + if (!dev_mem) dev_mem = VirtualAlloc(nullptr, sizeof(MTLD3D12Device), + MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + m_d3d12_device = ::new(dev_mem) MTLD3D12Device(std::move(device), m_adapter.ptr()); + DDTRACE("D3D12Device at %p (VirtualAlloc)", (void*)m_d3d12_device); + m_d3d12_device->SetDXGIDevice(this); + Logger::info("D3D12DXGIDevice created"); +} + +MTLD3D12DXGIDevice::~MTLD3D12DXGIDevice() { + if (m_d3d12_device) + m_d3d12_device->Release(); +} + +ULONG STDMETHODCALLTYPE MTLD3D12DXGIDevice::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12DXGIDevice::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::GetPrivateData(REFGUID Name, UINT *pDataSize, void *pData) { + DDTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::SetPrivateData(REFGUID Name, UINT DataSize, const void *pData) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::SetPrivateDataInterface(REFGUID Name, const IUnknown *pUnknown) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::QueryInterface(REFIID riid, void **ppvObject) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "DXGIDevice::QI(%s)\n", str::format(riid).c_str()); fclose(f); } + } + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == __uuidof(IUnknown) || riid == __uuidof(IDXGIObject) || + riid == __uuidof(IDXGIDevice) || riid == __uuidof(IDXGIDevice1) || + riid == __uuidof(IDXGIDevice2) || riid == __uuidof(IDXGIDevice3) || + riid == __uuidof(IMTLDXGIDevice)) { + *ppvObject = ref(this); + return S_OK; + } + + if (riid == __uuidof(ID3D12Device) || riid == __uuidof(ID3D12Object) || + riid == __uuidof(ID3D12DeviceChild)) { + return m_d3d12_device->QueryInterface(riid, ppvObject); + } + + Logger::warn(str::format("D3D12DXGIDevice::QueryInterface: unknown ", riid)); + return E_NOINTERFACE; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::GetParent(REFIID riid, void **ppParent) { + return m_adapter->QueryInterface(riid, ppParent); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::GetAdapter(IDXGIAdapter **pAdapter) { + if (!pAdapter) + return DXGI_ERROR_INVALID_CALL; + *pAdapter = m_adapter.ref(); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::CreateSurface(const DXGI_SURFACE_DESC *desc, + UINT surface_count, DXGI_USAGE usage, + const DXGI_SHARED_RESOURCE *shared_resource, + IDXGISurface **surface) { + DDTRACE("CreateSurface E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::QueryResourceResidency(IUnknown *const *ppResources, + DXGI_RESIDENCY *pResidency, + UINT ResourceCount) { + if (!ppResources || !pResidency) + return E_INVALIDARG; + for (UINT i = 0; i < ResourceCount; i++) + pResidency[i] = DXGI_RESIDENCY_FULLY_RESIDENT; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::SetGPUThreadPriority(INT Priority) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::GetGPUThreadPriority(INT *pPriority) { + if (pPriority) + *pPriority = 0; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::SetMaximumFrameLatency(UINT MaxLatency) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::GetMaximumFrameLatency(UINT *pMaxLatency) { + if (pMaxLatency) + *pMaxLatency = 2; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::OfferResources(UINT NumResources, + IDXGIResource *const *ppResources, + DXGI_OFFER_RESOURCE_PRIORITY Priority) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::ReclaimResources(UINT NumResources, + IDXGIResource *const *ppResources, + WINBOOL *pDiscarded) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12DXGIDevice::EnqueueSetEvent(HANDLE hEvent) { return E_FAIL; } + +void STDMETHODCALLTYPE MTLD3D12DXGIDevice::Trim() {} + +WMT::Device STDMETHODCALLTYPE MTLD3D12DXGIDevice::GetMTLDevice() { + return m_adapter->GetMTLDevice(); +} + +D3DKMT_HANDLE STDMETHODCALLTYPE MTLD3D12DXGIDevice::GetLocalD3DKMT() { + return m_kmt; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12DXGIDevice::CreateSwapChain( + IDXGIFactory1 *pFactory, HWND hWnd, const DXGI_SWAP_CHAIN_DESC1 *pDesc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *pFullscreenDesc, + IDXGISwapChain1 **ppSwapChain) { + DDTRACE("CreateSwapChain called"); + return dxmt::CreateD3D12SwapChain(pFactory, m_d3d12_device, this, hWnd, + pDesc, pFullscreenDesc, ppSwapChain); +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_dxgi_device.hpp b/src/d3d12/d3d12_dxgi_device.hpp new file mode 100644 index 000000000..641d3f3b6 --- /dev/null +++ b/src/d3d12/d3d12_dxgi_device.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "dxgi_interfaces.h" +#include "dxmt_device.hpp" +#include "d3d12.h" +#include "Metal.hpp" +#include +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12DXGIDevice : public IMTLDXGIDevice { +public: + MTLD3D12DXGIDevice(std::unique_ptr &&device, + IMTLDXGIAdapter *adapter); + ~MTLD3D12DXGIDevice(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID Name, UINT *pDataSize, void *pData) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID Name, UINT DataSize, const void *pData) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface(REFGUID Name, const IUnknown *pUnknown) override; + HRESULT STDMETHODCALLTYPE GetParent(REFIID riid, void **ppParent) override; + + HRESULT STDMETHODCALLTYPE GetAdapter(IDXGIAdapter **pAdapter) override; + HRESULT STDMETHODCALLTYPE CreateSurface(const DXGI_SURFACE_DESC *desc, + UINT surface_count, DXGI_USAGE usage, + const DXGI_SHARED_RESOURCE *shared_resource, + IDXGISurface **surface) override; + HRESULT STDMETHODCALLTYPE QueryResourceResidency(IUnknown *const *ppResources, + DXGI_RESIDENCY *pResidency, + UINT ResourceCount) override; + HRESULT STDMETHODCALLTYPE SetGPUThreadPriority(INT Priority) override; + HRESULT STDMETHODCALLTYPE GetGPUThreadPriority(INT *pPriority) override; + HRESULT STDMETHODCALLTYPE SetMaximumFrameLatency(UINT MaxLatency) override; + HRESULT STDMETHODCALLTYPE GetMaximumFrameLatency(UINT *pMaxLatency) override; + HRESULT STDMETHODCALLTYPE OfferResources(UINT NumResources, + IDXGIResource *const *ppResources, + DXGI_OFFER_RESOURCE_PRIORITY Priority) override; + HRESULT STDMETHODCALLTYPE ReclaimResources(UINT NumResources, + IDXGIResource *const *ppResources, + WINBOOL *pDiscarded) override; + HRESULT STDMETHODCALLTYPE EnqueueSetEvent(HANDLE hEvent) override; + void STDMETHODCALLTYPE Trim() override; + + WMT::Device STDMETHODCALLTYPE GetMTLDevice() override; + D3DKMT_HANDLE STDMETHODCALLTYPE GetLocalD3DKMT() override; + HRESULT STDMETHODCALLTYPE CreateSwapChain( + IDXGIFactory1 *pFactory, HWND hWnd, const DXGI_SWAP_CHAIN_DESC1 *pDesc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *pFullscreenDesc, + IDXGISwapChain1 **ppSwapChain) override; + + MTLD3D12Device *GetD3D12Device() { return m_d3d12_device; } + +private: + Com m_adapter; + MTLD3D12Device *m_d3d12_device; + D3DKMT_HANDLE m_kmt = 0; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_fence.cpp b/src/d3d12/d3d12_fence.cpp new file mode 100644 index 000000000..1d119e025 --- /dev/null +++ b/src/d3d12/d3d12_fence.cpp @@ -0,0 +1,116 @@ +#include "d3d12_fence.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +#define FTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "Fence::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12Fence::MTLD3D12Fence(MTLD3D12Device *device, uint64_t initial_value, + D3D12_FENCE_FLAGS flags) + : m_device(device), m_flags(flags), m_value(initial_value) { + m_device->AddRef(); + auto wmt_device = m_device->GetDXMTDevice().device(); + m_shared_event = wmt_device.newSharedEvent(); + m_shared_event.signalValue(initial_value); + Logger::info(str::format("D3D12Fence: created value=", initial_value)); +} + +MTLD3D12Fence::~MTLD3D12Fence() { + m_shared_event = nullptr; + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12Fence) { + *ppvObject = ref(this); + return S_OK; + } + FTRACE("QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12Fence::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12Fence::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::GetPrivateData(REFGUID guid, UINT *data_size, void *data) { + FTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::SetPrivateDataInterface(REFGUID guid, const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Fence::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +uint64_t STDMETHODCALLTYPE MTLD3D12Fence::GetCompletedValue() { + FTRACE("GetCompletedValue -> %llu", (unsigned long long)m_value.load(std::memory_order_acquire)); + return m_value.load(std::memory_order_acquire); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Fence::SetEventOnCompletion(uint64_t value, HANDLE event) { + FTRACE("SetEventOnCompletion value=%llu current=%llu this=%p event=%p", (unsigned long long)value, (unsigned long long)m_value.load(), (void *)this, (void *)(uintptr_t)event); + if (m_value.load(std::memory_order_acquire) >= value) { + if (event) + SetEvent(event); + return S_OK; + } + if (!event) { + if (m_value.load(std::memory_order_acquire) < value) { + FTRACE("SetEventOnCompletion: null event, auto-signaling fence %p from %llu to %llu (sync replay already done)", + (void *)this, (unsigned long long)m_value.load(), (unsigned long long)value); + m_value.store(value, std::memory_order_release); + if (m_shared_event.handle) { + m_shared_event.signalValue(value); + } + } + return S_OK; + } + if (m_shared_event.handle) { + m_shared_event.waitUntilSignaledValue(value, UINT64_MAX); + } + SetEvent(event); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Fence::Signal(uint64_t value) { + FTRACE("Signal value=%llu this=%p", (unsigned long long)value, (void *)this); + m_value.store(value, std::memory_order_release); + if (m_shared_event.handle) { + m_shared_event.signalValue(value); + } + return S_OK; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_fence.hpp b/src/d3d12/d3d12_fence.hpp new file mode 100644 index 000000000..49d954784 --- /dev/null +++ b/src/d3d12/d3d12_fence.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "Metal.hpp" +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12Fence : public ID3D12Fence { +public: + MTLD3D12Fence(MTLD3D12Device *device, uint64_t initial_value, + D3D12_FENCE_FLAGS flags); + ~MTLD3D12Fence(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + uint64_t STDMETHODCALLTYPE GetCompletedValue() override; + HRESULT STDMETHODCALLTYPE SetEventOnCompletion(uint64_t value, + HANDLE event) override; + HRESULT STDMETHODCALLTYPE Signal(uint64_t value) override; + + WMT::Reference GetMTLSharedEvent() { + return m_shared_event; + } + +private: + MTLD3D12Device *m_device; + D3D12_FENCE_FLAGS m_flags; + std::atomic m_value; + WMT::Reference m_shared_event; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_heap.cpp b/src/d3d12/d3d12_heap.cpp new file mode 100644 index 000000000..c9496b285 --- /dev/null +++ b/src/d3d12/d3d12_heap.cpp @@ -0,0 +1,78 @@ +#include "d3d12_heap.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +#define HTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "Heap::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12Heap::MTLD3D12Heap(MTLD3D12Device *device, const D3D12_HEAP_DESC &desc) + : m_device(device), m_desc(desc) { + m_device->AddRef(); + HTRACE("ctor: size=%llu alignment=%llu type=%u flags=0x%x", + (unsigned long long)desc.SizeInBytes, (unsigned long long)desc.Alignment, + desc.Properties.Type, desc.Flags); +} + +MTLD3D12Heap::~MTLD3D12Heap() { + HTRACE("dtor"); + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Heap::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12Heap) { + *ppvObject = ref(this); + return S_OK; + } + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12Heap::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12Heap::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Heap::GetPrivateData(REFGUID guid, UINT *data_size, void *data) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Heap::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Heap::SetPrivateDataInterface(REFGUID guid, const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Heap::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Heap::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +D3D12_HEAP_DESC *STDMETHODCALLTYPE +MTLD3D12Heap::GetDesc(D3D12_HEAP_DESC *__ret) { + *__ret = m_desc; + return __ret; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_heap.hpp b/src/d3d12/d3d12_heap.hpp new file mode 100644 index 000000000..52873b617 --- /dev/null +++ b/src/d3d12/d3d12_heap.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12Heap : public ID3D12Heap { +public: + MTLD3D12Heap(MTLD3D12Device *device, const D3D12_HEAP_DESC &desc); + ~MTLD3D12Heap(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + D3D12_HEAP_DESC *STDMETHODCALLTYPE GetDesc(D3D12_HEAP_DESC *__ret) override; + + const D3D12_HEAP_DESC &GetHeapDesc() const { return m_desc; } + +private: + MTLD3D12Device *m_device; + D3D12_HEAP_DESC m_desc; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_pipeline_state.cpp b/src/d3d12/d3d12_pipeline_state.cpp new file mode 100644 index 000000000..0e0cc0dde --- /dev/null +++ b/src/d3d12/d3d12_pipeline_state.cpp @@ -0,0 +1,623 @@ +#include "d3d12_pipeline_state.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include "Metal.hpp" +#include "airconv_public.h" +#include "dxmt_format.hpp" +#include "dxil/dxil_container.hpp" +#include "dxil/llvm_bitcode.hpp" +#include "dxil/dxil_to_msl.hpp" +#include "../../libs/DXBCParser/BlobContainer.h" +#include +#include +#include +#include +#include + +#define PSTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +std::mutex MTLD3D12PipelineState::s_shader_mutex; +std::unordered_map> MTLD3D12PipelineState::s_shader_cache; + +MTLD3D12PipelineState::MTLD3D12PipelineState(MTLD3D12Device *device, + bool is_compute) + : m_device(device), m_is_compute(is_compute) { + m_device->AddRef(); +} + +MTLD3D12PipelineState::~MTLD3D12PipelineState() { + if (m_root_sig) + m_root_sig->Release(); + m_render_pso = nullptr; + m_compute_pso = nullptr; + m_device->Release(); +} + +WMTPixelFormat MTLD3D12PipelineState::DXGIToMTLPixelFormat(DXGI_FORMAT format) { + switch (format) { + case DXGI_FORMAT_R8G8B8A8_UNORM: return WMTPixelFormatRGBA8Unorm; + case DXGI_FORMAT_R8G8B8A8_UNORM_SRGB: return WMTPixelFormatRGBA8Unorm_sRGB; + case DXGI_FORMAT_B8G8R8A8_UNORM: return WMTPixelFormatBGRA8Unorm; + case DXGI_FORMAT_B8G8R8A8_UNORM_SRGB: return WMTPixelFormatBGRA8Unorm_sRGB; + case DXGI_FORMAT_R16G16B16A16_FLOAT: return WMTPixelFormatRGBA16Float; + case DXGI_FORMAT_R32G32B32A32_FLOAT: return WMTPixelFormatRGBA32Float; + case DXGI_FORMAT_R10G10B10A2_UNORM: return WMTPixelFormatRGB10A2Unorm; + case DXGI_FORMAT_R11G11B10_FLOAT: return WMTPixelFormatRG11B10Float; + case DXGI_FORMAT_R8_UNORM: return WMTPixelFormatR8Unorm; + case DXGI_FORMAT_R16_FLOAT: return WMTPixelFormatR16Float; + case DXGI_FORMAT_R32_FLOAT: return WMTPixelFormatR32Float; + case DXGI_FORMAT_D32_FLOAT: return WMTPixelFormatDepth32Float; + case DXGI_FORMAT_D24_UNORM_S8_UINT: return WMTPixelFormatDepth24Unorm_Stencil8; + case DXGI_FORMAT_D32_FLOAT_S8X24_UINT: return WMTPixelFormatDepth32Float_Stencil8; + case DXGI_FORMAT_D16_UNORM: return WMTPixelFormatDepth16Unorm; + case DXGI_FORMAT_R16G16_FLOAT: return WMTPixelFormatRG16Float; + case DXGI_FORMAT_R16G16_UNORM: return WMTPixelFormatRG16Unorm; + case DXGI_FORMAT_R8G8_UNORM: return WMTPixelFormatRG8Unorm; + case DXGI_FORMAT_BC1_UNORM: return WMTPixelFormatBC1_RGBA; + case DXGI_FORMAT_BC1_UNORM_SRGB: return WMTPixelFormatBC1_RGBA_sRGB; + case DXGI_FORMAT_BC2_UNORM: return WMTPixelFormatBC2_RGBA; + case DXGI_FORMAT_BC2_UNORM_SRGB: return WMTPixelFormatBC2_RGBA_sRGB; + case DXGI_FORMAT_BC3_UNORM: return WMTPixelFormatBC3_RGBA; + case DXGI_FORMAT_BC3_UNORM_SRGB: return WMTPixelFormatBC3_RGBA_sRGB; + case DXGI_FORMAT_BC4_UNORM: return WMTPixelFormatBC4_RUnorm; + case DXGI_FORMAT_BC4_SNORM: return WMTPixelFormatBC4_RSnorm; + case DXGI_FORMAT_BC5_UNORM: return WMTPixelFormatBC5_RGUnorm; + case DXGI_FORMAT_BC5_SNORM: return WMTPixelFormatBC5_RGSnorm; + case DXGI_FORMAT_BC6H_UF16: return WMTPixelFormatBC6H_RGBUfloat; + case DXGI_FORMAT_BC6H_SF16: return WMTPixelFormatBC6H_RGBFloat; + case DXGI_FORMAT_BC7_UNORM: return WMTPixelFormatBC7_RGBAUnorm; + case DXGI_FORMAT_BC7_UNORM_SRGB: return WMTPixelFormatBC7_RGBAUnorm_sRGB; + default: return WMTPixelFormatInvalid; + } +} + +bool MTLD3D12PipelineState::CompileShader(const void *bytecode, SIZE_T size, + ShaderType type, + const char *func_name, + WMT::Reference &out_func) { + size_t hash = 0; + if (bytecode && size > 0) { + const uint8_t *p = (const uint8_t *)bytecode; + for (SIZE_T i = 0; i < size; i++) + hash = hash * 131 + p[i]; + } + { + std::lock_guard lock(s_shader_mutex); + PSTRACE("CompileShader: %s hash=0x%zx size=%zu cache_entries=%zu", func_name, hash, size, s_shader_cache.size()); + auto it = s_shader_cache.find(hash); + if (it != s_shader_cache.end()) { + out_func = it->second; + PSTRACE("CompileShader: %s CACHE HIT hash=0x%zx", func_name, hash); + return true; + } + } + + if (bytecode && size >= 4) { + auto *magic = (const uint32_t *)bytecode; + PSTRACE("CompileShader: %s size=%zu magic=0x%08x (DXBC=0x43425844 DXIL=0x4C495844)", func_name, size, *magic); + if (*magic == 0x43425844 && size >= 20) { + auto *chunks = (const uint32_t *)bytecode; + uint32_t num_chunks = chunks[4]; + PSTRACE(" DXBC: num_chunks=%u", num_chunks); + for (uint32_t i = 0; i < num_chunks && i < 16; i++) { + uint32_t offset = chunks[5 + i]; + if (offset + 8 <= size) { + char tag[5] = {}; + memcpy(tag, (const char *)bytecode + offset, 4); + uint32_t chunk_size = *((const uint32_t *)bytecode + offset/4 + 1); + PSTRACE(" chunk[%u]: tag='%s' offset=%u size=%u", i, tag, offset, chunk_size); + } + } + } + } + sm50_error_t sm50_err = nullptr; + sm50_shader_t shader = nullptr; + MTL_SHADER_REFLECTION reflection = {}; + + if (SM50Initialize(bytecode, size, &shader, &reflection, &sm50_err)) { + char err_buf[256] = {}; + SM50GetErrorMessage(sm50_err, err_buf, sizeof(err_buf)); + SM50FreeError(sm50_err); + + bool has_dxil = false; + using namespace microsoft; + CDXBCParser dxbcParser; + if (SUCCEEDED(dxbcParser.ReadDXBC(bytecode, size))) { + for (UINT32 i = 0; i < dxbcParser.GetBlobCount(); i++) { + if (dxbcParser.GetBlobFourCC(i) == dxmt::dxil::DXIL_FOURCC) { + has_dxil = true; + const void *blob = dxbcParser.GetBlob(i); + UINT32 blob_size = dxbcParser.GetBlobSize(i); + PSTRACE("DXIL blob found index=%u size=%u", i, blob_size); + + auto wmt_device = m_device->GetDXMTDevice().device(); + + char cache_path[256]; + snprintf(cache_path, sizeof(cache_path), "/tmp/dxmt_shader_cache/%016zx", hash); + char dxbc_path[256], metallib_path[256], reflection_path[256]; + snprintf(dxbc_path, sizeof(dxbc_path), "%s.dxbc", cache_path); + snprintf(metallib_path, sizeof(metallib_path), "%s.metallib", cache_path); + snprintf(reflection_path, sizeof(reflection_path), "%s.json", cache_path); + + FILE *mf = fopen(metallib_path, "rb"); + if (!mf) { + PSTRACE(" metallib not cached, attempting DXIL->MSL compilation"); + + auto container = dxmt::dxil::DXILContainer::parse(blob, blob_size); + if (!container) { + PSTRACE(" DXILContainer::parse FAILED for %s", func_name); + FILE *df = fopen(dxbc_path, "wb"); + if (df) { fwrite(bytecode, 1, size, df); fclose(df); } + return false; + } + + auto &shader_info = container->shader(); + PSTRACE(" DXIL container parsed: kind=%u sm=%u.%u bc_size=%u", + (uint32_t)shader_info.kind, shader_info.shader_model.major, + shader_info.shader_model.minor, shader_info.bitcode.size); + + auto module = dxmt::dxil::BitcodeReader::parse( + shader_info.bitcode.data, shader_info.bitcode.size); + if (!module) { + PSTRACE(" BitcodeReader::parse FAILED"); + FILE *df = fopen(dxbc_path, "wb"); + if (df) { fwrite(bytecode, 1, size, df); fclose(df); } + return false; + } + + PSTRACE(" Bitcode parsed: types=%zu functions=%zu constants=%zu", + module->types.size(), module->functions.size(), module->constants.size()); + + auto msl_result = dxmt::dxil::DXILToMSL::convert(*module, shader_info); + if (!msl_result) { + PSTRACE(" DXILToMSL::convert FAILED"); + FILE *df = fopen(dxbc_path, "wb"); + if (df) { fwrite(bytecode, 1, size, df); fclose(df); } + return false; + } + + PSTRACE(" MSL generated: %zu bytes, entry=%s", msl_result->source.size(), msl_result->entry_point.c_str()); + + { + char msl_path[256]; + snprintf(msl_path, sizeof(msl_path), "%s.msl", cache_path); + FILE *msl_file = fopen(msl_path, "w"); + if (msl_file) { + fwrite(msl_result->source.c_str(), 1, msl_result->source.size(), msl_file); + fclose(msl_file); + PSTRACE(" MSL source written to %s", msl_path); + } + } + + WMT::Reference compile_err; + auto library = wmt_device.newLibraryWithSource( + msl_result->source.c_str(), msl_result->source.size(), compile_err); + + if (compile_err.handle) { + char *err_desc = (char *)NSObject_description(compile_err.handle); + PSTRACE(" newLibraryWithSource FAILED: %s", err_desc ? err_desc : "unknown"); + Logger::err(str::format("DXIL MSL compilation failed for ", func_name, ": ", + err_desc ? err_desc : "unknown error")); + FILE *df = fopen(dxbc_path, "wb"); + if (df) { fwrite(bytecode, 1, size, df); fclose(df); } + return false; + } + + PSTRACE(" Metal library compiled OK from source"); + + const char *entry_name = msl_result->entry_point.c_str(); + if (strcmp(entry_name, "cs_main") != 0 && + strcmp(entry_name, "vs_main") != 0 && + strcmp(entry_name, "ps_main") != 0) { + switch (shader_info.kind) { + case dxmt::dxil::DxilShaderKind::Compute: entry_name = "cs_main"; break; + case dxmt::dxil::DxilShaderKind::Vertex: entry_name = "vs_main"; break; + case dxmt::dxil::DxilShaderKind::Pixel: entry_name = "ps_main"; break; + default: break; + } + } + + out_func = library.newFunction(entry_name); + if (!out_func.handle) { + PSTRACE(" newFunction(%s) returned null, trying alternatives", entry_name); + out_func = library.newFunction("main"); + if (!out_func.handle) + out_func = library.newFunction("cs_main"); + if (!out_func.handle) + out_func = library.newFunction("vs_main"); + if (!out_func.handle) + out_func = library.newFunction("ps_main"); + } + + if (out_func.handle) { + PSTRACE(" DXIL shader compiled OK! entry=%s", entry_name); + s_shader_cache[hash] = out_func; + + if (shader_info.kind == dxmt::dxil::DxilShaderKind::Compute) { + m_threadgroup_size.width = msl_result->tg_size[0]; + m_threadgroup_size.height = msl_result->tg_size[1]; + m_threadgroup_size.depth = msl_result->tg_size[2]; + } + return true; + } else { + PSTRACE(" newFunction returned null for all entry points"); + Logger::err(str::format("DXIL: failed to get function from compiled library for ", func_name)); + return false; + } + } + + PSTRACE(" loading cached metallib from %s", metallib_path); + fseek(mf, 0, SEEK_END); + long lib_size = ftell(mf); + fseek(mf, 0, SEEK_SET); + PSTRACE(" metallib size=%ld", lib_size); + if (lib_size > 0) { + std::vector lib_data(lib_size); + fread(lib_data.data(), 1, lib_size, mf); + fclose(mf); + auto dispatch_data = WMT::MakeDispatchData(lib_data.data(), lib_size); + WMT::Reference err; + auto library = wmt_device.newLibrary(dispatch_data, err); + if (!err.handle) { + char actual_entry[256] = {}; + char rbuf[4096] = {}; + FILE *rf = fopen(reflection_path, "r"); + if (rf) { + fread(rbuf, 1, sizeof(rbuf)-1, rf); + fclose(rf); + char *ep = strstr(rbuf, "\"EntryPoint\""); + if (ep) { + char *q1 = strchr(ep + 13, '"'); + char *q2 = q1 ? strchr(q1+1, '"') : nullptr; + if (q1 && q2) { + size_t len = q2 - q1 - 1; + if (len < sizeof(actual_entry)) { + memcpy(actual_entry, q1+1, len); + actual_entry[len] = 0; + } + } + } + } + const char *fn_name = actual_entry[0] ? actual_entry : func_name; + PSTRACE(" trying newFunction(%s)", fn_name); + out_func = library.newFunction(fn_name); + if (!out_func.handle && actual_entry[0]) { + out_func = library.newFunction(func_name); + } + if (out_func.handle) { + PSTRACE(" DXIL loaded from cache OK! entry=%s", fn_name); + s_shader_cache[hash] = out_func; + char *tg = strstr(rbuf, "\"tg_size\""); + if (tg) { + int tw=1,th=1,td=1; + if (sscanf(tg, "\"tg_size\": [%d, %d, %d]", &tw, &th, &td) == 3 || + sscanf(tg, "\"tg_size\":[%d,%d,%d]", &tw, &th, &td) == 3) { + m_threadgroup_size.width = tw; + m_threadgroup_size.height = th; + m_threadgroup_size.depth = td; + PSTRACE(" threadgroup_size from reflection: %dx%dx%d", tw, th, td); + } + } + return true; + } else { + PSTRACE(" WMT newFunction returned null"); + } + } else { + PSTRACE(" WMT newLibrary FAILED"); + } + } else { + fclose(mf); + } + break; + } + } + } + if (!has_dxil) { + PSTRACE("SM50Init FAILED for %s: %s (no DXIL chunk)", func_name, err_buf); + } + return false; + } + + SM50_SHADER_COMMON_DATA common = {}; + common.next = nullptr; + common.type = SM50_SHADER_COMMON; + common.metal_version = SM50_SHADER_METAL_310; + common.flags = {}; + + sm50_bitcode_t compile_result = nullptr; + if (SM50Compile(shader, (SM50_SHADER_COMPILATION_ARGUMENT_DATA *)&common, + func_name, &compile_result, &sm50_err)) { + char err_buf[256] = {}; + SM50GetErrorMessage(sm50_err, err_buf, sizeof(err_buf)); + Logger::err(str::format("SM50Compile failed for ", func_name, ": ", err_buf)); + SM50FreeError(sm50_err); + SM50Destroy(shader); + return false; + } + + SM50_COMPILED_BITCODE bitcode = {}; + SM50GetCompiledBitcode(compile_result, &bitcode); + + auto wmt_device = m_device->GetDXMTDevice().device(); + WMT::Reference err; + auto lib_data = WMT::MakeDispatchData(bitcode.Data, bitcode.Size); + auto library = wmt_device.newLibrary(lib_data, err); + + if (err.handle) { + Logger::err(str::format("Failed to create Metal library for ", func_name)); + SM50DestroyBitcode(compile_result); + SM50Destroy(shader); + return false; + } + + out_func = library.newFunction(func_name); + SM50DestroyBitcode(compile_result); + SM50Destroy(shader); + + if (!out_func.handle) { + Logger::err(str::format("Failed to get function ", func_name)); + return false; + } + + Logger::info(str::format(" Compiled ", func_name, " OK")); + { + std::lock_guard lock(s_shader_mutex); + s_shader_cache[hash] = out_func; + } + return true; +} + +bool MTLD3D12PipelineState::Compile() { + if (m_compiled) + return true; + + auto wmt_device = m_device->GetDXMTDevice().device(); + WMT::Reference err; + + if (m_is_compute) { + if (m_cs.empty()) { + Logger::err("Compute PSO has no CS bytecode"); + return false; + } + + WMT::Reference cs_func; + if (!CompileShader(m_cs.data(), m_cs.size(), ShaderType::Compute, + "cs_main", cs_func)) + return false; + + WMTComputePipelineInfo info = {}; + WMT::InitializeComputePipelineInfo(info); + info.compute_function = cs_func.handle; + + m_compute_pso = wmt_device.newComputePipelineState(info, err); + if (!m_compute_pso.handle) { + Logger::err("Failed to create compute PSO"); + return false; + } + + m_compiled = true; + Logger::info("Compute PSO compiled successfully"); + return true; + } + + WMT::Reference vs_func, ps_func; + + if (!m_vs.empty()) { + if (!CompileShader(m_vs.data(), m_vs.size(), ShaderType::Vertex, + "vs_main", vs_func)) + return false; + } + + if (!m_ps.empty()) { + if (!CompileShader(m_ps.data(), m_ps.size(), ShaderType::Pixel, + "ps_main", ps_func)) + return false; + } + + WMTRenderPipelineInfo info; + WMT::InitializeRenderPipelineInfo(info); + + if (vs_func.handle) + info.vertex_function = vs_func.handle; + if (ps_func.handle) + info.fragment_function = ps_func.handle; + + info.rasterization_enabled = (m_rasterizer_desc.FillMode != D3D12_FILL_MODE_WIREFRAME); + info.raster_sample_count = m_sample_count ? m_sample_count : 1; + + for (UINT i = 0; i < m_num_render_targets && i < 8; i++) { + auto fmt = DXGIToMTLPixelFormat(m_rtv_formats[i]); + if (fmt != WMTPixelFormatInvalid) + info.colors[i].pixel_format = fmt; + } + + auto depth_fmt = DXGIToMTLPixelFormat(m_dsv_format); + if (depth_fmt != WMTPixelFormatInvalid) { + info.depth_pixel_format = depth_fmt; + if (m_dsv_format == DXGI_FORMAT_D24_UNORM_S8_UINT || + m_dsv_format == DXGI_FORMAT_D32_FLOAT_S8X24_UINT) + info.stencil_pixel_format = depth_fmt; + } + + if (m_blend_desc.RenderTarget[0].BlendEnable) { + for (UINT i = 0; i < m_num_render_targets && i < 8; i++) { + auto &rt = m_blend_desc.RenderTarget[i]; + info.colors[i].blending_enabled = rt.BlendEnable ? true : false; + info.colors[i].write_mask = rt.RenderTargetWriteMask; + + auto map_blend = [](D3D12_BLEND b) -> WMTBlendFactor { + switch (b) { + case D3D12_BLEND_ZERO: return WMTBlendFactorZero; + case D3D12_BLEND_ONE: return WMTBlendFactorOne; + case D3D12_BLEND_SRC_COLOR: return WMTBlendFactorSourceColor; + case D3D12_BLEND_INV_SRC_COLOR: return WMTBlendFactorOneMinusSourceColor; + case D3D12_BLEND_SRC_ALPHA: return WMTBlendFactorSourceAlpha; + case D3D12_BLEND_INV_SRC_ALPHA: return WMTBlendFactorOneMinusSourceAlpha; + case D3D12_BLEND_DEST_ALPHA: return WMTBlendFactorDestinationAlpha; + case D3D12_BLEND_INV_DEST_ALPHA: return WMTBlendFactorOneMinusDestinationAlpha; + case D3D12_BLEND_DEST_COLOR: return WMTBlendFactorDestinationColor; + case D3D12_BLEND_INV_DEST_COLOR: return WMTBlendFactorOneMinusDestinationColor; + case D3D12_BLEND_SRC_ALPHA_SAT: return WMTBlendFactorSourceAlphaSaturated; + case D3D12_BLEND_BLEND_FACTOR: return WMTBlendFactorBlendColor; + case D3D12_BLEND_INV_BLEND_FACTOR: return WMTBlendFactorOneMinusBlendColor; + default: return WMTBlendFactorOne; + } + }; + + auto map_op = [](D3D12_BLEND_OP op) -> WMTBlendOperation { + switch (op) { + case D3D12_BLEND_OP_ADD: return WMTBlendOperationAdd; + case D3D12_BLEND_OP_SUBTRACT: return WMTBlendOperationSubtract; + case D3D12_BLEND_OP_REV_SUBTRACT: return WMTBlendOperationReverseSubtract; + case D3D12_BLEND_OP_MIN: return WMTBlendOperationMin; + case D3D12_BLEND_OP_MAX: return WMTBlendOperationMax; + default: return WMTBlendOperationAdd; + } + }; + + info.colors[i].src_rgb_blend_factor = map_blend(rt.SrcBlend); + info.colors[i].dst_rgb_blend_factor = map_blend(rt.DestBlend); + info.colors[i].rgb_blend_operation = map_op(rt.BlendOp); + info.colors[i].src_alpha_blend_factor = map_blend(rt.SrcBlendAlpha); + info.colors[i].dst_alpha_blend_factor = map_blend(rt.DestBlendAlpha); + info.colors[i].alpha_blend_operation = map_op(rt.BlendOpAlpha); + } + } + + switch (m_topology) { + case D3D12_PRIMITIVE_TOPOLOGY_TYPE_POINT: info.input_primitive_topology = WMTPrimitiveTopologyClassPoint; break; + case D3D12_PRIMITIVE_TOPOLOGY_TYPE_LINE: info.input_primitive_topology = WMTPrimitiveTopologyClassLine; break; + case D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE: info.input_primitive_topology = WMTPrimitiveTopologyClassTriangle; break; + default: info.input_primitive_topology = WMTPrimitiveTopologyClassUnspecified; break; + } + + info.immutable_vertex_buffers = (1 << 16) | (1 << 29) | (1 << 30); + info.immutable_fragment_buffers = (1 << 29) | (1 << 30); + + m_render_pso = wmt_device.newRenderPipelineState(info, err); + if (!m_render_pso.handle) { + Logger::err("Failed to create render PSO"); + return false; + } + + m_compiled = true; + Logger::info(str::format("Graphics PSO compiled: RTs=", m_num_render_targets, + " DSV=", (int)m_dsv_format, + " samples=", m_sample_count)); + return true; +} + +void MTLD3D12PipelineState::SetGraphicsDesc( + const D3D12_GRAPHICS_PIPELINE_STATE_DESC &desc) { + if (desc.pRootSignature) { + m_root_sig = desc.pRootSignature; + m_root_sig->AddRef(); + } + + if (desc.VS.pShaderBytecode && desc.VS.BytecodeLength) { + m_vs.resize(desc.VS.BytecodeLength); + memcpy(m_vs.data(), desc.VS.pShaderBytecode, desc.VS.BytecodeLength); + } + if (desc.PS.pShaderBytecode && desc.PS.BytecodeLength) { + m_ps.resize(desc.PS.BytecodeLength); + memcpy(m_ps.data(), desc.PS.pShaderBytecode, desc.PS.BytecodeLength); + } + if (desc.GS.pShaderBytecode && desc.GS.BytecodeLength) { + m_gs.resize(desc.GS.BytecodeLength); + memcpy(m_gs.data(), desc.GS.pShaderBytecode, desc.GS.BytecodeLength); + } + if (desc.HS.pShaderBytecode && desc.HS.BytecodeLength) { + m_hs.resize(desc.HS.BytecodeLength); + memcpy(m_hs.data(), desc.HS.pShaderBytecode, desc.HS.BytecodeLength); + } + if (desc.DS.pShaderBytecode && desc.DS.BytecodeLength) { + m_ds.resize(desc.DS.BytecodeLength); + memcpy(m_ds.data(), desc.DS.pShaderBytecode, desc.DS.BytecodeLength); + } + + m_blend_desc = desc.BlendState; + m_rasterizer_desc = desc.RasterizerState; + m_depth_stencil_desc = desc.DepthStencilState; + m_input_layout = desc.InputLayout; + m_strip_cut_value = desc.IBStripCutValue; + m_topology = desc.PrimitiveTopologyType; + m_num_render_targets = desc.NumRenderTargets; + memcpy(m_rtv_formats, desc.RTVFormats, sizeof(m_rtv_formats)); + m_dsv_format = desc.DSVFormat; + m_sample_mask = desc.SampleMask; + m_sample_count = desc.SampleDesc.Count ? desc.SampleDesc.Count : 1; +} + +void MTLD3D12PipelineState::SetComputeDesc( + const D3D12_COMPUTE_PIPELINE_STATE_DESC &desc) { + if (desc.pRootSignature) { + m_root_sig = desc.pRootSignature; + m_root_sig->AddRef(); + } + if (desc.CS.pShaderBytecode && desc.CS.BytecodeLength) { + m_cs.resize(desc.CS.BytecodeLength); + memcpy(m_cs.data(), desc.CS.pShaderBytecode, desc.CS.BytecodeLength); + } +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12PipelineState) { + *ppvObject = ref(this); + return S_OK; + } + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE +MTLD3D12PipelineState::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12PipelineState::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12PipelineState::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12PipelineState::GetCachedBlob(ID3DBlob **blob) { + return E_NOTIMPL; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_pipeline_state.hpp b/src/d3d12/d3d12_pipeline_state.hpp new file mode 100644 index 000000000..aa76fb3da --- /dev/null +++ b/src/d3d12/d3d12_pipeline_state.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "Metal.hpp" +#include "airconv_public.h" +#include +#include +#include +#include + +namespace dxmt { + +class MTLD3D12Device; + +struct CompiledShader { + sm50_shader_t handle = nullptr; + MTL_SHADER_REFLECTION reflection = {}; +}; + +class MTLD3D12PipelineState : public ID3D12PipelineState { +public: + MTLD3D12PipelineState(MTLD3D12Device *device, bool is_compute); + ~MTLD3D12PipelineState(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + HRESULT STDMETHODCALLTYPE GetCachedBlob(ID3DBlob **blob) override; + + void SetGraphicsDesc(const D3D12_GRAPHICS_PIPELINE_STATE_DESC &desc); + void SetComputeDesc(const D3D12_COMPUTE_PIPELINE_STATE_DESC &desc); + + bool Compile(); + + bool EnsureCompiled() { + if (!m_compiled) Compile(); + return m_compiled; + } + + bool IsCompute() const { return m_is_compute; } + bool IsCompiled() const { return m_compiled; } + + WMT::Reference GetRenderPSO() const { + return m_render_pso; + } + WMT::Reference GetComputePSO() const { + return m_compute_pso; + } + ID3D12RootSignature *GetRootSignature() const { return m_root_sig; } + struct WMTSize GetThreadgroupSize() const { + return {(uint64_t)m_threadgroup_size.width, (uint64_t)m_threadgroup_size.height, (uint64_t)m_threadgroup_size.depth}; + } + + static WMTPixelFormat DXGIToMTLPixelFormat(DXGI_FORMAT format); + +private: + bool CompileShader(const void *bytecode, SIZE_T size, ShaderType type, + const char *func_name, WMT::Reference &out_func); + + static std::mutex s_shader_mutex; + static std::unordered_map> s_shader_cache; + + MTLD3D12Device *m_device; + bool m_is_compute; + bool m_compiled = false; + ID3D12RootSignature *m_root_sig = nullptr; + std::vector m_vs, m_ps, m_gs, m_hs, m_ds, m_cs; + D3D12_BLEND_DESC m_blend_desc = {}; + D3D12_RASTERIZER_DESC m_rasterizer_desc = {}; + D3D12_DEPTH_STENCIL_DESC m_depth_stencil_desc = {}; + D3D12_INPUT_LAYOUT_DESC m_input_layout = {}; + D3D12_INDEX_BUFFER_STRIP_CUT_VALUE m_strip_cut_value = {}; + D3D12_PRIMITIVE_TOPOLOGY_TYPE m_topology = {}; + UINT m_num_render_targets = 0; + DXGI_FORMAT m_rtv_formats[8] = {}; + DXGI_FORMAT m_dsv_format = DXGI_FORMAT_UNKNOWN; + UINT m_sample_mask = UINT_MAX; + UINT m_sample_count = 1; + + WMT::Reference m_render_pso; + WMT::Reference m_compute_pso; + struct { uint32_t width = 1, height = 1, depth = 1; } m_threadgroup_size; + + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_query_heap.cpp b/src/d3d12/d3d12_query_heap.cpp new file mode 100644 index 000000000..97ad1866e --- /dev/null +++ b/src/d3d12/d3d12_query_heap.cpp @@ -0,0 +1,69 @@ +#include "d3d12_query_heap.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +namespace dxmt { + +MTLD3D12QueryHeap::MTLD3D12QueryHeap(MTLD3D12Device *device, + const D3D12_QUERY_HEAP_DESC &desc) + : m_device(device), m_desc(desc) { + m_device->AddRef(); + m_data.resize(desc.Count, 0); + Logger::info(str::format("D3D12QueryHeap: type=", desc.Type, + " count=", desc.Count)); +} + +MTLD3D12QueryHeap::~MTLD3D12QueryHeap() { m_device->Release(); } + +HRESULT STDMETHODCALLTYPE +MTLD3D12QueryHeap::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12QueryHeap) { + *ppvObject = ref(this); + return S_OK; + } + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12QueryHeap::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12QueryHeap::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12QueryHeap::GetPrivateData(REFGUID guid, UINT *data_size, void *data) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12QueryHeap::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12QueryHeap::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12QueryHeap::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12QueryHeap::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_query_heap.hpp b/src/d3d12/d3d12_query_heap.hpp new file mode 100644 index 000000000..7dc198fb7 --- /dev/null +++ b/src/d3d12/d3d12_query_heap.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12QueryHeap : public ID3D12QueryHeap { +public: + MTLD3D12QueryHeap(MTLD3D12Device *device, + const D3D12_QUERY_HEAP_DESC &desc); + ~MTLD3D12QueryHeap(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + D3D12_QUERY_HEAP_TYPE GetType() const { return m_desc.Type; } + UINT GetCount() const { return m_desc.Count; } + uint64_t *GetData() { return m_data.data(); } + +private: + MTLD3D12Device *m_device; + D3D12_QUERY_HEAP_DESC m_desc; + std::vector m_data; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_resource.cpp b/src/d3d12/d3d12_resource.cpp new file mode 100644 index 000000000..0c0cd7bc6 --- /dev/null +++ b/src/d3d12/d3d12_resource.cpp @@ -0,0 +1,316 @@ +#include "d3d12_resource.hpp" +#include "d3d12_device.hpp" +#include "d3d12_pipeline_state.hpp" +#include "log/log.hpp" +#include "util_string.hpp" + +#define RTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "Resource::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +namespace dxmt { + +MTLD3D12Resource::MTLD3D12Resource( + MTLD3D12Device *device, const D3D12_RESOURCE_DESC &desc, + D3D12_RESOURCE_STATES initial_state, + D3D12_HEAP_PROPERTIES heap_properties) + : m_device(device), m_desc(desc), m_state(initial_state), + m_heap_properties(heap_properties) { + m_device->AddRef(); + + auto wmt_device = m_device->GetDXMTDevice().device(); + RTRACE("ctor: wmt_device=%llu dim=%u fmt=%u w=%llu h=%u depth_or_arr=%u", + (unsigned long long)wmt_device.handle, desc.Dimension, desc.Format, + desc.Width, desc.Height, desc.DepthOrArraySize); + + if (desc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER) { + WMTBufferInfo buf_info = {}; + buf_info.length = desc.Width ? desc.Width : 256; + buf_info.options = WMTResourceStorageModeShared; + m_mtl_buffer = wmt_device.newBuffer(buf_info); + m_cpu_addr = buf_info.memory.get_accessible_or_null(); + m_gpu_addr = buf_info.gpu_address; + m_buf_info = buf_info; + RTRACE("ctor: buffer cpu=%p gpu=0x%llx len=%llu", m_cpu_addr, (unsigned long long)m_gpu_addr, (unsigned long long)desc.Width); + } else { + bool cpu_accessible = (heap_properties.Type == D3D12_HEAP_TYPE_UPLOAD || + heap_properties.Type == D3D12_HEAP_TYPE_READBACK); + WMTTextureInfo tex_info = {}; + tex_info.width = desc.Width; + tex_info.height = desc.Height; + tex_info.depth = (desc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D) + ? desc.DepthOrArraySize + : 1; + tex_info.array_length = + (desc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE2D) + ? desc.DepthOrArraySize + : 1; + tex_info.mipmap_level_count = desc.MipLevels ? desc.MipLevels : 1; + tex_info.sample_count = desc.SampleDesc.Count ? desc.SampleDesc.Count : 1; + switch (desc.Dimension) { + case D3D12_RESOURCE_DIMENSION_TEXTURE1D: + tex_info.type = (desc.DepthOrArraySize > 1) ? WMTTextureType1DArray : WMTTextureType1D; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE3D: + tex_info.type = WMTTextureType3D; + break; + default: + tex_info.type = (desc.DepthOrArraySize > 1) ? WMTTextureType2DArray : WMTTextureType2D; + break; + } + tex_info.usage = (WMTTextureUsage)(WMTTextureUsageRenderTarget | + WMTTextureUsageShaderRead | + WMTTextureUsageShaderWrite); + tex_info.options = cpu_accessible ? WMTResourceStorageModeShared : WMTResourceStorageModePrivate; + tex_info.pixel_format = MTLD3D12PipelineState::DXGIToMTLPixelFormat(static_cast(desc.Format)); + if (tex_info.pixel_format == WMTPixelFormatInvalid) + tex_info.pixel_format = WMTPixelFormatBGRA8Unorm; + + RTRACE("ctor: about to newTexture type=%u fmt=%u %ux%u depth=%u arr=%u mip=%u sample=%u opts=%u", + tex_info.type, tex_info.pixel_format, (unsigned)tex_info.width, (unsigned)tex_info.height, + (unsigned)tex_info.depth, (unsigned)tex_info.array_length, + (unsigned)tex_info.mipmap_level_count, (unsigned)tex_info.sample_count, (unsigned)tex_info.options); + m_mtl_texture = wmt_device.newTexture(tex_info); + if (!m_mtl_texture.handle) { + RTRACE("ctor: texture creation FAILED type=%u fmt=%u %ux%u arr=%u", + tex_info.type, tex_info.pixel_format, (unsigned)tex_info.width, (unsigned)tex_info.height, (unsigned)tex_info.array_length); + } else { + RTRACE("ctor: texture created fmt=%u %ux%u arr=%u handle=%llu %s", + tex_info.pixel_format, (unsigned)tex_info.width, (unsigned)tex_info.height, (unsigned)tex_info.array_length, + (unsigned long long)m_mtl_texture.handle, cpu_accessible ? "cpu" : "gpu"); + } + { + uint64_t fake_size = (uint64_t)tex_info.width * tex_info.height * 4; + if (fake_size < 256) fake_size = 256; + WMTBufferInfo fake_buf = {}; + fake_buf.length = fake_size; + fake_buf.options = WMTResourceStorageModeShared; + m_fake_buffer = wmt_device.newBuffer(fake_buf); + m_gpu_addr = fake_buf.gpu_address; + RTRACE("ctor: texture fake gpu_addr=0x%llx (from fake buffer %llu bytes)", (unsigned long long)m_gpu_addr, (unsigned long long)fake_size); + } + } + + Logger::info(str::format("D3D12Resource: dim=", desc.Dimension, + " ", desc.Width, "x", desc.Height, + " gpu=", m_gpu_addr)); + m_device->RegisterResource(this); +} + +WMT::Reference MTLD3D12Resource::GetMTLTexture() { + if (!m_mtl_texture.handle && m_desc.Dimension != D3D12_RESOURCE_DIMENSION_BUFFER) { + bool cpu_accessible = (m_heap_properties.Type == D3D12_HEAP_TYPE_UPLOAD || + m_heap_properties.Type == D3D12_HEAP_TYPE_READBACK); + auto wmt_device = m_device->GetDXMTDevice().device(); + WMTTextureInfo tex_info = {}; + tex_info.width = m_desc.Width ? m_desc.Width : 1; + tex_info.height = m_desc.Height ? m_desc.Height : 1; + tex_info.depth = (m_desc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D) + ? m_desc.DepthOrArraySize : 1; + tex_info.array_length = (m_desc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE2D) + ? m_desc.DepthOrArraySize : 1; + tex_info.mipmap_level_count = m_desc.MipLevels ? m_desc.MipLevels : 1; + tex_info.sample_count = m_desc.SampleDesc.Count ? m_desc.SampleDesc.Count : 1; + switch (m_desc.Dimension) { + case D3D12_RESOURCE_DIMENSION_TEXTURE1D: + tex_info.type = (m_desc.DepthOrArraySize > 1) ? WMTTextureType1DArray : WMTTextureType1D; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE3D: + tex_info.type = WMTTextureType3D; + break; + default: + tex_info.type = (m_desc.DepthOrArraySize > 1) ? WMTTextureType2DArray : WMTTextureType2D; + break; + } + tex_info.usage = (WMTTextureUsage)(WMTTextureUsageRenderTarget | WMTTextureUsageShaderRead | WMTTextureUsageShaderWrite); + tex_info.options = cpu_accessible ? WMTResourceStorageModeShared : WMTResourceStorageModePrivate; + tex_info.pixel_format = MTLD3D12PipelineState::DXGIToMTLPixelFormat(static_cast(m_desc.Format)); + if (tex_info.pixel_format == WMTPixelFormatInvalid) + tex_info.pixel_format = WMTPixelFormatBGRA8Unorm; + RTRACE("GetMTLTexture: creating type=%u fmt=%u %ux%ux%u arr=%u mip=%u sample=%u opts=%u", + tex_info.type, tex_info.pixel_format, (unsigned)tex_info.width, (unsigned)tex_info.height, + (unsigned)tex_info.depth, (unsigned)tex_info.array_length, (unsigned)tex_info.mipmap_level_count, + (unsigned)tex_info.sample_count, (unsigned)tex_info.options); + m_mtl_texture = wmt_device.newTexture(tex_info); + if (!m_mtl_texture.handle) { + RTRACE("GetMTLTexture: newTexture returned NULL handle"); + return m_mtl_texture; + } + RTRACE("GetMTLTexture: handle=%llu", (unsigned long long)m_mtl_texture.handle); + } + return m_mtl_texture; +} + +MTLD3D12Resource::~MTLD3D12Resource() { + m_device->UnregisterResource(this); + m_mtl_buffer = nullptr; + m_mtl_texture = nullptr; + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12Pageable || + riid == IID_ID3D12Resource || riid == IID_ID3D12Resource1 || + riid == IID_ID3D12Resource2) { + *ppvObject = ref(this); + return S_OK; + } + RTRACE("QI unknown IID %s -> E_NOINTERFACE", str::format(riid).c_str()); + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12Resource::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12Resource::Release() { + uint32_t rc = --m_refCount; + if (!rc) { + uint32_t rp = --m_refPrivate; + if (!rp) { + m_refPrivate += 0x80000000; + delete this; + } + } + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::GetPrivateData(REFGUID guid, UINT *data_size, void *data) { + RTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::SetPrivateDataInterface(REFGUID guid, const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Resource::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::Map(UINT sub_resource, + const D3D12_RANGE *read_range, + void **data) { + RTRACE("Map sub=%u", sub_resource); + if (!data) + return E_POINTER; + if (m_desc.Dimension > D3D12_RESOURCE_DIMENSION_TEXTURE3D) { + RTRACE("Map: invalid dimension, returning fake pointer"); + *data = (void*)1; + return S_OK; + } + if (m_cpu_addr) { + *data = m_cpu_addr; + RTRACE("Map returning cpu_addr=%p gpu_addr=0x%llx", m_cpu_addr, (unsigned long long)m_gpu_addr); + return S_OK; + } + RTRACE("Map FAILED - no cpu_addr"); + return E_FAIL; +} + +void STDMETHODCALLTYPE MTLD3D12Resource::Unmap( + UINT sub_resource, const D3D12_RANGE *written_range) {} + +D3D12_RESOURCE_DESC *STDMETHODCALLTYPE +MTLD3D12Resource::GetDesc(D3D12_RESOURCE_DESC *__ret) { + *__ret = m_desc; + return __ret; +} + +D3D12_GPU_VIRTUAL_ADDRESS STDMETHODCALLTYPE +MTLD3D12Resource::GetGPUVirtualAddress() { + RTRACE("GetGPUVirtualAddress -> 0x%llx this=%p is_buffer=%d", (unsigned long long)m_gpu_addr, (void*)this, IsBuffer()); + return m_gpu_addr; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Resource::WriteToSubresource( + UINT dst_sub_resource, const D3D12_BOX *dst_box, const void *src_data, + UINT src_row_pitch, UINT src_slice_pitch) { + RTRACE("WriteToSubresource sub=%u box=%p", dst_sub_resource, dst_box); + if (!src_data) + return E_POINTER; + if (m_desc.Dimension > D3D12_RESOURCE_DIMENSION_TEXTURE3D) { + return S_OK; + } + if (m_cpu_addr) { + if (dst_box) { + UINT rows = dst_box->bottom - dst_box->top; + UINT depth = dst_box->back - dst_box->front; + for (UINT z = 0; z < depth; z++) { + for (UINT y = 0; y < rows; y++) { + memcpy((char *)m_cpu_addr + (dst_box->front + z) * src_slice_pitch + (dst_box->top + y) * src_row_pitch + dst_box->left, + (char *)src_data + z * src_slice_pitch + y * src_row_pitch, + dst_box->right - dst_box->left); + } + } + } else { + memcpy(m_cpu_addr, src_data, src_slice_pitch ? src_slice_pitch : src_row_pitch); + } + return S_OK; + } + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12Resource::ReadFromSubresource( + void *dst_data, UINT dst_row_pitch, UINT dst_slice_pitch, + UINT src_sub_resource, const D3D12_BOX *src_box) { + void *vtable = *(void**)this; + RTRACE("ReadFromSubresource dst=%p row_pitch=%u slice_pitch=%u sub=%u box=%p dim=%u this=%p vtable=%p", dst_data, dst_row_pitch, dst_slice_pitch, src_sub_resource, src_box, m_desc.Dimension, (void*)this, vtable); + if (!dst_data) + return E_POINTER; + if (m_desc.Dimension > D3D12_RESOURCE_DIMENSION_TEXTURE3D) { + RTRACE("ReadFromSubresource: invalid dimension %u, this=%p is NOT a resource! Skipping.", m_desc.Dimension, (void*)this); + return S_OK; + } + if (m_cpu_addr) { + UINT rows = m_desc.Height ? m_desc.Height : 1; + if (src_box) { + UINT copy_rows = src_box->bottom - src_box->top; + UINT copy_depth = src_box->back - src_box->front; + UINT copy_width = src_box->right - src_box->left; + for (UINT z = 0; z < copy_depth; z++) { + for (UINT y = 0; y < copy_rows; y++) { + memcpy((char *)dst_data + z * dst_slice_pitch + y * dst_row_pitch, + (char *)m_cpu_addr + (src_box->front + z) * rows * dst_row_pitch + (src_box->top + y) * dst_row_pitch + src_box->left, + copy_width); + } + } + } else { + memcpy(dst_data, m_cpu_addr, dst_slice_pitch ? dst_slice_pitch : dst_row_pitch); + } + return S_OK; + } + if (m_desc.Dimension != D3D12_RESOURCE_DIMENSION_BUFFER) { + UINT total = dst_slice_pitch ? dst_slice_pitch : dst_row_pitch * (m_desc.Height ? m_desc.Height : 1); + if (total) memset(dst_data, 0, total); + return S_OK; + } + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12Resource::GetHeapProperties(D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS *flags) { + if (heap_properties) + *heap_properties = m_heap_properties; + if (flags) + *flags = D3D12_HEAP_FLAG_NONE; + return S_OK; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_resource.hpp b/src/d3d12/d3d12_resource.hpp new file mode 100644 index 000000000..79989b2b4 --- /dev/null +++ b/src/d3d12/d3d12_resource.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "Metal.hpp" +#include "winemetal.h" +#include + +namespace dxmt { + +class MTLD3D12Device; + +class MTLD3D12Resource : public ID3D12Resource { +public: + MTLD3D12Resource(MTLD3D12Device *device, const D3D12_RESOURCE_DESC &desc, + D3D12_RESOURCE_STATES initial_state, + D3D12_HEAP_PROPERTIES heap_properties); + ~MTLD3D12Resource(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + HRESULT STDMETHODCALLTYPE Map(UINT sub_resource, + const D3D12_RANGE *read_range, + void **data) override; + void STDMETHODCALLTYPE Unmap(UINT sub_resource, + const D3D12_RANGE *written_range) override; + D3D12_RESOURCE_DESC *STDMETHODCALLTYPE + GetDesc(D3D12_RESOURCE_DESC *__ret) override; + D3D12_GPU_VIRTUAL_ADDRESS STDMETHODCALLTYPE + GetGPUVirtualAddress() override; + HRESULT STDMETHODCALLTYPE WriteToSubresource( + UINT dst_sub_resource, const D3D12_BOX *dst_box, const void *src_data, + UINT src_row_pitch, UINT src_slice_pitch) override; + HRESULT STDMETHODCALLTYPE ReadFromSubresource( + void *dst_data, UINT dst_row_pitch, UINT dst_slice_pitch, + UINT src_sub_resource, const D3D12_BOX *src_box) override; + HRESULT STDMETHODCALLTYPE + GetHeapProperties(D3D12_HEAP_PROPERTIES *heap_properties, + D3D12_HEAP_FLAGS *flags) override; + + bool IsBuffer() const { + return m_desc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER; + } + + WMT::Reference GetMTLBuffer() { return m_mtl_buffer; } + WMT::Reference GetMTLTexture(); + +private: + MTLD3D12Device *m_device; + D3D12_RESOURCE_DESC m_desc; + D3D12_RESOURCE_STATES m_state; + D3D12_HEAP_PROPERTIES m_heap_properties; + WMTBufferInfo m_buf_info = {}; + WMT::Reference m_mtl_buffer; + WMT::Reference m_mtl_texture; + WMT::Reference m_fake_buffer; + + void *m_cpu_addr = nullptr; + uint64_t m_gpu_addr = 0; + std::atomic m_refCount = {1ul}; + std::atomic m_refPrivate = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_root_signature.cpp b/src/d3d12/d3d12_root_signature.cpp new file mode 100644 index 000000000..00095f1bb --- /dev/null +++ b/src/d3d12/d3d12_root_signature.cpp @@ -0,0 +1,149 @@ +#include "d3d12_root_signature.hpp" +#include "d3d12_device.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include + +namespace dxmt { + +#pragma pack(push, 1) +struct RSHeader { + uint32_t num_parameters; + uint32_t num_static_samplers; + uint32_t flags; +}; + +struct RSParameter { + uint8_t type; + uint8_t visibility; + union { + struct { + uint32_t register_space; + uint32_t register_index; + uint32_t num_32bit_values; + } constants; + struct { + uint32_t register_space; + uint32_t register_index; + } descriptor; + struct { + uint32_t num_ranges; + } table; + }; +}; + +struct RSDescriptorRange { + uint8_t range_type; + uint32_t num_descriptors; + uint32_t base_register; + uint32_t register_space; + uint32_t offset_in_table; +}; +#pragma pack(pop) + +MTLD3D12RootSignature::MTLD3D12RootSignature(MTLD3D12Device *device, + const void *blob, SIZE_T blob_size) + : m_device(device) { + m_device->AddRef(); + Parse(blob, blob_size); + Logger::info(str::format("D3D12RootSignature: ", m_parameters.size(), + " params, ", m_num_static_samplers, + " static samplers, flags=", m_flags)); +} + +MTLD3D12RootSignature::~MTLD3D12RootSignature() { m_device->Release(); } + +void MTLD3D12RootSignature::Parse(const void *blob, SIZE_T blob_size) { + if (blob_size < sizeof(RSHeader)) + return; + + auto header = static_cast(blob); + m_num_static_samplers = header->num_static_samplers; + m_flags = static_cast(header->flags); + + auto params = reinterpret_cast(blob) + sizeof(RSHeader); + for (uint32_t i = 0; i < header->num_parameters; i++) { + auto p = reinterpret_cast(params); + RootParameter rp = {}; + rp.type = static_cast(p->type); + rp.shader_visibility = p->visibility; + + if (p->type == D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS) { + rp.register_space = p->constants.register_space; + rp.register_index = p->constants.register_index; + } else if (p->type == D3D12_ROOT_PARAMETER_TYPE_CBV || + p->type == D3D12_ROOT_PARAMETER_TYPE_SRV || + p->type == D3D12_ROOT_PARAMETER_TYPE_UAV) { + rp.register_space = p->descriptor.register_space; + rp.register_index = p->descriptor.register_index; + } else if (p->type == D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE) { + auto ranges = reinterpret_cast( + params + sizeof(RSParameter)); + rp.descriptor_table_entries = p->table.num_ranges; + if (p->table.num_ranges > 0) { + rp.range_type = + static_cast(ranges[0].range_type); + rp.num_descriptors = ranges[0].num_descriptors; + rp.register_space = ranges[0].register_space; + rp.register_index = ranges[0].base_register; + } + params += p->table.num_ranges * sizeof(RSDescriptorRange); + } + m_parameters.push_back(rp); + params += sizeof(RSParameter); + } +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::QueryInterface(REFIID riid, void **ppvObject) { + if (!ppvObject) + return E_POINTER; + *ppvObject = nullptr; + + if (riid == IID_IUnknown || riid == IID_ID3D12Object || + riid == IID_ID3D12DeviceChild || riid == IID_ID3D12RootSignature) { + *ppvObject = ref(this); + return S_OK; + } + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE +MTLD3D12RootSignature::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12RootSignature::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::GetPrivateData(REFGUID guid, UINT *data_size, + void *data) { + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::SetPrivateData(REFGUID guid, UINT data_size, + const void *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::SetPrivateDataInterface(REFGUID guid, + const IUnknown *data) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::SetName(LPCWSTR name) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12RootSignature::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_root_signature.hpp b/src/d3d12/d3d12_root_signature.hpp new file mode 100644 index 000000000..073212956 --- /dev/null +++ b/src/d3d12/d3d12_root_signature.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include +#include + +namespace dxmt { + +class MTLD3D12Device; + +struct RootParameter { + D3D12_ROOT_PARAMETER_TYPE type; + uint32_t shader_visibility; + uint32_t register_space; + uint32_t register_index; + uint32_t num_descriptors; + D3D12_DESCRIPTOR_RANGE_TYPE range_type; + uint32_t descriptor_table_entries; +}; + +class MTLD3D12RootSignature : public ID3D12RootSignature { +public: + MTLD3D12RootSignature(MTLD3D12Device *device, const void *blob, + SIZE_T blob_size); + ~MTLD3D12RootSignature(); + + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, + void **ppvObject) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID guid, UINT *data_size, + void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID guid, UINT data_size, + const void *data) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface( + REFGUID guid, const IUnknown *data) override; + HRESULT STDMETHODCALLTYPE SetName(LPCWSTR name) override; + + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + const std::vector &GetParameters() const { + return m_parameters; + } + uint32_t GetNumParameters() const { return m_parameters.size(); } + uint32_t GetNumStaticSamplers() const { return m_num_static_samplers; } + D3D12_ROOT_SIGNATURE_FLAGS GetFlags() const { return m_flags; } + +private: + void Parse(const void *blob, SIZE_T blob_size); + + MTLD3D12Device *m_device; + std::vector m_parameters; + uint32_t m_num_static_samplers = 0; + D3D12_ROOT_SIGNATURE_FLAGS m_flags = D3D12_ROOT_SIGNATURE_FLAG_NONE; + std::atomic m_refCount = {1ul}; +}; + +} // namespace dxmt diff --git a/src/d3d12/d3d12_swapchain.cpp b/src/d3d12/d3d12_swapchain.cpp new file mode 100644 index 000000000..683c81406 --- /dev/null +++ b/src/d3d12/d3d12_swapchain.cpp @@ -0,0 +1,397 @@ +#include "d3d12_swapchain.hpp" +#include "d3d12_device.hpp" +#include "d3d12_resource.hpp" +#include "log/log.hpp" +#include "util_string.hpp" +#include "Metal.hpp" + +#define SCTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "SwapChain::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + +static uint64_t g_sc_enc_id = 0; + +namespace dxmt { + +static WMTPixelFormat DXGIToMTL(DXGI_FORMAT fmt) { + switch (fmt) { + case DXGI_FORMAT_R8G8B8A8_UNORM: return WMTPixelFormatRGBA8Unorm; + case DXGI_FORMAT_R8G8B8A8_UNORM_SRGB: return WMTPixelFormatRGBA8Unorm_sRGB; + case DXGI_FORMAT_B8G8R8A8_UNORM: return WMTPixelFormatBGRA8Unorm; + case DXGI_FORMAT_B8G8R8A8_UNORM_SRGB: return WMTPixelFormatBGRA8Unorm_sRGB; + case DXGI_FORMAT_R16G16B16A16_FLOAT: return WMTPixelFormatRGBA16Float; + case DXGI_FORMAT_R10G10B10A2_UNORM: return WMTPixelFormatRGB10A2Unorm; + default: return WMTPixelFormatBGRA8Unorm; + } +} + +MTLD3D12SwapChain::MTLD3D12SwapChain( + IDXGIFactory1 *factory, MTLD3D12Device *device, + IMTLDXGIDevice *dxgi_device, HWND hWnd, + const DXGI_SWAP_CHAIN_DESC1 *desc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *fs_desc) + : m_factory(factory), m_dxgi_device(dxgi_device), m_device(device), + m_hwnd(hWnd), m_desc(*desc) { + if (m_factory) + m_factory->AddRef(); + if (m_dxgi_device) + m_dxgi_device->AddRef(); + if (m_device) + m_device->AddRef(); + + if (fs_desc) { + m_fs_desc = *fs_desc; + } else { + m_fs_desc = {}; + m_fs_desc.Windowed = true; + } + + m_native_view = WMT::CreateMetalViewFromHWND( + (intptr_t)hWnd, dxgi_device->GetMTLDevice(), m_layer); + + auto wmt_dev = dxgi_device->GetMTLDevice(); + m_present_queue = wmt_dev.newCommandQueue(1); + + ResizeBuffers(0, m_desc.Width, m_desc.Height, m_desc.Format, m_desc.Flags); + Logger::info(str::format("D3D12SwapChain: ", m_desc.Width, "x", m_desc.Height, + " fmt=", m_desc.Format, " hwnd=", (void*)hWnd)); +} + +MTLD3D12SwapChain::~MTLD3D12SwapChain() { + for (uint32_t i = 0; i < 4; i++) + m_backbuffers[i] = nullptr; + if (m_native_view.handle) + WMT::ReleaseMetalView(m_native_view); + if (m_device) + m_device->Release(); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::QueryInterface(REFIID riid, void **ppv) { + if (!ppv) + return E_POINTER; + *ppv = nullptr; + + if (riid == __uuidof(IUnknown) || riid == __uuidof(IDXGIObject) || + riid == __uuidof(IDXGIDeviceSubObject) || + riid == __uuidof(IDXGISwapChain) || riid == __uuidof(IDXGISwapChain1) || + riid == __uuidof(IDXGISwapChain2) || riid == __uuidof(IDXGISwapChain3) || + riid == __uuidof(IDXGISwapChain4)) { + *ppv = ref(this); + return S_OK; + } + return E_NOINTERFACE; +} + +ULONG STDMETHODCALLTYPE MTLD3D12SwapChain::AddRef() { return ++m_refCount; } + +ULONG STDMETHODCALLTYPE MTLD3D12SwapChain::Release() { + uint32_t rc = --m_refCount; + if (!rc) + delete this; + return rc; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetPrivateData(REFGUID Name, UINT *pDataSize, void *pData) { + SCTRACE("GetPrivateData E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::SetPrivateData(REFGUID Name, UINT DataSize, const void *pData) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::SetPrivateDataInterface(REFGUID Name, const IUnknown *pUnknown) { + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetDevice(REFIID riid, void **device) { + return m_device->QueryInterface(riid, device); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetParent(REFIID riid, void **ppParent) { + return m_factory->QueryInterface(riid, ppParent); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::Present(UINT sync_interval, UINT flags) { + SCTRACE("Present sync=%u flags=0x%x", sync_interval, flags); + return Present1(sync_interval, flags, nullptr); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetBuffer(UINT buffer_idx, REFIID riid, void **surface) { + SCTRACE("GetBuffer idx=%u", buffer_idx); + if (!surface) + return E_POINTER; + if (buffer_idx >= 4 || !m_backbuffers[buffer_idx]) { + SCTRACE("GetBuffer idx=%u FAILED (no buffer)", buffer_idx); + return DXGI_ERROR_INVALID_CALL; + } + return m_backbuffers[buffer_idx]->QueryInterface(riid, surface); +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::SetFullscreenState(BOOL fullscreen, IDXGIOutput *target) { + m_fs_desc.Windowed = !fullscreen; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetFullscreenState(BOOL *fullscreen, IDXGIOutput **target) { + if (fullscreen) + *fullscreen = !m_fs_desc.Windowed; + if (target) + *target = nullptr; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetDesc(DXGI_SWAP_CHAIN_DESC *desc) { + desc->BufferDesc.Width = m_desc.Width; + desc->BufferDesc.Height = m_desc.Height; + desc->BufferDesc.RefreshRate = m_fs_desc.RefreshRate; + desc->BufferDesc.Format = m_desc.Format; + desc->BufferDesc.ScanlineOrdering = m_fs_desc.ScanlineOrdering; + desc->BufferDesc.Scaling = m_fs_desc.Scaling; + desc->SampleDesc = m_desc.SampleDesc; + desc->BufferUsage = m_desc.BufferUsage; + desc->BufferCount = m_desc.BufferCount; + desc->OutputWindow = m_hwnd; + desc->Windowed = m_fs_desc.Windowed; + desc->SwapEffect = m_desc.SwapEffect; + desc->Flags = m_desc.Flags; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::ResizeBuffers(UINT buffer_count, UINT width, UINT height, + DXGI_FORMAT format, UINT flags) { + SCTRACE("ResizeBuffers count=%u w=%u h=%u fmt=%u flags=0x%x (old: w=%u h=%u)", + buffer_count, width, height, (unsigned)format, flags, m_desc.Width, m_desc.Height); + for (uint32_t i = 0; i < 4; i++) + m_backbuffers[i] = nullptr; + + if (buffer_count) + m_desc.BufferCount = buffer_count; + if (format != DXGI_FORMAT_UNKNOWN) + m_desc.Format = format; + + if (width == 0 || height == 0) { + RECT rect; + if (GetClientRect(m_hwnd, &rect)) { + width = rect.right - rect.left; + height = rect.bottom - rect.top; + } + if (width == 0) + width = 1; + if (height == 0) + height = 1; + } + m_desc.Width = width; + m_desc.Height = height; + + D3D12_RESOURCE_DESC res_desc = {}; + res_desc.Dimension = D3D12_RESOURCE_DIMENSION_TEXTURE2D; + res_desc.Width = m_desc.Width; + res_desc.Height = m_desc.Height; + res_desc.DepthOrArraySize = 1; + res_desc.MipLevels = 1; + res_desc.Format = m_desc.Format; + res_desc.SampleDesc.Count = 1; + res_desc.SampleDesc.Quality = 0; + res_desc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN; + res_desc.Flags = D3D12_RESOURCE_FLAG_ALLOW_RENDER_TARGET; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_DEFAULT; + + uint32_t count = m_desc.BufferCount ? m_desc.BufferCount : 2; + if (count > 4) count = 4; + for (uint32_t i = 0; i < count; i++) { + m_backbuffers[i] = new MTLD3D12Resource(m_device, res_desc, + D3D12_RESOURCE_STATE_RENDER_TARGET, + heap_props); + } + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::ResizeTarget(const DXGI_MODE_DESC *new_target_params) { + SCTRACE("ResizeTarget w=%u h=%u", new_target_params ? new_target_params->Width : 0, new_target_params ? new_target_params->Height : 0); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetContainingOutput(IDXGIOutput **output) { + if (!output) + return E_POINTER; + *output = nullptr; + Com adapter; + HRESULT hr = m_factory->EnumAdapters(0, &adapter); + if (FAILED(hr)) return hr; + hr = adapter->EnumOutputs(0, output); + SCTRACE("GetContainingOutput -> hr=0x%lx output=%p", hr, output ? *output : nullptr); + return hr; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetFrameStatistics(DXGI_FRAME_STATISTICS *stats) { + if (stats) + memset(stats, 0, sizeof(*stats)); + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetLastPresentCount(UINT *last_present_count) { + if (last_present_count) + *last_present_count = (UINT)m_present_count; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetDesc1(DXGI_SWAP_CHAIN_DESC1 *desc) { + *desc = m_desc; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetFullscreenDesc(DXGI_SWAP_CHAIN_FULLSCREEN_DESC *desc) { + *desc = m_fs_desc; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetHwnd(HWND *hWnd) { + if (hWnd) + *hWnd = m_hwnd; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetCoreWindow(REFIID riid, void **core_window) { + SCTRACE("GetCoreWindow E_NOTIMPL"); + return E_NOTIMPL; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::Present1(UINT sync_interval, UINT flags, + const DXGI_PRESENT_PARAMETERS *params) { + m_present_count++; + + if (!m_backbuffers[m_current_buffer]) { + SCTRACE("SwapChain::Present sync=%u flags=0x%x NO BACKBUFFER idx=%u", sync_interval, flags, m_current_buffer); + return S_OK; + } + + if (!m_present_queue.handle) { + auto wmt_device = m_dxgi_device->GetMTLDevice(); + m_present_queue = wmt_device.newCommandQueue(1); + } + + auto cmdbuf = m_present_queue.commandBuffer(); + + auto drawable = m_layer.nextDrawable(); + if (!drawable.handle) { + SCTRACE("SwapChain::Present sync=%u flags=0x%x NO DRAWABLE", sync_interval, flags); + cmdbuf.commit(); + return S_OK; + } + + auto dst_texture = drawable.texture(); + auto *res = static_cast(m_backbuffers[m_current_buffer].ptr()); + auto src_texture = res->GetMTLTexture(); + + SCTRACE("SwapChain::Present blit: idx=%u src=%p dst=%p w=%u h=%u", + m_current_buffer, src_texture.handle, dst_texture.handle, m_desc.Width, m_desc.Height); + + if (src_texture.handle && dst_texture.handle) { + auto blit = cmdbuf.blitCommandEncoder(); + uint64_t _sc_eid = __atomic_add_fetch(&g_sc_enc_id, 1, __ATOMIC_SEQ_CST); + SCTRACE("[SC_ENC+%llu] CREATE blit handle=%llu", (unsigned long long)_sc_eid, (unsigned long long)blit.handle); + struct wmtcmd_blit_copy_from_texture_to_texture copy = {}; + copy.type = WMTBlitCommandCopyFromTextureToTexture; + copy.next.set(nullptr); + copy.src = src_texture; + copy.src_slice = 0; + copy.src_level = 0; + copy.src_origin = {0, 0, 0}; + copy.src_size = {m_desc.Width, m_desc.Height, 1}; + copy.dst = dst_texture; + copy.dst_slice = 0; + copy.dst_level = 0; + copy.dst_origin = {0, 0, 0}; + blit.encodeCommands(reinterpret_cast(©)); + SCTRACE("[SC_ENC] END handle=%llu", (unsigned long long)blit.handle); + blit.endEncoding(); + } + + cmdbuf.presentDrawable(drawable); + SCTRACE("[SC_ENC] COMMIT cmdbuf=%llu", (unsigned long long)cmdbuf.handle); + cmdbuf.commit(); + + m_current_buffer = (m_current_buffer + 1) % (m_desc.BufferCount ? m_desc.BufferCount : 2); + + return S_OK; +} + +WINBOOL STDMETHODCALLTYPE MTLD3D12SwapChain::IsTemporaryMonoSupported() { + return false; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetRestrictToOutput(IDXGIOutput **output) { + *output = nullptr; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::SetBackgroundColor(const DXGI_RGBA *color) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetBackgroundColor(DXGI_RGBA *color) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::SetRotation(DXGI_MODE_ROTATION rotation) { return S_OK; } + +HRESULT STDMETHODCALLTYPE +MTLD3D12SwapChain::GetRotation(DXGI_MODE_ROTATION *rotation) { + if (rotation) + *rotation = DXGI_MODE_ROTATION_IDENTITY; + return S_OK; +} + +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::SetSourceSize(UINT Width, UINT Height) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::GetSourceSize(UINT *pWidth, UINT *pHeight) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::SetMaximumFrameLatency(UINT MaxLatency) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::GetMaximumFrameLatency(UINT *pMaxLatency) { if (pMaxLatency) *pMaxLatency = 1; return S_OK; } +HANDLE STDMETHODCALLTYPE MTLD3D12SwapChain::GetFrameLatencyWaitableObject() { return nullptr; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::SetMatrixTransform(const DXGI_MATRIX_3X2_F *pMatrix) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::GetMatrixTransform(DXGI_MATRIX_3X2_F *pMatrix) { return S_OK; } +UINT STDMETHODCALLTYPE MTLD3D12SwapChain::GetCurrentBackBufferIndex() { return m_current_buffer; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::CheckColorSpaceSupport(DXGI_COLOR_SPACE_TYPE ColorSpace, UINT *pSupport) { if (pSupport) *pSupport = 0; return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::SetColorSpace1(DXGI_COLOR_SPACE_TYPE ColorSpace) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::ResizeBuffers1(UINT, UINT, UINT, DXGI_FORMAT, UINT, const UINT *, IUnknown *const *) { return S_OK; } +HRESULT STDMETHODCALLTYPE MTLD3D12SwapChain::SetHDRMetaData(DXGI_HDR_METADATA_TYPE, UINT, void *) { return S_OK; } + +HRESULT CreateD3D12SwapChain(IDXGIFactory1 *factory, MTLD3D12Device *device, + IMTLDXGIDevice *dxgi_device, HWND hWnd, + const DXGI_SWAP_CHAIN_DESC1 *desc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *fs_desc, + IDXGISwapChain1 **pp_swap_chain) { + if (!pp_swap_chain) + return E_POINTER; + *pp_swap_chain = nullptr; + + auto swapchain = new MTLD3D12SwapChain(factory, device, dxgi_device, hWnd, + desc, fs_desc); + HRESULT hr = swapchain->QueryInterface(IID_PPV_ARGS(pp_swap_chain)); + if (FAILED(hr)) + swapchain->Release(); + return hr; +} + +} // namespace dxmt diff --git a/src/d3d12/d3d12_swapchain.hpp b/src/d3d12/d3d12_swapchain.hpp new file mode 100644 index 000000000..39b9fd86a --- /dev/null +++ b/src/d3d12/d3d12_swapchain.hpp @@ -0,0 +1,105 @@ +#pragma once + +#include "com/com_pointer.hpp" +#include "d3d12.h" +#include "dxgi_interfaces.h" +#include "Metal.hpp" +#include +#include + +namespace dxmt { + +class MTLD3D12Device; +class MTLD3D12Resource; + +class MTLD3D12SwapChain : public IDXGISwapChain4 { +public: + MTLD3D12SwapChain(IDXGIFactory1 *factory, MTLD3D12Device *device, + IMTLDXGIDevice *dxgi_device, HWND hWnd, + const DXGI_SWAP_CHAIN_DESC1 *desc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *fs_desc); + ~MTLD3D12SwapChain(); + + /*** IUnknown ***/ + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppv) override; + ULONG STDMETHODCALLTYPE AddRef() override; + ULONG STDMETHODCALLTYPE Release() override; + + /*** IDXGIObject ***/ + HRESULT STDMETHODCALLTYPE SetPrivateData(REFGUID Name, UINT DataSize, const void *pData) override; + HRESULT STDMETHODCALLTYPE SetPrivateDataInterface(REFGUID Name, const IUnknown *pUnknown) override; + HRESULT STDMETHODCALLTYPE GetPrivateData(REFGUID Name, UINT *pDataSize, void *pData) override; + HRESULT STDMETHODCALLTYPE GetParent(REFIID riid, void **ppParent) override; + + /*** IDXGIDeviceSubObject ***/ + HRESULT STDMETHODCALLTYPE GetDevice(REFIID riid, void **device) override; + + /*** IDXGISwapChain ***/ + HRESULT STDMETHODCALLTYPE Present(UINT SyncInterval, UINT Flags) override; + HRESULT STDMETHODCALLTYPE GetBuffer(UINT Buffer, REFIID riid, void **ppSurface) override; + HRESULT STDMETHODCALLTYPE SetFullscreenState(BOOL Fullscreen, IDXGIOutput *pTarget) override; + HRESULT STDMETHODCALLTYPE GetFullscreenState(BOOL *pFullscreen, IDXGIOutput **ppTarget) override; + HRESULT STDMETHODCALLTYPE GetDesc(DXGI_SWAP_CHAIN_DESC *pDesc) override; + HRESULT STDMETHODCALLTYPE ResizeBuffers(UINT BufferCount, UINT Width, UINT Height, DXGI_FORMAT NewFormat, UINT SwapChainFlags) override; + HRESULT STDMETHODCALLTYPE ResizeTarget(const DXGI_MODE_DESC *pNewTargetParameters) override; + HRESULT STDMETHODCALLTYPE GetContainingOutput(IDXGIOutput **ppOutput) override; + HRESULT STDMETHODCALLTYPE GetFrameStatistics(DXGI_FRAME_STATISTICS *pStats) override; + HRESULT STDMETHODCALLTYPE GetLastPresentCount(UINT *pLastPresentCount) override; + + /*** IDXGISwapChain1 ***/ + HRESULT STDMETHODCALLTYPE GetDesc1(DXGI_SWAP_CHAIN_DESC1 *pDesc) override; + HRESULT STDMETHODCALLTYPE GetFullscreenDesc(DXGI_SWAP_CHAIN_FULLSCREEN_DESC *pDesc) override; + HRESULT STDMETHODCALLTYPE GetHwnd(HWND *pHwnd) override; + HRESULT STDMETHODCALLTYPE GetCoreWindow(REFIID riid, void **ppUnk) override; + HRESULT STDMETHODCALLTYPE Present1(UINT SyncInterval, UINT PresentFlags, const DXGI_PRESENT_PARAMETERS *pPresentParameters) override; + + WINBOOL STDMETHODCALLTYPE IsTemporaryMonoSupported() override; + HRESULT STDMETHODCALLTYPE GetRestrictToOutput(IDXGIOutput **ppOutput) override; + HRESULT STDMETHODCALLTYPE SetBackgroundColor(const DXGI_RGBA *pColor) override; + HRESULT STDMETHODCALLTYPE GetBackgroundColor(DXGI_RGBA *pColor) override; + HRESULT STDMETHODCALLTYPE SetRotation(DXGI_MODE_ROTATION Rotation) override; + HRESULT STDMETHODCALLTYPE GetRotation(DXGI_MODE_ROTATION *pRotation) override; + + /*** IDXGISwapChain2 ***/ + HRESULT STDMETHODCALLTYPE SetSourceSize(UINT Width, UINT Height) override; + HRESULT STDMETHODCALLTYPE GetSourceSize(UINT *pWidth, UINT *pHeight) override; + HRESULT STDMETHODCALLTYPE SetMaximumFrameLatency(UINT MaxLatency) override; + HRESULT STDMETHODCALLTYPE GetMaximumFrameLatency(UINT *pMaxLatency) override; + + HANDLE STDMETHODCALLTYPE GetFrameLatencyWaitableObject() override; + HRESULT STDMETHODCALLTYPE SetMatrixTransform(const DXGI_MATRIX_3X2_F *pMatrix) override; + HRESULT STDMETHODCALLTYPE GetMatrixTransform(DXGI_MATRIX_3X2_F *pMatrix) override; + + /*** IDXGISwapChain3 ***/ + UINT STDMETHODCALLTYPE GetCurrentBackBufferIndex() override; + HRESULT STDMETHODCALLTYPE CheckColorSpaceSupport(DXGI_COLOR_SPACE_TYPE ColorSpace, UINT *pColorSpaceSupport) override; + HRESULT STDMETHODCALLTYPE SetColorSpace1(DXGI_COLOR_SPACE_TYPE ColorSpace) override; + HRESULT STDMETHODCALLTYPE ResizeBuffers1(UINT BufferCount, UINT Width, UINT Height, DXGI_FORMAT Format, UINT SwapChainFlags, const UINT *pCreationNodeMask, IUnknown *const *ppPresentQueue) override; + + /*** IDXGISwapChain4 ***/ + HRESULT STDMETHODCALLTYPE SetHDRMetaData(DXGI_HDR_METADATA_TYPE Type, UINT Size, void *pMetaData) override; + +private: + std::atomic m_refCount = {1ul}; + Com m_factory; + Com m_dxgi_device; + MTLD3D12Device *m_device = nullptr; + HWND m_hwnd = nullptr; + WMT::MetalLayer m_layer = {}; + WMT::Object m_native_view; + DXGI_SWAP_CHAIN_DESC1 m_desc = {}; + DXGI_SWAP_CHAIN_FULLSCREEN_DESC m_fs_desc = {}; + uint64_t m_present_count = 0; + + std::array, 4> m_backbuffers; + uint32_t m_current_buffer = 0; + WMT::Reference m_present_queue; +}; + +HRESULT CreateD3D12SwapChain(IDXGIFactory1 *factory, MTLD3D12Device *device, + IMTLDXGIDevice *dxgi_device, HWND hWnd, + const DXGI_SWAP_CHAIN_DESC1 *desc, + const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *fs_desc, + IDXGISwapChain1 **pp_swap_chain); + +} // namespace dxmt diff --git a/src/d3d12/meson.build b/src/d3d12/meson.build new file mode 100644 index 000000000..2d6b0279b --- /dev/null +++ b/src/d3d12/meson.build @@ -0,0 +1,56 @@ +d3d12_res = wrc_generator.process('version.rc') + +d3d12_src = [ + 'd3d12.cpp', + 'd3d12_device.cpp', + 'd3d12_command_queue.cpp', + 'd3d12_command_allocator.cpp', + 'd3d12_command_list.cpp', + 'd3d12_descriptor_heap.cpp', + 'd3d12_dxgi_device.cpp', + 'd3d12_fence.cpp', + 'd3d12_heap.cpp', + 'd3d12_pipeline_state.cpp', + 'd3d12_query_heap.cpp', + 'd3d12_resource.cpp', + 'd3d12_root_signature.cpp', + 'd3d12_swapchain.cpp', + '../airconv/dxil/dxil_container.cpp', + '../airconv/dxil/llvm_bitcode.cpp', + '../airconv/dxil/dxil_to_msl.cpp', +] + +d3d12_ld_args = ['-L' + windows_native_install_dir] +d3d12_link_depends = [] + +d3d12_dll = shared_library('d3d12', d3d12_src, d3d12_res, + name_prefix : '', + dependencies : [ + dxgi_dep, + dxbc_parser_dep, + airconv_forward_dep, + dxmt_dep, + winemetal_dep, + util_dep ], + include_directories : [ dxmt_include_path ], + install : true, + install_dir : windows_native_install_dir, + vs_module_defs : 'd3d12.def', + link_args : d3d12_ld_args, + link_depends : [ d3d12_link_depends ], +) + +d3d12_dep = declare_dependency( + link_with : [ d3d12_dll ], + include_directories : [ dxmt_include_path ], +) + +if wine_builtin_dll +custom_target('postprocess_libd3d12', + input : d3d12_dll, + output: 'd3d12.dll.postproc', + command : [ winebuild, '--builtin', '@INPUT@' ], + depends : d3d12_dll, + build_by_default : true +) +endif diff --git a/src/d3d12/version.rc b/src/d3d12/version.rc new file mode 100644 index 000000000..a80519c60 --- /dev/null +++ b/src/d3d12/version.rc @@ -0,0 +1,30 @@ +#include + +VS_VERSION_INFO VERSIONINFO +FILEVERSION 10,0,17763,1 +PRODUCTVERSION 10,0,17763,1 +FILEFLAGSMASK VS_FFI_FILEFLAGSMASK +FILEFLAGS 0 +FILEOS VOS_NT_WINDOWS32 +FILETYPE VFT_DLL +FILESUBTYPE VFT2_UNKNOWN +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "080904b0" + BEGIN + VALUE "CompanyName", "DXMT" + VALUE "FileDescription", "Direct3D 12 Runtime" + VALUE "FileVersion", "10.0.17763.1 (WinBuild.160101.0800)" + VALUE "InternalName", "D3D12.dll" + VALUE "LegalCopyright", "MIT License" + VALUE "OriginalFilename", "D3D12.dll" + VALUE "ProductName", "DXMT" + VALUE "ProductVersion", "10.0.17763.1" + END + END + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x0809, 1200 + END +END diff --git a/src/dxgi/dxgi_adapter.cpp b/src/dxgi/dxgi_adapter.cpp index 1b3d37cdb..9d964bedb 100644 --- a/src/dxgi/dxgi_adapter.cpp +++ b/src/dxgi/dxgi_adapter.cpp @@ -7,6 +7,7 @@ #include "dxgi_interfaces.h" #include "dxgi_object.hpp" #include "d3d10_1.h" +#include "d3d12.h" #include "Metal.hpp" namespace dxmt { @@ -151,7 +152,7 @@ class MTLDXGIAdapter : public MTLDXGIObject { if (options_.customVendorId >= 0) { pDesc->VendorId = options_.customVendorId; } else { - pDesc->VendorId = 0x106B; + pDesc->VendorId = 0x1002; if (g_extension_enabled == VendorExtension::Nvidia) { pDesc->VendorId = 0x10DE; } @@ -176,6 +177,16 @@ class MTLDXGIAdapter : public MTLDXGIObject { pDesc->GraphicsPreemptionGranularity = DXGI_GRAPHICS_PREEMPTION_DMA_BUFFER_BOUNDARY; pDesc->ComputePreemptionGranularity = DXGI_COMPUTE_PREEMPTION_DMA_BUFFER_BOUNDARY; + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fwprintf(f, L"GetDesc3: %s VendorId=0x%04x DeviceId=0x%04x VRAM=%lluMB Flags=0x%x\n", + pDesc->Description, pDesc->VendorId, pDesc->DeviceId, + (unsigned long long)pDesc->DedicatedVideoMemory/(1024*1024), pDesc->Flags); + fclose(f); + } + } + return S_OK; } @@ -195,10 +206,14 @@ class MTLDXGIAdapter : public MTLDXGIObject { } HRESULT STDMETHODCALLTYPE CheckInterfaceSupport(const GUID &guid, LARGE_INTEGER *umd_version) final { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "CheckInterfaceSupport: %08lx-%04x-%04x\n", guid.Data1, guid.Data2, guid.Data3); fclose(f); } + } HRESULT hr = DXGI_ERROR_UNSUPPORTED; if (guid == __uuidof(IDXGIDevice) || guid == __uuidof(ID3D10Device) || - guid == __uuidof(ID3D10Device1)) + guid == __uuidof(ID3D10Device1) || guid == __uuidof(ID3D12Device)) hr = S_OK; // We can't really reconstruct the version numbers diff --git a/src/dxgi/dxgi_factory.cpp b/src/dxgi/dxgi_factory.cpp index 0b3fa0975..85db98bb4 100644 --- a/src/dxgi/dxgi_factory.cpp +++ b/src/dxgi/dxgi_factory.cpp @@ -8,6 +8,8 @@ #include "wsi_window.hpp" #include "Metal.hpp" +#define DGTRACE(fmt, ...) do { FILE *_tf = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); if (_tf) { fprintf(_tf, "DXGI::" fmt "\n", ##__VA_ARGS__); fclose(_tf); } } while(0) + namespace dxmt { Com CreateAdapter(WMT::Device Device, @@ -24,6 +26,11 @@ class MTLDXGIFactory : public MTLDXGIObject { return E_POINTER; *ppvObject = nullptr; + + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "DXGIFactory::QI riid=%s\n", str::format(riid).c_str()); fclose(f); } + } if (riid == __uuidof(IUnknown) || riid == __uuidof(IDXGIObject) || riid == __uuidof(IDXGIFactory) || riid == __uuidof(IDXGIFactory1) || @@ -67,6 +74,7 @@ class MTLDXGIFactory : public MTLDXGIObject { HRESULT STDMETHODCALLTYPE CreateSwapChain(IUnknown *pDevice, DXGI_SWAP_CHAIN_DESC *pDesc, IDXGISwapChain **ppSwapChain) final { + DGTRACE("CreateSwapChain (legacy) called"); if (ppSwapChain == nullptr || pDesc == nullptr || pDevice == nullptr) return DXGI_ERROR_INVALID_CALL; @@ -97,17 +105,23 @@ class MTLDXGIFactory : public MTLDXGIObject { return hr; } - HRESULT STDMETHODCALLTYPE CreateSwapChainForHwnd( + HRESULT STDMETHODCALLTYPE + CreateSwapChainForHwnd( IUnknown *pDevice, HWND hWnd, const DXGI_SWAP_CHAIN_DESC1 *pDesc, const DXGI_SWAP_CHAIN_FULLSCREEN_DESC *pFullscreenDesc, IDXGIOutput *pRestrictToOutput, IDXGISwapChain1 **ppSwapChain) final { + DGTRACE("CreateSwapChainForHwnd called"); InitReturnPtr(ppSwapChain); - if (!ppSwapChain || !pDesc || !hWnd || !pDevice) + if (!ppSwapChain || !pDesc || !hWnd || !pDevice) { + DGTRACE("CreateSwapChainForHwnd -> DXGI_ERROR_INVALID_CALL (null args)"); return DXGI_ERROR_INVALID_CALL; + } Com metal_dxgi_device; - if (FAILED(pDevice->QueryInterface(IID_PPV_ARGS(&metal_dxgi_device)))) { + HRESULT qhr = pDevice->QueryInterface(IID_PPV_ARGS(&metal_dxgi_device)); + if (FAILED(qhr)) { + DGTRACE("CreateSwapChainForHwnd -> QI IMTLDXGIDevice FAILED hr=0x%lx", qhr); ERR("Unsupported device type"); return DXGI_ERROR_UNSUPPORTED; } @@ -155,6 +169,10 @@ class MTLDXGIFactory : public MTLDXGIObject { HRESULT STDMETHODCALLTYPE EnumAdapters(UINT Adapter, IDXGIAdapter **ppAdapter) final { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "EnumAdapters(%u, %p)\n", Adapter, ppAdapter); fclose(f); } + } InitReturnPtr(ppAdapter); if (ppAdapter == nullptr) @@ -163,18 +181,37 @@ class MTLDXGIFactory : public MTLDXGIObject { IDXGIAdapter1 *handle = nullptr; HRESULT hr = this->EnumAdapters1(Adapter, &handle); *ppAdapter = handle; + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "EnumAdapters(%u) -> hr=0x%lx adapter=%p\n", Adapter, hr, *ppAdapter); fclose(f); } + } return hr; } HRESULT STDMETHODCALLTYPE EnumAdapters1(UINT Adapter, IDXGIAdapter1 **ppAdapter) final { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "EnumAdapters1(%u, %p)\n", Adapter, ppAdapter); fclose(f); } + } InitReturnPtr(ppAdapter); auto devices = WMT::CopyAllDevices(); UINT adapter_count = devices.count(); + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "EnumAdapters1: adapter_count=%u\n", adapter_count); fclose(f); } + } - if (Adapter >= adapter_count) - return DXGI_ERROR_NOT_FOUND; + if (Adapter >= adapter_count) { + if (adapter_count == 1 && Adapter == 1) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "EnumAdapters1: mapping index 1 -> 0 (single adapter)\n"); fclose(f); } + Adapter = 0; + } else { + return DXGI_ERROR_NOT_FOUND; + } + } UINT adjusted_adapter = Adapter; if (adapter_count > 1) { @@ -309,11 +346,20 @@ class MTLDXGIFactory : public MTLDXGIObject { extern "C" HRESULT __stdcall CreateDXGIFactory2(UINT Flags, REFIID riid, void **ppFactory) { + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "CreateDXGIFactory2 called Flags=0x%x riid=%s\n", Flags, str::format(riid).c_str()); fclose(f); } + } try { MTLDXGIFactory* factory = new MTLDXGIFactory(Flags); HRESULT hr = factory->QueryInterface(riid, ppFactory); factory->Release(); + { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { fprintf(f, "QI hr=0x%lx\n", hr); fclose(f); } + } + if (FAILED(hr)) return hr; diff --git a/src/dxmt/dxmt_allocation.cpp b/src/dxmt/dxmt_allocation.cpp index 036b76b51..d1bfc7327 100644 --- a/src/dxmt/dxmt_allocation.cpp +++ b/src/dxmt/dxmt_allocation.cpp @@ -19,8 +19,44 @@ #include "dxmt_allocation.hpp" #include "util_likely.hpp" #include +#include #include #include +#include + +namespace dxmt { + +void *g_d3d12_device_addr = nullptr; +size_t g_d3d12_device_size = 0; + +} // namespace dxmt + +void operator delete(void *ptr) noexcept { + if (dxmt::g_d3d12_device_addr && ptr) { + uintptr_t p = (uintptr_t)ptr; + uintptr_t d = (uintptr_t)dxmt::g_d3d12_device_addr; + if (p >= d && p < d + dxmt::g_d3d12_device_size) { + DWORD tid = GetCurrentThreadId(); + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "!!! GLOBAL operator delete ON DEVICE! ptr=%p device=%p size=%zu tid=%lu\n", + ptr, dxmt::g_d3d12_device_addr, dxmt::g_d3d12_device_size, (unsigned long)tid); + void *buf[16]; + ULONG n = RtlCaptureStackBackTrace(1, 16, buf, nullptr); + fprintf(f, " stack[%lu]=", (unsigned long)n); + for (ULONG i = 0; i < n; i++) fprintf(f, "%p ", buf[i]); + fprintf(f, "\n"); + fclose(f); + } + return; + } + } + free(ptr); +} + +void operator delete(void *ptr, std::align_val_t) noexcept { + free(ptr); +} namespace dxmt { @@ -31,8 +67,26 @@ Allocation::incRef() { void Allocation::decRef() { - if (refcount_.fetch_sub(1u, std::memory_order_release) == 1u) - this->free(); + if (g_d3d12_device_addr != nullptr && + (uintptr_t)this >= (uintptr_t)g_d3d12_device_addr && + (uintptr_t)this < (uintptr_t)g_d3d12_device_addr + g_d3d12_device_size) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "!!! decRef BLOCKED on DEVICE this=%p vtable=%p refcount=%u tid=%lu\n", + (void*)this, *(void**)this, refcount_.load(), (unsigned long)GetCurrentThreadId()); + void *buf[16]; + ULONG n = RtlCaptureStackBackTrace(1, 16, buf, nullptr); + fprintf(f, " stack[%lu]=", (unsigned long)n); + for (ULONG i = 0; i < n; i++) fprintf(f, "%p ", buf[i]); + fprintf(f, "\n"); + fclose(f); + } + return; + } + uint32_t prev = refcount_.fetch_sub(1u, std::memory_order_release); + if (prev == 1u) { + this->destroy(); + } }; AllocationRefTracking::AllocationRefTracking() { @@ -65,15 +119,19 @@ AllocationRefTracking::addStorage(void *ptr, size_t length) { void AllocationRefTracking::clear() { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); RefAddChunk<> *chunk = reinterpret_cast *>(&chunk_placed); while (chunk) { dxmt::Allocation **list = std::launder(chunk->allocations); - for (unsigned i = 0; i < chunk->size; i++) + for (unsigned i = 0; i < chunk->size; i++) { + if (f) fprintf(f, "RefTracking::clear decRef on alloc[%u]=%p vtable=%p\n", + i, (void*)list[i], list[i] ? *(void**)list[i] : nullptr); list[i]->decRef(); + } chunk->size = 0; chunk = chunk->next_chunk; } - // reset state + if (f) { fprintf(f, "RefTracking::clear done\n"); fclose(f); } chunk_last = reinterpret_cast *>(&chunk_placed); chunk_placed.next_chunk = nullptr; }; diff --git a/src/dxmt/dxmt_allocation.hpp b/src/dxmt/dxmt_allocation.hpp index 327946dca..4cbd5ba7a 100644 --- a/src/dxmt/dxmt_allocation.hpp +++ b/src/dxmt/dxmt_allocation.hpp @@ -23,14 +23,20 @@ namespace dxmt { +extern void *g_d3d12_device_addr; +extern size_t g_d3d12_device_size; + class Allocation { public: + ~Allocation(){}; + + virtual void destroy() = 0; + void incRef(); void decRef(); bool checkRetained(uint64_t seq_id) { - // FIXME: is a compare-and-swap necessary? if (seq_id == last_retained_seq_id) return true; last_retained_seq_id = seq_id; diff --git a/src/dxmt/dxmt_buffer.cpp b/src/dxmt/dxmt_buffer.cpp index 8589c32e4..b20df3c1b 100644 --- a/src/dxmt/dxmt_buffer.cpp +++ b/src/dxmt/dxmt_buffer.cpp @@ -18,11 +18,13 @@ #include "dxmt_buffer.hpp" #include "dxmt_format.hpp" +#include "log/log.hpp" #include "thread.hpp" #include "util_likely.hpp" #include "util_math.hpp" #include "wsi_platform.hpp" #include +#include #include namespace dxmt { @@ -48,11 +50,82 @@ BufferAllocation::BufferAllocation(WMT::Device device, const WMTBufferInfo &info obj_ = device.newBuffer(info_); gpuAddress_ = info_.gpu_address; mappedMemory_ = info_.memory.get_accessible_or_null(); + { + extern void *g_d3d12_device_addr; + extern size_t g_d3d12_device_size; + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "BufferAlloc CTOR %p: placed_buf=%p mappedMem=%p gpuAddr=0x%llx len=%llu flags=%u", + this, placed_buffer, mappedMemory_, (unsigned long long)gpuAddress_, (unsigned long long)info_.length, (uint32_t)flags_.raw()); + if (g_d3d12_device_addr) { + uintptr_t dstart = (uintptr_t)g_d3d12_device_addr; + uintptr_t dend = dstart + g_d3d12_device_size; + uintptr_t astart = (uintptr_t)this; + uintptr_t aend = astart + sizeof(BufferAllocation); + bool overlaps = (astart < dend && aend > dstart); + fprintf(f, " device_dist=%lld overlap=%d", (long long)(astart - dstart), overlaps); + if (overlaps) { + fprintf(f, " !!! ALLOCATION OVERLAPS DEVICE !!!"); + } + } + fprintf(f, "\n"); + fclose(f); + } + } }; void BufferAllocation::free() { + { + extern void *g_d3d12_device_addr; + extern size_t g_d3d12_device_size; + DWORD tid = GetCurrentThreadId(); + void *ret_addr = __builtin_return_address(0); + void *ret_addr2 = __builtin_return_address(1); + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "BufferAlloc FREE ENTRY %p: canary=0x%08x ret0=%p ret1=%p tid=%lu vtable_now=%p\n", + this, canary_, ret_addr, ret_addr2, (unsigned long)tid, *(void**)this); + void *buf[16]; + ULONG n = RtlCaptureStackBackTrace(1, 16, buf, nullptr); + fprintf(f, " stack[%lu]=", (unsigned long)n); + for (ULONG i = 0; i < n; i++) fprintf(f, "%p ", buf[i]); + if (g_d3d12_device_addr) { + uintptr_t dstart = (uintptr_t)g_d3d12_device_addr; + uintptr_t dend = dstart + g_d3d12_device_size; + uintptr_t astart = (uintptr_t)this; + bool overlaps = (astart >= dstart && astart < dend); + fprintf(f, "\n device_dist=%lld overlap=%d sizeof_this=%zu", + (long long)(astart - dstart), overlaps, sizeof(BufferAllocation)); + if (overlaps) { + fprintf(f, " !!! FREE ON DEVICE MEMORY — ABORTING !!!\n"); + fclose(f); + return; + } + } + fprintf(f, "\n"); + fclose(f); + } + } + if (canary_ != 0xDEADBEEF) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, " CORRUPTED! canary=0x%08x expected 0xDEADBEEF, skipping free\n", canary_); + fclose(f); + } + return; + } if (placed_buffer) { + uintptr_t p = reinterpret_cast(placed_buffer); + if (p & 0xFFF) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, " placed_buf=%p NOT page-aligned! heap corruption, skipping free\n", placed_buffer); + fclose(f); + } + placed_buffer = nullptr; + return; + } wsi::aligned_free(placed_buffer); placed_buffer = nullptr; } diff --git a/src/dxmt/dxmt_buffer.hpp b/src/dxmt/dxmt_buffer.hpp index b6726d73b..f9cec7d81 100644 --- a/src/dxmt/dxmt_buffer.hpp +++ b/src/dxmt/dxmt_buffer.hpp @@ -22,6 +22,8 @@ #include "dxmt_deptrack.hpp" #include "dxmt_residency.hpp" #include "dxmt_allocation.hpp" +#include +#include #include "rc/util_rc_ptr.hpp" #include "thread.hpp" #include "util_flags.hpp" @@ -127,6 +129,28 @@ class BufferAllocation final : public Allocation { BufferAllocation(WMT::Device device, const WMTBufferInfo &info, Flags flags); void free(); + void destroy() override { + if (dxmt::g_d3d12_device_addr) { + uintptr_t a = (uintptr_t)this; + uintptr_t d = (uintptr_t)dxmt::g_d3d12_device_addr; + if (a >= d && a < d + dxmt::g_d3d12_device_size) { + FILE *f = fopen("Z:\\tmp\\dxmt_dxgi_trace.log", "a"); + if (f) { + fprintf(f, "!!! BLOCKED BufferAllocation::destroy() on DEVICE this=%p device=%p tid=%lu\n", + (void*)this, dxmt::g_d3d12_device_addr, (unsigned long)GetCurrentThreadId()); + void *buf[16]; + ULONG n = RtlCaptureStackBackTrace(1, 16, buf, nullptr); + fprintf(f, " stack[%lu]=", (unsigned long)n); + for (ULONG i = 0; i < n; i++) fprintf(f, "%p ", buf[i]); + fprintf(f, "\n"); + fclose(f); + } + return; + } + } + delete this; + } + BufferAllocation(const BufferAllocation &) = delete; BufferAllocation(BufferAllocation &&) = delete; @@ -142,6 +166,7 @@ class BufferAllocation final : public Allocation { uint32_t suballocation_count_ = 1; void * placed_buffer = nullptr; + uint32_t canary_ = 0xDEADBEEF; }; class Buffer { diff --git a/src/dxmt/dxmt_texture.hpp b/src/dxmt/dxmt_texture.hpp index 0fa103037..bd5be67bc 100644 --- a/src/dxmt/dxmt_texture.hpp +++ b/src/dxmt/dxmt_texture.hpp @@ -162,6 +162,8 @@ class TextureAllocation final : public Allocation { void free(); + void destroy() override { delete this; } + TextureAllocation(const TextureAllocation &) = delete; TextureAllocation(TextureAllocation &&) = delete; diff --git a/src/meson.build b/src/meson.build index cf0edbe44..399221a75 100644 --- a/src/meson.build +++ b/src/meson.build @@ -10,6 +10,7 @@ endif subdir('dxmt') subdir('dxgi') subdir('d3d11') +subdir('d3d12') subdir('d3d10') if get_option('enable_nvapi') diff --git a/src/winemetal/Metal.hpp b/src/winemetal/Metal.hpp index 3d3aa9895..cafeb4632 100644 --- a/src/winemetal/Metal.hpp +++ b/src/winemetal/Metal.hpp @@ -786,6 +786,11 @@ class Device : public Object { return Reference(MTLDevice_newLibrary(handle, data, &error.handle)); } + Reference + newLibraryWithSource(const char *source, uint64_t source_length, Error &error) { + return Reference(MTLDevice_newLibraryWithSource(handle, source, source_length, &error.handle)); + } + Reference newComputePipelineState(const Function &compute_function, Error &error) { WMTComputePipelineInfo info; diff --git a/src/winemetal/unix/winemetal_unix.c b/src/winemetal/unix/winemetal_unix.c index 35a934bad..b714d5532 100644 --- a/src/winemetal/unix/winemetal_unix.c +++ b/src/winemetal/unix/winemetal_unix.c @@ -363,6 +363,23 @@ _MTLDevice_newLibrary(void *obj) { return STATUS_SUCCESS; } +static NTSTATUS +_MTLDevice_newLibraryWithSource(void *obj) { + struct unixcall_mtldevice_newlibrary_source *params = obj; + id device = (id)params->device; + NSError *err = NULL; + NSString *source = [[NSString alloc] initWithBytes:params->source.ptr + length:params->source_length + encoding:NSUTF8StringEncoding]; + MTLCompileOptions *options = [[MTLCompileOptions alloc] init]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + params->ret_library = (obj_handle_t)[device newLibraryWithSource:source options:options error:&err]; + params->ret_error = (obj_handle_t)err; + [source release]; + [options release]; + return STATUS_SUCCESS; +} + static NTSTATUS _MTLLibrary_newFunction(void *obj) { struct unixcall_generic_obj_uint64_obj_ret *params = obj; @@ -2974,6 +2991,7 @@ const void *__wine_unix_call_funcs[] = { &_MTLCommandBuffer_blitCommandEncoderWithSampleBuffers, &_MTLCommandBuffer_property, &_MTLDevice_newTileRenderPipelineState, + &_MTLDevice_newLibraryWithSource, }; #ifndef DXMT_NATIVE @@ -3110,5 +3128,6 @@ const void *__wine_unix_call_wow64_funcs[] = { &_MTLCommandBuffer_blitCommandEncoderWithSampleBuffers, &_MTLCommandBuffer_property, &_MTLDevice_newTileRenderPipelineState, + &_MTLDevice_newLibraryWithSource, }; #endif diff --git a/src/winemetal/winemetal.h b/src/winemetal/winemetal.h index c70893157..6b56cee89 100644 --- a/src/winemetal/winemetal.h +++ b/src/winemetal/winemetal.h @@ -582,6 +582,8 @@ enum WMTAttributeFormat : uint32_t { WINEMETAL_API obj_handle_t MTLDevice_newLibrary(obj_handle_t device, obj_handle_t data, obj_handle_t *err_out); +WINEMETAL_API obj_handle_t MTLDevice_newLibraryWithSource(obj_handle_t device, const char *source, uint64_t source_length, obj_handle_t *err_out); + WINEMETAL_API obj_handle_t MTLLibrary_newFunction(obj_handle_t library, const char *name); WINEMETAL_API uint64_t NSString_lengthOfBytesUsingEncoding(obj_handle_t str, enum WMTStringEncoding encoding); diff --git a/src/winemetal/winemetal_thunks.c b/src/winemetal/winemetal_thunks.c index af3166ee8..69f1a99ae 100644 --- a/src/winemetal/winemetal_thunks.c +++ b/src/winemetal/winemetal_thunks.c @@ -268,6 +268,22 @@ MTLDevice_newLibrary( return params.ret_library; } +WINEMETAL_API obj_handle_t +MTLDevice_newLibraryWithSource( + obj_handle_t device, const char *source, uint64_t source_length, obj_handle_t *err_out +) { + struct unixcall_mtldevice_newlibrary_source params; + params.device = device; + params.source.ptr = source; + params.source_length = source_length; + params.ret_error = 0; + params.ret_library = 0; + UNIX_CALL(133, ¶ms); + if (err_out) + *err_out = params.ret_error; + return params.ret_library; +} + WINEMETAL_API obj_handle_t MTLLibrary_newFunction(obj_handle_t library, const char *name) { struct unixcall_generic_obj_uint64_obj_ret params; diff --git a/src/winemetal/winemetal_thunks.h b/src/winemetal/winemetal_thunks.h index 9397d1379..1eed7d7f3 100644 --- a/src/winemetal/winemetal_thunks.h +++ b/src/winemetal/winemetal_thunks.h @@ -371,6 +371,14 @@ struct unixcall_mtlcommandbuffer_blitcommandencoderwithsamplebuffers { obj_handle_t ret; }; +struct unixcall_mtldevice_newlibrary_source { + obj_handle_t device; + struct WMTConstMemoryPointer source; + uint64_t source_length; + obj_handle_t ret_error; + obj_handle_t ret_library; +}; + #pragma pack(pop) #endif