From db76b1e000556683ae013137dd5cd2a96daaf54a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:48:08 +0000 Subject: [PATCH 1/2] Initial plan From 39bdf575df0f056b0e22a52cb9d142b58315d3f7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:14:23 +0000 Subject: [PATCH 2/2] Add shader graph node system core and editor integration Co-authored-by: bluesky013 <35895395+bluesky013@users.noreply.github.com> --- .../render/adaptor/assets/ShaderGraphAsset.h | 32 ++ .../adaptor/src/assets/ShaderGraphAsset.cpp | 28 + .../render/editor/ShaderGraphCreator.h | 21 + .../shadergraph/ShaderGraphEditWindow.h | 19 + .../render/editor/src/RenderEditorModule.cpp | 5 + .../render/editor/src/ShaderGraphCreator.cpp | 30 + .../src/shadergraph/ShaderGraphEditWindow.cpp | 22 + .../src/shadergraph/ShaderGraphNodeModel.cpp | 201 +++++++ .../src/shadergraph/ShaderGraphNodeModel.h | 159 ++++++ .../src/shadergraph/ShaderGraphWidget.cpp | 82 +++ .../src/shadergraph/ShaderGraphWidget.h | 35 ++ .../include/shader/shadergraph/ShaderGraph.h | 65 +++ .../shadergraph/ShaderGraphInputNodes.h | 246 ++++++++ .../shader/shadergraph/ShaderGraphMathNodes.h | 261 +++++++++ .../shader/shadergraph/ShaderGraphNode.h | 82 +++ .../shadergraph/ShaderGraphOutputNode.h | 31 + .../shader/shadergraph/ShaderGraphPin.h | 47 ++ .../shader/shadergraph/ShaderGraphTypes.h | 89 +++ .../shader/src/shadergraph/ShaderGraph.cpp | 298 ++++++++++ .../src/shadergraph/ShaderGraphInputNodes.cpp | 436 ++++++++++++++ .../src/shadergraph/ShaderGraphMathNodes.cpp | 531 ++++++++++++++++++ .../src/shadergraph/ShaderGraphNode.cpp | 31 + .../src/shadergraph/ShaderGraphOutputNode.cpp | 39 ++ 23 files changed, 2790 insertions(+) create mode 100644 engine/render/adaptor/include/render/adaptor/assets/ShaderGraphAsset.h create mode 100644 engine/render/adaptor/src/assets/ShaderGraphAsset.cpp create mode 100644 engine/render/editor/include/render/editor/ShaderGraphCreator.h create mode 100644 engine/render/editor/include/render/editor/shadergraph/ShaderGraphEditWindow.h create mode 100644 engine/render/editor/src/ShaderGraphCreator.cpp create mode 100644 engine/render/editor/src/shadergraph/ShaderGraphEditWindow.cpp create mode 100644 engine/render/editor/src/shadergraph/ShaderGraphNodeModel.cpp create mode 100644 engine/render/editor/src/shadergraph/ShaderGraphNodeModel.h create mode 100644 engine/render/editor/src/shadergraph/ShaderGraphWidget.cpp create mode 100644 engine/render/editor/src/shadergraph/ShaderGraphWidget.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraph.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphInputNodes.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphMathNodes.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphNode.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphOutputNode.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphPin.h create mode 100644 engine/render/shader/include/shader/shadergraph/ShaderGraphTypes.h create mode 100644 engine/render/shader/src/shadergraph/ShaderGraph.cpp create mode 100644 engine/render/shader/src/shadergraph/ShaderGraphInputNodes.cpp create mode 100644 engine/render/shader/src/shadergraph/ShaderGraphMathNodes.cpp create mode 100644 engine/render/shader/src/shadergraph/ShaderGraphNode.cpp create mode 100644 engine/render/shader/src/shadergraph/ShaderGraphOutputNode.cpp diff --git a/engine/render/adaptor/include/render/adaptor/assets/ShaderGraphAsset.h b/engine/render/adaptor/include/render/adaptor/assets/ShaderGraphAsset.h new file mode 100644 index 00000000..cdeedc0f --- /dev/null +++ b/engine/render/adaptor/include/render/adaptor/assets/ShaderGraphAsset.h @@ -0,0 +1,32 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include + +namespace sky { + class JsonInputArchive; + class JsonOutputArchive; + + struct ShaderGraphAssetData { + uint32_t version = 1; + sg::ShaderGraph graph; + + void LoadJson(JsonInputArchive& archive); + void SaveJson(JsonOutputArchive& archive) const; + }; + + // Placeholder type tag for the shader graph asset + struct ShaderGraphAssetTag {}; + + template <> + struct AssetTraits { + using DataType = ShaderGraphAssetData; + static constexpr std::string_view ASSET_TYPE = "ShaderGraph"; + static constexpr SerializeType SERIALIZE_TYPE = SerializeType::JSON; + }; + +} // namespace sky diff --git a/engine/render/adaptor/src/assets/ShaderGraphAsset.cpp b/engine/render/adaptor/src/assets/ShaderGraphAsset.cpp new file mode 100644 index 00000000..e60591ad --- /dev/null +++ b/engine/render/adaptor/src/assets/ShaderGraphAsset.cpp @@ -0,0 +1,28 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include + +namespace sky { + + void ShaderGraphAssetData::LoadJson(JsonInputArchive& archive) + { + archive.LoadKeyValue("Version", version); + if (archive.Start("Graph")) { + graph.LoadJson(archive); + archive.End(); + } + } + + void ShaderGraphAssetData::SaveJson(JsonOutputArchive& archive) const + { + archive.StartObject(); + archive.SaveValueObject("Version", version); + archive.Key("Graph"); + graph.SaveJson(archive); + archive.EndObject(); + } + +} // namespace sky diff --git a/engine/render/editor/include/render/editor/ShaderGraphCreator.h b/engine/render/editor/include/render/editor/ShaderGraphCreator.h new file mode 100644 index 00000000..286cf5f5 --- /dev/null +++ b/engine/render/editor/include/render/editor/ShaderGraphCreator.h @@ -0,0 +1,21 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include + +namespace sky::editor { + + class ShaderGraphCreator : public AssetCreatorBase { + public: + ShaderGraphCreator() = default; + ~ShaderGraphCreator() override = default; + + private: + void CreateAsset(const FilePath& path) override; + std::string GetExtension() const override { return ".shadergraph"; } + }; + +} // namespace sky::editor diff --git a/engine/render/editor/include/render/editor/shadergraph/ShaderGraphEditWindow.h b/engine/render/editor/include/render/editor/shadergraph/ShaderGraphEditWindow.h new file mode 100644 index 00000000..b6f0ff21 --- /dev/null +++ b/engine/render/editor/include/render/editor/shadergraph/ShaderGraphEditWindow.h @@ -0,0 +1,19 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include + +namespace sky::editor { + + class ShaderGraphEditWindow : public IAssetPreviewWndFactory { + public: + ShaderGraphEditWindow() = default; + ~ShaderGraphEditWindow() override = default; + + bool SetupWidget(AssetPreviewWidget& widget, const AssetSourcePtr& src) override; + }; + +} // namespace sky::editor diff --git a/engine/render/editor/src/RenderEditorModule.cpp b/engine/render/editor/src/RenderEditorModule.cpp index 59aa96d2..8676a48c 100644 --- a/engine/render/editor/src/RenderEditorModule.cpp +++ b/engine/render/editor/src/RenderEditorModule.cpp @@ -6,13 +6,16 @@ #include #include #include +#include #include +#include #include #include #include #include +#include namespace sky::editor { @@ -26,6 +29,7 @@ namespace sky::editor { // asset AssetCreatorManager::Get()->RegisterTool(Name("Material"), new MaterialInstanceCreator()); AssetCreatorManager::Get()->RegisterTool(Name("Animation Graph"), new AnimationGraphCreator()); + AssetCreatorManager::Get()->RegisterTool(Name("Shader Graph"), new ShaderGraphCreator()); // create RegisterActorCreators(BuiltinGeometryType::CUBE); @@ -33,6 +37,7 @@ namespace sky::editor { // preview AssetPreviewManager::Get()->Register(AssetTraits::ASSET_TYPE, new SkeletonPreviewWindow()); AssetPreviewManager::Get()->Register(AssetTraits::ASSET_TYPE, new GraphEditWindow()); + AssetPreviewManager::Get()->Register(AssetTraits::ASSET_TYPE, new ShaderGraphEditWindow()); return true; } diff --git a/engine/render/editor/src/ShaderGraphCreator.cpp b/engine/render/editor/src/ShaderGraphCreator.cpp new file mode 100644 index 00000000..76ed32fd --- /dev/null +++ b/engine/render/editor/src/ShaderGraphCreator.cpp @@ -0,0 +1,30 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include +#include +#include + +namespace sky::editor { + + void ShaderGraphCreator::CreateAsset(const FilePath& path) + { + AssetSourcePath sourcePath = {}; + sourcePath.bundle = SourceAssetBundle::WORKSPACE; + sourcePath.path = path; + + auto file = AssetDataBase::Get()->CreateOrOpenFile(sourcePath); + + ShaderGraphAssetData data = {}; + data.version = 1; + + auto archive = file->WriteAsArchive(); + JsonOutputArchive json(*archive); + data.SaveJson(json); + + AssetDataBase::Get()->RegisterAsset(sourcePath); + } + +} // namespace sky::editor diff --git a/engine/render/editor/src/shadergraph/ShaderGraphEditWindow.cpp b/engine/render/editor/src/shadergraph/ShaderGraphEditWindow.cpp new file mode 100644 index 00000000..0b402097 --- /dev/null +++ b/engine/render/editor/src/shadergraph/ShaderGraphEditWindow.cpp @@ -0,0 +1,22 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include "ShaderGraphWidget.h" + +#include + +namespace sky::editor { + + bool ShaderGraphEditWindow::SetupWidget(AssetPreviewWidget& widget, const AssetSourcePtr& src) + { + auto file = AssetDataBase::Get()->OpenFile(src); + if (file) { + widget.SetWidget(new ShaderGraphWidget(file)); + return true; + } + return false; + } + +} // namespace sky::editor diff --git a/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.cpp b/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.cpp new file mode 100644 index 00000000..aca5623c --- /dev/null +++ b/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.cpp @@ -0,0 +1,201 @@ +// +// Created by blues on 2026/3/10. +// + +#include "ShaderGraphNodeModel.h" +#include +#include +#include +#include +#include +#include + +namespace sky::editor { + + // ---- SGNodeModel ---- + + SGNodeModel::SGNodeModel(sg::SGNodePtr n) : node(std::move(n)) + { + } + + QString SGNodeModel::caption() const + { + return QString::fromStdString(node->GetDisplayName()); + } + + QString SGNodeModel::name() const + { + return QString::fromStdString(node->GetTypeName()); + } + + unsigned int SGNodeModel::nPorts(QtNodes::PortType portType) const + { + if (portType == QtNodes::PortType::In) { + return static_cast(node->GetInputPins().size()); + } + return static_cast(node->GetOutputPins().size()); + } + + QtNodes::NodeDataType SGNodeModel::dataType(QtNodes::PortType portType, + QtNodes::PortIndex portIndex) const + { + const auto* pins = (portType == QtNodes::PortType::In) + ? &node->GetInputPins() + : &node->GetOutputPins(); + + if (portIndex < static_cast(pins->size())) { + const auto& pin = (*pins)[portIndex]; + return {sg::SGDataTypeToString(pin.type).c_str(), + pin.name.c_str()}; + } + return {"float", "Value"}; + } + + std::shared_ptr SGNodeModel::outData(QtNodes::PortIndex port) + { + const auto& outPins = node->GetOutputPins(); + if (port < static_cast(outPins.size())) { + return std::make_shared(outPins[port].type); + } + return nullptr; + } + + void SGNodeModel::setInData(std::shared_ptr /*data*/, + QtNodes::PortIndex /*portIndex*/) + { + // Data flow is handled at graph level; nothing to update here + } + + // ---- SGConstantFloatNodeModel ---- + + SGConstantFloatNodeModel::SGConstantFloatNodeModel() + : SGNodeModel(std::make_shared()) + { + } + + QWidget* SGConstantFloatNodeModel::embeddedWidget() + { + if (!spinBox) { + spinBox = new QDoubleSpinBox(); + spinBox->setRange(-1e6, 1e6); + spinBox->setDecimals(4); + spinBox->setValue(0.0); + spinBox->setFixedWidth(100); + + connect(spinBox, static_cast(&QDoubleSpinBox::valueChanged), + this, [this](double v) { + auto* n = static_cast(node.get()); + n->SetValue(static_cast(v)); + emit OnValueChanged(v); + }); + } + return spinBox; + } + + // ---- SGConstantVec3NodeModel ---- + + SGConstantVec3NodeModel::SGConstantVec3NodeModel() + : SGNodeModel(std::make_shared()) + { + } + + QWidget* SGConstantVec3NodeModel::embeddedWidget() + { + if (!container) { + container = new QWidget(); + auto* layout = new QVBoxLayout(container); + layout->setContentsMargins(0, 0, 0, 0); + layout->setSpacing(2); + + spinX = new QDoubleSpinBox(); + spinY = new QDoubleSpinBox(); + spinZ = new QDoubleSpinBox(); + + for (auto* spin : {spinX, spinY, spinZ}) { + spin->setRange(-1e6, 1e6); + spin->setDecimals(4); + spin->setValue(0.0); + spin->setFixedWidth(100); + layout->addWidget(spin); + } + + auto updateNode = [this]() { + auto* n = static_cast(node.get()); + n->SetValue(static_cast(spinX->value()), + static_cast(spinY->value()), + static_cast(spinZ->value())); + }; + + connect(spinX, static_cast(&QDoubleSpinBox::valueChanged), + this, [updateNode](double) { updateNode(); }); + connect(spinY, static_cast(&QDoubleSpinBox::valueChanged), + this, [updateNode](double) { updateNode(); }); + connect(spinZ, static_cast(&QDoubleSpinBox::valueChanged), + this, [updateNode](double) { updateNode(); }); + + container->setLayout(layout); + } + return container; + } + + // ---- SGScalarParamNodeModel ---- + + SGScalarParamNodeModel::SGScalarParamNodeModel() + : SGNodeModel(std::make_shared()) + { + } + + QWidget* SGScalarParamNodeModel::embeddedWidget() + { + if (!nameEdit) { + nameEdit = new QLineEdit("Param"); + nameEdit->setFixedWidth(100); + connect(nameEdit, &QLineEdit::textChanged, this, [this](const QString& text) { + auto* n = static_cast(node.get()); + n->SetParamName(text.toStdString()); + }); + } + return nameEdit; + } + + // ---- SGVectorParamNodeModel ---- + + SGVectorParamNodeModel::SGVectorParamNodeModel() + : SGNodeModel(std::make_shared()) + { + } + + QWidget* SGVectorParamNodeModel::embeddedWidget() + { + if (!nameEdit) { + nameEdit = new QLineEdit("VecParam"); + nameEdit->setFixedWidth(100); + connect(nameEdit, &QLineEdit::textChanged, this, [this](const QString& text) { + auto* n = static_cast(node.get()); + n->SetParamName(text.toStdString()); + }); + } + return nameEdit; + } + + // ---- SGTextureParamNodeModel ---- + + SGTextureParamNodeModel::SGTextureParamNodeModel() + : SGNodeModel(std::make_shared()) + { + } + + QWidget* SGTextureParamNodeModel::embeddedWidget() + { + if (!nameEdit) { + nameEdit = new QLineEdit("Texture"); + nameEdit->setFixedWidth(100); + connect(nameEdit, &QLineEdit::textChanged, this, [this](const QString& text) { + auto* n = static_cast(node.get()); + n->SetParamName(text.toStdString()); + }); + } + return nameEdit; + } + +} // namespace sky::editor diff --git a/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.h b/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.h new file mode 100644 index 00000000..36cb89b7 --- /dev/null +++ b/engine/render/editor/src/shadergraph/ShaderGraphNodeModel.h @@ -0,0 +1,159 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sky::editor { + + // ---- NodeData types ---- + + // Wraps a shader graph data type for QtNodes connections + class SGNodeData : public QtNodes::NodeData { + public: + explicit SGNodeData(sg::SGDataType type) : dataType(type) {} + + QtNodes::NodeDataType type() const override + { + return {sg::SGDataTypeToString(dataType).c_str(), + sg::SGDataTypeToHLSL(dataType).c_str()}; + } + + sg::SGDataType GetSGDataType() const { return dataType; } + + private: + sg::SGDataType dataType; + }; + + // ---- Base node model ---- + + // Wraps an SGNode to make it compatible with the QtNodes data-flow graph + class SGNodeModel : public QtNodes::NodeDelegateModel { + Q_OBJECT + public: + explicit SGNodeModel(sg::SGNodePtr node); + ~SGNodeModel() override = default; + + QString caption() const override; + bool captionVisible() const override { return true; } + QString name() const override; + + unsigned int nPorts(QtNodes::PortType portType) const override; + QtNodes::NodeDataType dataType(QtNodes::PortType portType, QtNodes::PortIndex portIndex) const override; + std::shared_ptr outData(QtNodes::PortIndex port) override; + void setInData(std::shared_ptr data, QtNodes::PortIndex portIndex) override; + + QWidget* embeddedWidget() override { return nullptr; } + + sg::SGNodePtr GetNode() const { return node; } + + protected: + sg::SGNodePtr node; + }; + + // Template helper: wraps any SGNode subtype as a default-constructible model + template + class SGConcreteNodeModel : public SGNodeModel { + public: + SGConcreteNodeModel() : SGNodeModel(std::make_shared()) {} + ~SGConcreteNodeModel() override = default; + }; + + // ---- Constant float node model ---- + + class SGConstantFloatNodeModel : public SGNodeModel { + Q_OBJECT + public: + SGConstantFloatNodeModel(); + ~SGConstantFloatNodeModel() override = default; + + QString name() const override { return QStringLiteral("ConstantFloat"); } + + QWidget* embeddedWidget() override; + + Q_SIGNALS: + void OnValueChanged(double value); // NOLINT + + private: + QDoubleSpinBox* spinBox = nullptr; + }; + + // ---- Constant Vec3 node model ---- + + class SGConstantVec3NodeModel : public SGNodeModel { + Q_OBJECT + public: + SGConstantVec3NodeModel(); + ~SGConstantVec3NodeModel() override = default; + + QString name() const override { return QStringLiteral("ConstantVec3"); } + + QWidget* embeddedWidget() override; + + private: + QWidget* container = nullptr; + QDoubleSpinBox* spinX = nullptr; + QDoubleSpinBox* spinY = nullptr; + QDoubleSpinBox* spinZ = nullptr; + }; + + // ---- Scalar parameter node model ---- + + class SGScalarParamNodeModel : public SGNodeModel { + Q_OBJECT + public: + SGScalarParamNodeModel(); + ~SGScalarParamNodeModel() override = default; + + QString name() const override { return QStringLiteral("ScalarParam"); } + + QWidget* embeddedWidget() override; + + private: + QLineEdit* nameEdit = nullptr; + }; + + // ---- Vector parameter node model ---- + + class SGVectorParamNodeModel : public SGNodeModel { + Q_OBJECT + public: + SGVectorParamNodeModel(); + ~SGVectorParamNodeModel() override = default; + + QString name() const override { return QStringLiteral("VectorParam"); } + + QWidget* embeddedWidget() override; + + private: + QLineEdit* nameEdit = nullptr; + }; + + // ---- Texture parameter node model ---- + + class SGTextureParamNodeModel : public SGNodeModel { + Q_OBJECT + public: + SGTextureParamNodeModel(); + ~SGTextureParamNodeModel() override = default; + + QString name() const override { return QStringLiteral("TextureParam"); } + + QWidget* embeddedWidget() override; + + private: + QLineEdit* nameEdit = nullptr; + }; + +} // namespace sky::editor + diff --git a/engine/render/editor/src/shadergraph/ShaderGraphWidget.cpp b/engine/render/editor/src/shadergraph/ShaderGraphWidget.cpp new file mode 100644 index 00000000..c7c92508 --- /dev/null +++ b/engine/render/editor/src/shadergraph/ShaderGraphWidget.cpp @@ -0,0 +1,82 @@ +// +// Created by blues on 2026/3/10. +// + +#include "ShaderGraphWidget.h" +#include "ShaderGraphNodeModel.h" + +#include +#include +#include + +#include +#include + +namespace sky::editor { + + // Register all shader graph node models for the visual editor + static std::shared_ptr RegisterShaderGraphModels() + { + auto ret = std::make_shared(); + + // Math + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + ret->registerModel>("Math"); + + // Input + ret->registerModel>("Input"); + ret->registerModel>("Input"); + ret->registerModel>("Input"); + ret->registerModel>("Input"); + ret->registerModel>("Input"); + + // Constants + ret->registerModel("Constant"); + ret->registerModel>("Constant"); + ret->registerModel("Constant"); + ret->registerModel>("Constant"); + + // Parameters + ret->registerModel("Parameter"); + ret->registerModel("Parameter"); + ret->registerModel("Parameter"); + + // Texture + ret->registerModel>("Texture"); + + // Output + ret->registerModel>("Output"); + + return ret; + } + + ShaderGraphWidget::ShaderGraphWidget(const FilePtr& source) + : registry(RegisterShaderGraphModels()) + , model(new QtNodes::DataFlowGraphModel(registry)) + , scene(new QtNodes::DataFlowGraphicsScene(*model)) + , view(new QtNodes::GraphicsView(scene)) + , asset(source) + { + auto* mainLayout = new QVBoxLayout(this); + mainLayout->addWidget(view); + setLayout(mainLayout); + setBaseSize(1200, 800); + + connect(scene, &QtNodes::DataFlowGraphicsScene::modified, this, [this]() { + setWindowModified(true); + }); + } + +} // namespace sky::editor diff --git a/engine/render/editor/src/shadergraph/ShaderGraphWidget.h b/engine/render/editor/src/shadergraph/ShaderGraphWidget.h new file mode 100644 index 00000000..72f790be --- /dev/null +++ b/engine/render/editor/src/shadergraph/ShaderGraphWidget.h @@ -0,0 +1,35 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace sky::editor { + + class ShaderGraphWidget : public AssetPreviewContentWidget { + Q_OBJECT + public: + explicit ShaderGraphWidget(const FilePtr& source); + + private: + void OnClose() override {} + + std::shared_ptr registry; + + QtNodes::DataFlowGraphModel* model; + QtNodes::DataFlowGraphicsScene* scene; + QtNodes::GraphicsView* view; + + FilePtr asset; + sg::ShaderGraph graph; + }; + +} // namespace sky::editor diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraph.h b/engine/render/shader/include/shader/shadergraph/ShaderGraph.h new file mode 100644 index 00000000..321497b4 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraph.h @@ -0,0 +1,65 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace sky::sg { + + // The shader graph – owns all nodes and the connections between them. + class ShaderGraph { + public: + ShaderGraph() = default; + ~ShaderGraph() = default; + + // Node management + void AddNode(const SGNodePtr& node); + void RemoveNode(const Uuid& nodeId); + SGNodePtr FindNode(const Uuid& nodeId) const; + const std::unordered_map& GetNodes() const { return nodes; } + + // Connection management + bool AddConnection(const SGConnection& conn); + void RemoveConnection(const SGConnection& conn); + const std::vector& GetConnections() const { return connections; } + + // Find which output pin feeds a given input pin (returns invalid SGPinID if unconnected) + SGPinID GetSourcePin(const SGPinID& inputPin) const; + + // HLSL code generation – returns the full shader function body and declarations + struct GeneratedCode { + std::string declarations; // textures, samplers, cbuffer params + std::string body; // the SurfaceShader() function body + }; + GeneratedCode GenerateHLSL() const; + + // Serialization + void LoadJson(JsonInputArchive& archive); + void SaveJson(JsonOutputArchive& archive) const; + + // Node registry: maps type name → factory + using FactoryFn = std::function; + static void RegisterNodeType(const std::string& typeName, FactoryFn factory); + static SGNodePtr CreateNodeByType(const std::string& typeName); + + private: + // Topological traversal helper + void CollectInputs(const SGNodePtr& node, + std::vector& inputVars, + SGCodeGenContext& ctx, + std::unordered_map>& cache) const; + + std::unordered_map nodes; + std::vector connections; + + static std::unordered_map& GetRegistry(); + }; + +} // namespace sky::sg diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphInputNodes.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphInputNodes.h new file mode 100644 index 00000000..42f785c1 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphInputNodes.h @@ -0,0 +1,246 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include + +namespace sky::sg { + + // TexCoord – outputs UV coordinates from vertex input + class SGTexCoordNode : public SGNode { + public: + explicit SGTexCoordNode(uint32_t uvIndex = 0); + ~SGTexCoordNode() override = default; + + std::string GetTypeName() const override { return "TexCoord"; } + std::string GetDisplayName() const override { return "Tex Coord"; } + + void SetUVIndex(uint32_t idx); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + uint32_t uvIndex = 0; + }; + + // VertexColor – outputs the per-vertex color (float4) + class SGVertexColorNode : public SGNode { + public: + SGVertexColorNode(); + ~SGVertexColorNode() override = default; + + std::string GetTypeName() const override { return "VertexColor"; } + std::string GetDisplayName() const override { return "Vertex Color"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // WorldPosition – world-space position of the current pixel + class SGWorldPositionNode : public SGNode { + public: + SGWorldPositionNode(); + ~SGWorldPositionNode() override = default; + + std::string GetTypeName() const override { return "WorldPosition"; } + std::string GetDisplayName() const override { return "World Position"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // WorldNormal – world-space normal vector + class SGWorldNormalNode : public SGNode { + public: + SGWorldNormalNode(); + ~SGWorldNormalNode() override = default; + + std::string GetTypeName() const override { return "WorldNormal"; } + std::string GetDisplayName() const override { return "World Normal"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Time – provides shader time value + class SGTimeNode : public SGNode { + public: + SGTimeNode(); + ~SGTimeNode() override = default; + + std::string GetTypeName() const override { return "Time"; } + std::string GetDisplayName() const override { return "Time"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // ConstantFloat – a literal float value + class SGConstantFloatNode : public SGNode { + public: + explicit SGConstantFloatNode(float value = 0.0f); + ~SGConstantFloatNode() override = default; + + std::string GetTypeName() const override { return "ConstantFloat"; } + std::string GetDisplayName() const override { return "Constant (float)"; } + + void SetValue(float v) { value = v; } + float GetValue() const { return value; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + float value = 0.0f; + }; + + // ConstantVec2 – a literal float2 value + class SGConstantVec2Node : public SGNode { + public: + explicit SGConstantVec2Node(float x = 0.0f, float y = 0.0f); + ~SGConstantVec2Node() override = default; + + std::string GetTypeName() const override { return "ConstantVec2"; } + std::string GetDisplayName() const override { return "Constant (float2)"; } + + void SetValue(float x, float y) { value[0] = x; value[1] = y; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::array value = {0.f, 0.f}; + }; + + // ConstantVec3 – a literal float3 value + class SGConstantVec3Node : public SGNode { + public: + explicit SGConstantVec3Node(float x = 0.0f, float y = 0.0f, float z = 0.0f); + ~SGConstantVec3Node() override = default; + + std::string GetTypeName() const override { return "ConstantVec3"; } + std::string GetDisplayName() const override { return "Constant (float3)"; } + + void SetValue(float x, float y, float z) { value[0] = x; value[1] = y; value[2] = z; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::array value = {0.f, 0.f, 0.f}; + }; + + // ConstantVec4 – a literal float4 value + class SGConstantVec4Node : public SGNode { + public: + explicit SGConstantVec4Node(float x = 0.0f, float y = 0.0f, float z = 0.0f, float w = 0.0f); + ~SGConstantVec4Node() override = default; + + std::string GetTypeName() const override { return "ConstantVec4"; } + std::string GetDisplayName() const override { return "Constant (float4)"; } + + void SetValue(float x, float y, float z, float w) { value[0] = x; value[1] = y; value[2] = z; value[3] = w; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::array value = {0.f, 0.f, 0.f, 0.f}; + }; + + // ScalarParam – an exposed scalar (float) material parameter + class SGScalarParamNode : public SGNode { + public: + explicit SGScalarParamNode(const std::string& paramName = "Param", float defaultVal = 0.0f); + ~SGScalarParamNode() override = default; + + std::string GetTypeName() const override { return "ScalarParam"; } + std::string GetDisplayName() const override { return "Scalar Parameter"; } + + void SetParamName(const std::string& n); + void SetDefaultValue(float v) { defaultValue = v; } + const std::string& GetParamName() const { return paramName; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::string paramName; + float defaultValue = 0.0f; + }; + + // VectorParam – an exposed vector (float4) material parameter + class SGVectorParamNode : public SGNode { + public: + explicit SGVectorParamNode(const std::string& paramName = "VecParam"); + ~SGVectorParamNode() override = default; + + std::string GetTypeName() const override { return "VectorParam"; } + std::string GetDisplayName() const override { return "Vector Parameter"; } + + void SetParamName(const std::string& n); + const std::string& GetParamName() const { return paramName; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::string paramName; + std::array defaultValue = {0.f, 0.f, 0.f, 1.f}; + }; + + // TextureParam – an exposed Texture2D material parameter + class SGTextureParamNode : public SGNode { + public: + explicit SGTextureParamNode(const std::string& paramName = "Texture"); + ~SGTextureParamNode() override = default; + + std::string GetTypeName() const override { return "TextureParam"; } + std::string GetDisplayName() const override { return "Texture Parameter"; } + + void SetParamName(const std::string& n); + const std::string& GetParamName() const { return paramName; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + std::string paramName; + }; + + // TextureSample – samples a Texture2D with UV coordinates + // Input 0: Texture2D, Input 1: SamplerState, Input 2: UV (float2) + // Output 0: RGBA (float4), Output 1: RGB (float3), Output 2: R (float) + class SGTextureSampleNode : public SGNode { + public: + SGTextureSampleNode(); + ~SGTextureSampleNode() override = default; + + std::string GetTypeName() const override { return "TextureSample"; } + std::string GetDisplayName() const override { return "Texture Sample"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + +} // namespace sky::sg diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphMathNodes.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphMathNodes.h new file mode 100644 index 00000000..db8449f4 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphMathNodes.h @@ -0,0 +1,261 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include + +namespace sky::sg { + + // ---- Binary operator base ---- + // All binary nodes follow the pattern: two same-typed inputs → one output of the same type. + class SGBinaryMathNode : public SGNode { + public: + explicit SGBinaryMathNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGBinaryMathNode() override = default; + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + + protected: + void GenerateBinary(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx, + const char* op) const; + + SGDataType dataType; + }; + + // Add: Output = A + B + class SGAddNode : public SGBinaryMathNode { + public: + explicit SGAddNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGAddNode() override = default; + + std::string GetTypeName() const override { return "Add"; } + std::string GetDisplayName() const override { return "Add"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Subtract: Output = A - B + class SGSubtractNode : public SGBinaryMathNode { + public: + explicit SGSubtractNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGSubtractNode() override = default; + + std::string GetTypeName() const override { return "Subtract"; } + std::string GetDisplayName() const override { return "Subtract"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Multiply: Output = A * B + class SGMultiplyNode : public SGBinaryMathNode { + public: + explicit SGMultiplyNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGMultiplyNode() override = default; + + std::string GetTypeName() const override { return "Multiply"; } + std::string GetDisplayName() const override { return "Multiply"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Divide: Output = A / B + class SGDivideNode : public SGBinaryMathNode { + public: + explicit SGDivideNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGDivideNode() override = default; + + std::string GetTypeName() const override { return "Divide"; } + std::string GetDisplayName() const override { return "Divide"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Lerp: Output = lerp(A, B, Alpha) + class SGLerpNode : public SGNode { + public: + explicit SGLerpNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGLerpNode() override = default; + + std::string GetTypeName() const override { return "Lerp"; } + std::string GetDisplayName() const override { return "Lerp"; } + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType dataType; + }; + + // Clamp: Output = clamp(Value, Min, Max) + class SGClampNode : public SGNode { + public: + explicit SGClampNode(SGDataType dataType = SGDataType::FLOAT); + ~SGClampNode() override = default; + + std::string GetTypeName() const override { return "Clamp"; } + std::string GetDisplayName() const override { return "Clamp"; } + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType dataType; + }; + + // Saturate: Output = saturate(Value) + class SGSaturateNode : public SGNode { + public: + explicit SGSaturateNode(SGDataType dataType = SGDataType::FLOAT); + ~SGSaturateNode() override = default; + + std::string GetTypeName() const override { return "Saturate"; } + std::string GetDisplayName() const override { return "Saturate"; } + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType dataType; + }; + + // Abs: Output = abs(Value) + class SGAbsNode : public SGNode { + public: + explicit SGAbsNode(SGDataType dataType = SGDataType::FLOAT); + ~SGAbsNode() override = default; + + std::string GetTypeName() const override { return "Abs"; } + std::string GetDisplayName() const override { return "Abs"; } + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType dataType; + }; + + // Power: Output = pow(Base, Exp) – both float + class SGPowerNode : public SGNode { + public: + SGPowerNode(); + ~SGPowerNode() override = default; + + std::string GetTypeName() const override { return "Power"; } + std::string GetDisplayName() const override { return "Power"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Dot: Output = dot(A, B) → float + class SGDotNode : public SGNode { + public: + explicit SGDotNode(SGDataType dataType = SGDataType::FLOAT3); + ~SGDotNode() override = default; + + std::string GetTypeName() const override { return "Dot"; } + std::string GetDisplayName() const override { return "Dot"; } + + void SetDataType(SGDataType dataType); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType dataType; + }; + + // Cross: Output = cross(A, B) → float3 + class SGCrossNode : public SGNode { + public: + SGCrossNode(); + ~SGCrossNode() override = default; + + std::string GetTypeName() const override { return "Cross"; } + std::string GetDisplayName() const override { return "Cross"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // Normalize: Output = normalize(V) → float3 + class SGNormalizeNode : public SGNode { + public: + SGNormalizeNode(); + ~SGNormalizeNode() override = default; + + std::string GetTypeName() const override { return "Normalize"; } + std::string GetDisplayName() const override { return "Normalize"; } + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + + // ComponentMask: extract selected channels from a vector + class SGComponentMaskNode : public SGNode { + public: + explicit SGComponentMaskNode(SGDataType inputType = SGDataType::FLOAT4); + ~SGComponentMaskNode() override = default; + + std::string GetTypeName() const override { return "ComponentMask"; } + std::string GetDisplayName() const override { return "Component Mask"; } + + void SetChannels(bool r, bool g, bool b, bool a); + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + void UpdateOutputPin(); + + bool maskR = true; + bool maskG = true; + bool maskB = true; + bool maskA = false; + SGDataType inputType; + }; + + // Append: combine two values into a larger vector + class SGAppendNode : public SGNode { + public: + explicit SGAppendNode(SGDataType aType = SGDataType::FLOAT3, SGDataType bType = SGDataType::FLOAT); + ~SGAppendNode() override = default; + + std::string GetTypeName() const override { return "Append"; } + std::string GetDisplayName() const override { return "Append"; } + + void LoadJson(JsonInputArchive& archive) override; + void SaveJson(JsonOutputArchive& archive) const override; + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + private: + SGDataType aType; + SGDataType bType; + }; + +} // namespace sky::sg diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphNode.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphNode.h new file mode 100644 index 00000000..b64e2794 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphNode.h @@ -0,0 +1,82 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include +#include +#include + +namespace sky { + class JsonInputArchive; + class JsonOutputArchive; +} + +namespace sky::sg { + + // Context passed to GenerateHLSL - tracks variable counter and accumulates declarations + struct SGCodeGenContext { + uint32_t varCounter = 0; + std::string declarations; // global resource declarations (textures, samplers, params) + std::string bodyCode; // per-pixel body code + + std::string NextVarName() { return "_sg_var" + std::to_string(varCounter++); } + }; + + // Base class for all shader graph nodes + class SGNode { + public: + SGNode(); + virtual ~SGNode() = default; + + const Uuid& GetId() const { return id; } + const std::string& GetName() const { return name; } + float GetPosX() const { return posX; } + float GetPosY() const { return posY; } + + void SetName(const std::string& n) { name = n; } + void SetPosition(float x, float y) { posX = x; posY = y; } + + const std::vector& GetInputPins() const { return inputPins; } + const std::vector& GetOutputPins() const { return outputPins; } + + // Returns the type name used for serialization / registry + virtual std::string GetTypeName() const = 0; + + // Returns a human-readable display name for the editor + virtual std::string GetDisplayName() const { return name; } + + // Generate HLSL for this node. + // inputVars: variable names (or expressions) for each input pin (parallel to inputPins). + // outputVars: filled with the variable names produced for each output pin. + // ctx: shared code-gen context. + virtual void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const = 0; + + virtual void LoadJson(JsonInputArchive& archive); + virtual void SaveJson(JsonOutputArchive& archive) const; + + protected: + Uuid id; + std::string name; + float posX = 0.0f; + float posY = 0.0f; + + std::vector inputPins; + std::vector outputPins; + }; + + using SGNodePtr = std::shared_ptr; + + // Factory interface for creating nodes by type name + class SGNodeFactory { + public: + virtual ~SGNodeFactory() = default; + virtual SGNodePtr Create() const = 0; + }; + +} // namespace sky::sg diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphOutputNode.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphOutputNode.h new file mode 100644 index 00000000..e8cbd84e --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphOutputNode.h @@ -0,0 +1,31 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include +#include + +namespace sky::sg { + + // MaterialOutputNode – the terminal node of the shader graph. + // Each input slot corresponds to a PBR material property. + // This node does not produce output variables; it writes to the surface output struct. + class SGMaterialOutputNode : public SGNode { + public: + SGMaterialOutputNode(); + ~SGMaterialOutputNode() override = default; + + std::string GetTypeName() const override { return "MaterialOutput"; } + std::string GetDisplayName() const override { return "Material Output"; } + + // GenerateHLSL writes surface assignments into ctx.bodyCode + void GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const override; + }; + +} // namespace sky::sg diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphPin.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphPin.h new file mode 100644 index 00000000..58211e83 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphPin.h @@ -0,0 +1,47 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include +#include + +namespace sky::sg { + + // Unique identifier for a pin on a specific node + struct SGPinID { + Uuid nodeId; + uint32_t pinIndex = 0; + + bool IsValid() const { return static_cast(nodeId); } + + bool operator==(const SGPinID& rhs) const + { + return nodeId == rhs.nodeId && pinIndex == rhs.pinIndex; + } + + bool operator!=(const SGPinID& rhs) const + { + return !(*this == rhs); + } + }; + + // A pin (input or output slot) on a shader graph node + struct SGPin { + std::string name; + SGDataType type = SGDataType::FLOAT; + SGPinDirection direction = SGPinDirection::INPUT; + }; + + // A connection between two pins in the shader graph + struct SGConnection { + SGPinID src; // output pin + SGPinID dst; // input pin + + bool IsValid() const { return src.IsValid() && dst.IsValid(); } + }; + +} // namespace sky::sg + diff --git a/engine/render/shader/include/shader/shadergraph/ShaderGraphTypes.h b/engine/render/shader/include/shader/shadergraph/ShaderGraphTypes.h new file mode 100644 index 00000000..a0527cb6 --- /dev/null +++ b/engine/render/shader/include/shader/shadergraph/ShaderGraphTypes.h @@ -0,0 +1,89 @@ +// +// Created by blues on 2026/3/10. +// + +#pragma once + +#include +#include + +namespace sky::sg { + + enum class SGDataType : uint8_t { + FLOAT, + FLOAT2, + FLOAT3, + FLOAT4, + TEXTURE2D, + SAMPLER_STATE, + }; + + enum class SGPinDirection : uint8_t { + INPUT, + OUTPUT, + }; + + inline std::string SGDataTypeToHLSL(SGDataType type) + { + switch (type) { + case SGDataType::FLOAT: return "float"; + case SGDataType::FLOAT2: return "float2"; + case SGDataType::FLOAT3: return "float3"; + case SGDataType::FLOAT4: return "float4"; + case SGDataType::TEXTURE2D: return "Texture2D"; + case SGDataType::SAMPLER_STATE: return "SamplerState"; + default: return "float"; + } + } + + inline std::string SGDataTypeToString(SGDataType type) + { + switch (type) { + case SGDataType::FLOAT: return "float"; + case SGDataType::FLOAT2: return "float2"; + case SGDataType::FLOAT3: return "float3"; + case SGDataType::FLOAT4: return "float4"; + case SGDataType::TEXTURE2D: return "texture2d"; + case SGDataType::SAMPLER_STATE: return "samplerState"; + default: return "float"; + } + } + + inline uint8_t SGDataTypeComponents(SGDataType type) + { + switch (type) { + case SGDataType::FLOAT: return 1; + case SGDataType::FLOAT2: return 2; + case SGDataType::FLOAT3: return 3; + case SGDataType::FLOAT4: return 4; + default: return 0; + } + } + + // Material output slots following PBR convention + enum class MaterialSlot : uint8_t { + BASE_COLOR, + METALLIC, + ROUGHNESS, + NORMAL, + EMISSIVE, + OPACITY, + OPACITY_MASK, + COUNT + }; + + inline std::string MaterialSlotToString(MaterialSlot slot) + { + switch (slot) { + case MaterialSlot::BASE_COLOR: return "BaseColor"; + case MaterialSlot::METALLIC: return "Metallic"; + case MaterialSlot::ROUGHNESS: return "Roughness"; + case MaterialSlot::NORMAL: return "Normal"; + case MaterialSlot::EMISSIVE: return "Emissive"; + case MaterialSlot::OPACITY: return "Opacity"; + case MaterialSlot::OPACITY_MASK: return "OpacityMask"; + default: return "Unknown"; + } + } + +} // namespace sky::sg diff --git a/engine/render/shader/src/shadergraph/ShaderGraph.cpp b/engine/render/shader/src/shadergraph/ShaderGraph.cpp new file mode 100644 index 00000000..ad489118 --- /dev/null +++ b/engine/render/shader/src/shadergraph/ShaderGraph.cpp @@ -0,0 +1,298 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include +#include +#include +#include +#include + +namespace sky::sg { + + // ---- Node registry ---- + + std::unordered_map& ShaderGraph::GetRegistry() + { + static std::unordered_map registry; + return registry; + } + + void ShaderGraph::RegisterNodeType(const std::string& typeName, FactoryFn factory) + { + GetRegistry()[typeName] = std::move(factory); + } + + SGNodePtr ShaderGraph::CreateNodeByType(const std::string& typeName) + { + auto& reg = GetRegistry(); + auto it = reg.find(typeName); + if (it != reg.end()) { + return it->second(); + } + return nullptr; + } + + // ---- Node management ---- + + void ShaderGraph::AddNode(const SGNodePtr& node) + { + if (node) { + nodes[node->GetId()] = node; + } + } + + void ShaderGraph::RemoveNode(const Uuid& nodeId) + { + nodes.erase(nodeId); + // Remove all connections touching this node + connections.erase( + std::remove_if(connections.begin(), connections.end(), + [&nodeId](const SGConnection& c) { + return c.src.nodeId == nodeId || c.dst.nodeId == nodeId; + }), + connections.end()); + } + + SGNodePtr ShaderGraph::FindNode(const Uuid& nodeId) const + { + auto it = nodes.find(nodeId); + return it != nodes.end() ? it->second : nullptr; + } + + // ---- Connection management ---- + + bool ShaderGraph::AddConnection(const SGConnection& conn) + { + if (!conn.IsValid()) { + return false; + } + + // Validate nodes and pins exist + auto srcNode = FindNode(conn.src.nodeId); + auto dstNode = FindNode(conn.dst.nodeId); + if (!srcNode || !dstNode) { + return false; + } + if (conn.src.pinIndex >= srcNode->GetOutputPins().size()) { + return false; + } + if (conn.dst.pinIndex >= dstNode->GetInputPins().size()) { + return false; + } + + // Each input pin may only have one incoming connection – remove any existing + connections.erase( + std::remove_if(connections.begin(), connections.end(), + [&conn](const SGConnection& c) { + return c.dst == conn.dst; + }), + connections.end()); + + connections.push_back(conn); + return true; + } + + void ShaderGraph::RemoveConnection(const SGConnection& conn) + { + connections.erase( + std::remove_if(connections.begin(), connections.end(), + [&conn](const SGConnection& c) { + return c.src == conn.src && c.dst == conn.dst; + }), + connections.end()); + } + + SGPinID ShaderGraph::GetSourcePin(const SGPinID& inputPin) const + { + for (const auto& conn : connections) { + if (conn.dst == inputPin) { + return conn.src; + } + } + return {}; + } + + // ---- HLSL code generation ---- + + void ShaderGraph::CollectInputs(const SGNodePtr& node, + std::vector& inputVars, + SGCodeGenContext& ctx, + std::unordered_map>& cache) const + { + const auto& inPins = node->GetInputPins(); + inputVars.resize(inPins.size()); + + for (size_t i = 0; i < inPins.size(); ++i) { + SGPinID inputPin{node->GetId(), static_cast(i)}; + SGPinID srcPin = GetSourcePin(inputPin); + + if (!srcPin.IsValid()) { + inputVars[i] = ""; // unconnected – node will use default + continue; + } + + auto srcNode = FindNode(srcPin.nodeId); + if (!srcNode) { + inputVars[i] = ""; + continue; + } + + // Check cache + auto cacheIt = cache.find(srcPin.nodeId); + if (cacheIt != cache.end()) { + const auto& outs = cacheIt->second; + inputVars[i] = srcPin.pinIndex < outs.size() ? outs[srcPin.pinIndex] : ""; + continue; + } + + // Recursively generate source node + std::vector srcInputs; + CollectInputs(srcNode, srcInputs, ctx, cache); + + std::vector srcOutputs; + srcNode->GenerateHLSL(srcInputs, srcOutputs, ctx); + cache[srcPin.nodeId] = srcOutputs; + + inputVars[i] = srcPin.pinIndex < srcOutputs.size() ? srcOutputs[srcPin.pinIndex] : ""; + } + } + + ShaderGraph::GeneratedCode ShaderGraph::GenerateHLSL() const + { + SGCodeGenContext ctx; + + // Find the material output node + SGNodePtr outputNode; + for (const auto& [id, node] : nodes) { + if (node->GetTypeName() == "MaterialOutput") { + outputNode = node; + break; + } + } + + if (!outputNode) { + return {}; + } + + std::unordered_map> cache; + std::vector inputVars; + CollectInputs(outputNode, inputVars, ctx, cache); + + std::vector outputVars; + outputNode->GenerateHLSL(inputVars, outputVars, ctx); + + return {ctx.declarations, ctx.bodyCode}; + } + + // ---- Serialization ---- + + void ShaderGraph::SaveJson(JsonOutputArchive& archive) const + { + archive.StartObject(); + + // Save nodes + archive.Key("nodes"); + archive.StartArray(); + for (const auto& [id, node] : nodes) { + archive.StartObject(); + node->SaveJson(archive); + archive.EndObject(); + } + archive.EndArray(); + + // Save connections + archive.Key("connections"); + archive.StartArray(); + for (const auto& conn : connections) { + archive.StartObject(); + archive.SaveValueObject("srcNode", conn.src.nodeId); + archive.Key("srcPin"); archive.SaveValue(conn.src.pinIndex); + archive.SaveValueObject("dstNode", conn.dst.nodeId); + archive.Key("dstPin"); archive.SaveValue(conn.dst.pinIndex); + archive.EndObject(); + } + archive.EndArray(); + + archive.EndObject(); + } + + void ShaderGraph::LoadJson(JsonInputArchive& archive) + { + nodes.clear(); + connections.clear(); + + uint32_t nodeCount = archive.StartArray("nodes"); + for (uint32_t i = 0; i < nodeCount; ++i) { + std::string typeName; + archive.LoadKeyValue("type", typeName); + + auto node = CreateNodeByType(typeName); + if (node) { + node->LoadJson(archive); + nodes[node->GetId()] = node; + } + archive.NextArrayElement(); + } + archive.End(); + + uint32_t connCount = archive.StartArray("connections"); + for (uint32_t i = 0; i < connCount; ++i) { + SGConnection conn; + archive.LoadKeyValue("srcNode", conn.src.nodeId); + archive.LoadKeyValue("srcPin", conn.src.pinIndex); + archive.LoadKeyValue("dstNode", conn.dst.nodeId); + archive.LoadKeyValue("dstPin", conn.dst.pinIndex); + connections.push_back(conn); + archive.NextArrayElement(); + } + archive.End(); + } + + // ---- Default node type registration ---- + + namespace { + struct SGNodeTypeRegistrar { + SGNodeTypeRegistrar() + { + // Math nodes + ShaderGraph::RegisterNodeType("Add", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Subtract", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Multiply", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Divide", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Lerp", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Clamp", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Saturate", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Abs", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Power", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Dot", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Cross", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Normalize", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ComponentMask", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Append", []() -> SGNodePtr { return std::make_shared(); }); + + // Input nodes + ShaderGraph::RegisterNodeType("TexCoord", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("VertexColor", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("WorldPosition", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("WorldNormal", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("Time", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ConstantFloat", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ConstantVec2", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ConstantVec3", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ConstantVec4", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("ScalarParam", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("VectorParam", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("TextureParam", []() -> SGNodePtr { return std::make_shared(); }); + ShaderGraph::RegisterNodeType("TextureSample", []() -> SGNodePtr { return std::make_shared(); }); + + // Output + ShaderGraph::RegisterNodeType("MaterialOutput", []() -> SGNodePtr { return std::make_shared(); }); + } + }; + + static SGNodeTypeRegistrar sRegistrar; + } + +} // namespace sky::sg diff --git a/engine/render/shader/src/shadergraph/ShaderGraphInputNodes.cpp b/engine/render/shader/src/shadergraph/ShaderGraphInputNodes.cpp new file mode 100644 index 00000000..8eebf82c --- /dev/null +++ b/engine/render/shader/src/shadergraph/ShaderGraphInputNodes.cpp @@ -0,0 +1,436 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include +#include + +namespace sky::sg { + + // ---- SGTexCoordNode ---- + + SGTexCoordNode::SGTexCoordNode(uint32_t idx) : uvIndex(idx) + { + name = "TexCoord"; + outputPins.push_back({"UV", SGDataType::FLOAT2, SGPinDirection::OUTPUT}); + } + + void SGTexCoordNode::SetUVIndex(uint32_t idx) + { + uvIndex = idx; + } + + void SGTexCoordNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + archive.LoadKeyValue("uvIndex", uvIndex); + } + + void SGTexCoordNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("uvIndex"); archive.SaveValue(uvIndex); + } + + void SGTexCoordNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + ctx.bodyCode += " float2 " + varName + " = input.UV" + (uvIndex > 0 ? std::to_string(uvIndex) : "") + ";\n"; + outputVars.push_back(varName); + } + + // ---- SGVertexColorNode ---- + + SGVertexColorNode::SGVertexColorNode() + { + name = "VertexColor"; + outputPins.push_back({"Color", SGDataType::FLOAT4, SGPinDirection::OUTPUT}); + } + + void SGVertexColorNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + ctx.bodyCode += " float4 " + varName + " = input.Color;\n"; + outputVars.push_back(varName); + } + + // ---- SGWorldPositionNode ---- + + SGWorldPositionNode::SGWorldPositionNode() + { + name = "WorldPosition"; + outputPins.push_back({"Position", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + } + + void SGWorldPositionNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + ctx.bodyCode += " float3 " + varName + " = input.WorldPos;\n"; + outputVars.push_back(varName); + } + + // ---- SGWorldNormalNode ---- + + SGWorldNormalNode::SGWorldNormalNode() + { + name = "WorldNormal"; + outputPins.push_back({"Normal", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + } + + void SGWorldNormalNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + ctx.bodyCode += " float3 " + varName + " = input.WorldNormal;\n"; + outputVars.push_back(varName); + } + + // ---- SGTimeNode ---- + + SGTimeNode::SGTimeNode() + { + name = "Time"; + outputPins.push_back({"Time", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"SinTime", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"CosTime", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGTimeNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string t = ctx.NextVarName(); + std::string st = ctx.NextVarName(); + std::string ct = ctx.NextVarName(); + ctx.bodyCode += " float " + t + " = _Time.y;\n"; + ctx.bodyCode += " float " + st + " = sin(_Time.y);\n"; + ctx.bodyCode += " float " + ct + " = cos(_Time.y);\n"; + outputVars.push_back(t); + outputVars.push_back(st); + outputVars.push_back(ct); + } + + // ---- SGConstantFloatNode ---- + + SGConstantFloatNode::SGConstantFloatNode(float v) : value(v) + { + name = "ConstantFloat"; + outputPins.push_back({"Value", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGConstantFloatNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + double v = 0.0; + archive.LoadKeyValue("value", v); + value = static_cast(v); + } + + void SGConstantFloatNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("value"); archive.SaveValue(static_cast(value)); + } + + void SGConstantFloatNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::ostringstream ss; + ss << " float " << varName << " = " << value << ";\n"; + ctx.bodyCode += ss.str(); + outputVars.push_back(varName); + } + + // ---- SGConstantVec2Node ---- + + SGConstantVec2Node::SGConstantVec2Node(float x, float y) : value({x, y}) + { + name = "ConstantVec2"; + outputPins.push_back({"Value", SGDataType::FLOAT2, SGPinDirection::OUTPUT}); + } + + void SGConstantVec2Node::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + double x = 0.0, y = 0.0; + archive.LoadKeyValue("x", x); + archive.LoadKeyValue("y", y); + value[0] = static_cast(x); + value[1] = static_cast(y); + } + + void SGConstantVec2Node::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("x"); archive.SaveValue(static_cast(value[0])); + archive.Key("y"); archive.SaveValue(static_cast(value[1])); + } + + void SGConstantVec2Node::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::ostringstream ss; + ss << " float2 " << varName << " = float2(" << value[0] << ", " << value[1] << ");\n"; + ctx.bodyCode += ss.str(); + outputVars.push_back(varName); + } + + // ---- SGConstantVec3Node ---- + + SGConstantVec3Node::SGConstantVec3Node(float x, float y, float z) : value({x, y, z}) + { + name = "ConstantVec3"; + outputPins.push_back({"Value", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + } + + void SGConstantVec3Node::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + double x = 0.0, y = 0.0, z = 0.0; + archive.LoadKeyValue("x", x); + archive.LoadKeyValue("y", y); + archive.LoadKeyValue("z", z); + value[0] = static_cast(x); + value[1] = static_cast(y); + value[2] = static_cast(z); + } + + void SGConstantVec3Node::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("x"); archive.SaveValue(static_cast(value[0])); + archive.Key("y"); archive.SaveValue(static_cast(value[1])); + archive.Key("z"); archive.SaveValue(static_cast(value[2])); + } + + void SGConstantVec3Node::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::ostringstream ss; + ss << " float3 " << varName << " = float3(" << value[0] << ", " << value[1] << ", " << value[2] << ");\n"; + ctx.bodyCode += ss.str(); + outputVars.push_back(varName); + } + + // ---- SGConstantVec4Node ---- + + SGConstantVec4Node::SGConstantVec4Node(float x, float y, float z, float w) : value({x, y, z, w}) + { + name = "ConstantVec4"; + outputPins.push_back({"Value", SGDataType::FLOAT4, SGPinDirection::OUTPUT}); + } + + void SGConstantVec4Node::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + double x = 0.0, y = 0.0, z = 0.0, w = 0.0; + archive.LoadKeyValue("x", x); + archive.LoadKeyValue("y", y); + archive.LoadKeyValue("z", z); + archive.LoadKeyValue("w", w); + value[0] = static_cast(x); + value[1] = static_cast(y); + value[2] = static_cast(z); + value[3] = static_cast(w); + } + + void SGConstantVec4Node::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("x"); archive.SaveValue(static_cast(value[0])); + archive.Key("y"); archive.SaveValue(static_cast(value[1])); + archive.Key("z"); archive.SaveValue(static_cast(value[2])); + archive.Key("w"); archive.SaveValue(static_cast(value[3])); + } + + void SGConstantVec4Node::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::ostringstream ss; + ss << " float4 " << varName << " = float4(" + << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << ");\n"; + ctx.bodyCode += ss.str(); + outputVars.push_back(varName); + } + + // ---- SGScalarParamNode ---- + + SGScalarParamNode::SGScalarParamNode(const std::string& pName, float defVal) + : paramName(pName), defaultValue(defVal) + { + name = "ScalarParam"; + outputPins.push_back({"Value", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGScalarParamNode::SetParamName(const std::string& n) + { + paramName = n; + } + + void SGScalarParamNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + archive.LoadKeyValue("paramName", paramName); + double dv = 0.0; + archive.LoadKeyValue("defaultValue", dv); + defaultValue = static_cast(dv); + } + + void SGScalarParamNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.SaveValueObject("paramName", paramName); + archive.Key("defaultValue"); archive.SaveValue(static_cast(defaultValue)); + } + + void SGScalarParamNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + // Emit a global cbuffer member declaration (once per unique param) + std::string decl = "float " + paramName + ";\n"; + if (ctx.declarations.find(decl) == std::string::npos) { + ctx.declarations += decl; + } + outputVars.push_back(paramName); + } + + // ---- SGVectorParamNode ---- + + SGVectorParamNode::SGVectorParamNode(const std::string& pName) : paramName(pName) + { + name = "VectorParam"; + outputPins.push_back({"Value", SGDataType::FLOAT4, SGPinDirection::OUTPUT}); + outputPins.push_back({"RGB", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + outputPins.push_back({"R", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"G", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"B", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"A", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGVectorParamNode::SetParamName(const std::string& n) + { + paramName = n; + } + + void SGVectorParamNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + archive.LoadKeyValue("paramName", paramName); + } + + void SGVectorParamNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.SaveValueObject("paramName", paramName); + } + + void SGVectorParamNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string decl = "float4 " + paramName + ";\n"; + if (ctx.declarations.find(decl) == std::string::npos) { + ctx.declarations += decl; + } + outputVars.push_back(paramName); + outputVars.push_back(paramName + ".rgb"); + outputVars.push_back(paramName + ".r"); + outputVars.push_back(paramName + ".g"); + outputVars.push_back(paramName + ".b"); + outputVars.push_back(paramName + ".a"); + } + + // ---- SGTextureParamNode ---- + + SGTextureParamNode::SGTextureParamNode(const std::string& pName) : paramName(pName) + { + name = "TextureParam"; + outputPins.push_back({"Texture", SGDataType::TEXTURE2D, SGPinDirection::OUTPUT}); + outputPins.push_back({"Sampler", SGDataType::SAMPLER_STATE, SGPinDirection::OUTPUT}); + } + + void SGTextureParamNode::SetParamName(const std::string& n) + { + paramName = n; + } + + void SGTextureParamNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + archive.LoadKeyValue("paramName", paramName); + } + + void SGTextureParamNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.SaveValueObject("paramName", paramName); + } + + void SGTextureParamNode::GenerateHLSL(const std::vector& /*inputVars*/, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string texDecl = "Texture2D " + paramName + ";\n"; + std::string samplerDecl = "SamplerState " + paramName + "Sampler;\n"; + if (ctx.declarations.find(texDecl) == std::string::npos) { + ctx.declarations += texDecl; + } + if (ctx.declarations.find(samplerDecl) == std::string::npos) { + ctx.declarations += samplerDecl; + } + outputVars.push_back(paramName); + outputVars.push_back(paramName + "Sampler"); + } + + // ---- SGTextureSampleNode ---- + + SGTextureSampleNode::SGTextureSampleNode() + { + name = "TextureSample"; + inputPins.push_back({"Texture", SGDataType::TEXTURE2D, SGPinDirection::INPUT}); + inputPins.push_back({"Sampler", SGDataType::SAMPLER_STATE, SGPinDirection::INPUT}); + inputPins.push_back({"UV", SGDataType::FLOAT2, SGPinDirection::INPUT}); + outputPins.push_back({"RGBA", SGDataType::FLOAT4, SGPinDirection::OUTPUT}); + outputPins.push_back({"RGB", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + outputPins.push_back({"R", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"G", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"B", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + outputPins.push_back({"A", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGTextureSampleNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string tex = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "_DefaultTex"; + std::string sampler = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "_DefaultSampler"; + std::string uv = (inputVars.size() > 2 && !inputVars[2].empty()) ? inputVars[2] : "float2(0,0)"; + + ctx.bodyCode += " float4 " + varName + " = " + tex + ".Sample(" + sampler + ", " + uv + ");\n"; + outputVars.push_back(varName); + outputVars.push_back(varName + ".rgb"); + outputVars.push_back(varName + ".r"); + outputVars.push_back(varName + ".g"); + outputVars.push_back(varName + ".b"); + outputVars.push_back(varName + ".a"); + } + +} // namespace sky::sg diff --git a/engine/render/shader/src/shadergraph/ShaderGraphMathNodes.cpp b/engine/render/shader/src/shadergraph/ShaderGraphMathNodes.cpp new file mode 100644 index 00000000..b4553a08 --- /dev/null +++ b/engine/render/shader/src/shadergraph/ShaderGraphMathNodes.cpp @@ -0,0 +1,531 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include +#include + +namespace sky::sg { + + // ---- SGBinaryMathNode ---- + + SGBinaryMathNode::SGBinaryMathNode(SGDataType type) : dataType(type) + { + SetDataType(type); + } + + void SGBinaryMathNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"A", type, SGPinDirection::INPUT}); + inputPins.push_back({"B", type, SGPinDirection::INPUT}); + outputPins.push_back({"Output", type, SGPinDirection::OUTPUT}); + } + + void SGBinaryMathNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGBinaryMathNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGBinaryMathNode::GenerateBinary(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx, + const char* op) const + { + std::string varName = ctx.NextVarName(); + std::string hlslType = SGDataTypeToHLSL(dataType); + + std::string a = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string b = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "0.0"; + + ctx.bodyCode += " " + hlslType + " " + varName + " = " + a + " " + op + " " + b + ";\n"; + outputVars.push_back(varName); + } + + // ---- SGAddNode ---- + + SGAddNode::SGAddNode(SGDataType type) : SGBinaryMathNode(type) { name = "Add"; } + + void SGAddNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + GenerateBinary(inputVars, outputVars, ctx, "+"); + } + + // ---- SGSubtractNode ---- + + SGSubtractNode::SGSubtractNode(SGDataType type) : SGBinaryMathNode(type) { name = "Subtract"; } + + void SGSubtractNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + GenerateBinary(inputVars, outputVars, ctx, "-"); + } + + // ---- SGMultiplyNode ---- + + SGMultiplyNode::SGMultiplyNode(SGDataType type) : SGBinaryMathNode(type) { name = "Multiply"; } + + void SGMultiplyNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + GenerateBinary(inputVars, outputVars, ctx, "*"); + } + + // ---- SGDivideNode ---- + + SGDivideNode::SGDivideNode(SGDataType type) : SGBinaryMathNode(type) { name = "Divide"; } + + void SGDivideNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + GenerateBinary(inputVars, outputVars, ctx, "/"); + } + + // ---- SGLerpNode ---- + + SGLerpNode::SGLerpNode(SGDataType type) : dataType(type) + { + name = "Lerp"; + SetDataType(type); + } + + void SGLerpNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"A", type, SGPinDirection::INPUT}); + inputPins.push_back({"B", type, SGPinDirection::INPUT}); + inputPins.push_back({"Alpha", SGDataType::FLOAT, SGPinDirection::INPUT}); + outputPins.push_back({"Output", type, SGPinDirection::OUTPUT}); + } + + void SGLerpNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGLerpNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGLerpNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string hlslType = SGDataTypeToHLSL(dataType); + std::string a = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string b = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "0.0"; + std::string alpha = (inputVars.size() > 2 && !inputVars[2].empty()) ? inputVars[2] : "0.5"; + + ctx.bodyCode += " " + hlslType + " " + varName + " = lerp(" + a + ", " + b + ", " + alpha + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGClampNode ---- + + SGClampNode::SGClampNode(SGDataType type) : dataType(type) + { + name = "Clamp"; + SetDataType(type); + } + + void SGClampNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"Value", type, SGPinDirection::INPUT}); + inputPins.push_back({"Min", type, SGPinDirection::INPUT}); + inputPins.push_back({"Max", type, SGPinDirection::INPUT}); + outputPins.push_back({"Output", type, SGPinDirection::OUTPUT}); + } + + void SGClampNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGClampNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGClampNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string hlslType = SGDataTypeToHLSL(dataType); + std::string val = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string mn = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "0.0"; + std::string mx = (inputVars.size() > 2 && !inputVars[2].empty()) ? inputVars[2] : "1.0"; + + ctx.bodyCode += " " + hlslType + " " + varName + " = clamp(" + val + ", " + mn + ", " + mx + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGSaturateNode ---- + + SGSaturateNode::SGSaturateNode(SGDataType type) : dataType(type) + { + name = "Saturate"; + SetDataType(type); + } + + void SGSaturateNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"Value", type, SGPinDirection::INPUT}); + outputPins.push_back({"Output", type, SGPinDirection::OUTPUT}); + } + + void SGSaturateNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGSaturateNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGSaturateNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string hlslType = SGDataTypeToHLSL(dataType); + std::string val = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + + ctx.bodyCode += " " + hlslType + " " + varName + " = saturate(" + val + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGAbsNode ---- + + SGAbsNode::SGAbsNode(SGDataType type) : dataType(type) + { + name = "Abs"; + SetDataType(type); + } + + void SGAbsNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"Value", type, SGPinDirection::INPUT}); + outputPins.push_back({"Output", type, SGPinDirection::OUTPUT}); + } + + void SGAbsNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGAbsNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGAbsNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string hlslType = SGDataTypeToHLSL(dataType); + std::string val = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + + ctx.bodyCode += " " + hlslType + " " + varName + " = abs(" + val + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGPowerNode ---- + + SGPowerNode::SGPowerNode() + { + name = "Power"; + inputPins.push_back({"Base", SGDataType::FLOAT, SGPinDirection::INPUT}); + inputPins.push_back({"Exp", SGDataType::FLOAT, SGPinDirection::INPUT}); + outputPins.push_back({"Output", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGPowerNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string base = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string exp = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "1.0"; + + ctx.bodyCode += " float " + varName + " = pow(" + base + ", " + exp + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGDotNode ---- + + SGDotNode::SGDotNode(SGDataType type) : dataType(type) + { + name = "Dot"; + SetDataType(type); + } + + void SGDotNode::SetDataType(SGDataType type) + { + dataType = type; + inputPins.clear(); + outputPins.clear(); + inputPins.push_back({"A", type, SGPinDirection::INPUT}); + inputPins.push_back({"B", type, SGPinDirection::INPUT}); + outputPins.push_back({"Output", SGDataType::FLOAT, SGPinDirection::OUTPUT}); + } + + void SGDotNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t dt = 0; + archive.LoadKeyValue("dataType", dt); + SetDataType(static_cast(dt)); + } + + void SGDotNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("dataType"); archive.SaveValue(static_cast(dataType)); + } + + void SGDotNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string a = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string b = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "0.0"; + + ctx.bodyCode += " float " + varName + " = dot(" + a + ", " + b + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGCrossNode ---- + + SGCrossNode::SGCrossNode() + { + name = "Cross"; + inputPins.push_back({"A", SGDataType::FLOAT3, SGPinDirection::INPUT}); + inputPins.push_back({"B", SGDataType::FLOAT3, SGPinDirection::INPUT}); + outputPins.push_back({"Output", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + } + + void SGCrossNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string a = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "float3(0,0,0)"; + std::string b = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "float3(0,0,0)"; + + ctx.bodyCode += " float3 " + varName + " = cross(" + a + ", " + b + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGNormalizeNode ---- + + SGNormalizeNode::SGNormalizeNode() + { + name = "Normalize"; + inputPins.push_back({"V", SGDataType::FLOAT3, SGPinDirection::INPUT}); + outputPins.push_back({"Output", SGDataType::FLOAT3, SGPinDirection::OUTPUT}); + } + + void SGNormalizeNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string v = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "float3(0,0,1)"; + + ctx.bodyCode += " float3 " + varName + " = normalize(" + v + ");\n"; + outputVars.push_back(varName); + } + + // ---- SGComponentMaskNode ---- + + SGComponentMaskNode::SGComponentMaskNode(SGDataType inType) : inputType(inType) + { + name = "ComponentMask"; + inputPins.push_back({"In", inType, SGPinDirection::INPUT}); + UpdateOutputPin(); + } + + void SGComponentMaskNode::SetChannels(bool r, bool g, bool b, bool a) + { + maskR = r; maskG = g; maskB = b; maskA = a; + UpdateOutputPin(); + } + + void SGComponentMaskNode::UpdateOutputPin() + { + uint8_t count = (maskR ? 1 : 0) + (maskG ? 1 : 0) + (maskB ? 1 : 0) + (maskA ? 1 : 0); + outputPins.clear(); + SGDataType outType = SGDataType::FLOAT; + switch (count) { + case 2: outType = SGDataType::FLOAT2; break; + case 3: outType = SGDataType::FLOAT3; break; + case 4: outType = SGDataType::FLOAT4; break; + default: outType = SGDataType::FLOAT; break; + } + outputPins.push_back({"Output", outType, SGPinDirection::OUTPUT}); + } + + void SGComponentMaskNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + bool r = true, g = true, b = true, a = false; + archive.LoadKeyValue("maskR", r); + archive.LoadKeyValue("maskG", g); + archive.LoadKeyValue("maskB", b); + archive.LoadKeyValue("maskA", a); + uint8_t dt = 0; + archive.LoadKeyValue("inputType", dt); + inputType = static_cast(dt); + inputPins[0].type = inputType; + SetChannels(r, g, b, a); + } + + void SGComponentMaskNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("maskR"); archive.SaveValue(maskR); + archive.Key("maskG"); archive.SaveValue(maskG); + archive.Key("maskB"); archive.SaveValue(maskB); + archive.Key("maskA"); archive.SaveValue(maskA); + archive.Key("inputType"); archive.SaveValue(static_cast(inputType)); + } + + void SGComponentMaskNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string inVal = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + + std::string swizzle; + if (maskR) swizzle += 'r'; + if (maskG) swizzle += 'g'; + if (maskB) swizzle += 'b'; + if (maskA) swizzle += 'a'; + + if (swizzle.empty()) swizzle = "r"; + + uint8_t count = static_cast(swizzle.size()); + SGDataType outType = SGDataType::FLOAT; + switch (count) { + case 2: outType = SGDataType::FLOAT2; break; + case 3: outType = SGDataType::FLOAT3; break; + case 4: outType = SGDataType::FLOAT4; break; + default: outType = SGDataType::FLOAT; break; + } + + ctx.bodyCode += " " + SGDataTypeToHLSL(outType) + " " + varName + " = " + inVal + "." + swizzle + ";\n"; + outputVars.push_back(varName); + } + + // ---- SGAppendNode ---- + + SGAppendNode::SGAppendNode(SGDataType aT, SGDataType bT) : aType(aT), bType(bT) + { + name = "Append"; + inputPins.push_back({"A", aT, SGPinDirection::INPUT}); + inputPins.push_back({"B", bT, SGPinDirection::INPUT}); + + uint8_t totalComps = SGDataTypeComponents(aT) + SGDataTypeComponents(bT); + SGDataType outType = SGDataType::FLOAT; + switch (totalComps) { + case 2: outType = SGDataType::FLOAT2; break; + case 3: outType = SGDataType::FLOAT3; break; + case 4: outType = SGDataType::FLOAT4; break; + default: outType = SGDataType::FLOAT4; break; + } + outputPins.push_back({"Output", outType, SGPinDirection::OUTPUT}); + } + + void SGAppendNode::LoadJson(JsonInputArchive& archive) + { + SGNode::LoadJson(archive); + uint8_t at = 0, bt = 0; + archive.LoadKeyValue("aType", at); + archive.LoadKeyValue("bType", bt); + aType = static_cast(at); + bType = static_cast(bt); + inputPins[0].type = aType; + inputPins[1].type = bType; + } + + void SGAppendNode::SaveJson(JsonOutputArchive& archive) const + { + SGNode::SaveJson(archive); + archive.Key("aType"); archive.SaveValue(static_cast(aType)); + archive.Key("bType"); archive.SaveValue(static_cast(bType)); + } + + void SGAppendNode::GenerateHLSL(const std::vector& inputVars, + std::vector& outputVars, + SGCodeGenContext& ctx) const + { + std::string varName = ctx.NextVarName(); + std::string a = (inputVars.size() > 0 && !inputVars[0].empty()) ? inputVars[0] : "0.0"; + std::string b = (inputVars.size() > 1 && !inputVars[1].empty()) ? inputVars[1] : "0.0"; + + uint8_t totalComps = SGDataTypeComponents(aType) + SGDataTypeComponents(bType); + SGDataType outType = SGDataType::FLOAT; + switch (totalComps) { + case 2: outType = SGDataType::FLOAT2; break; + case 3: outType = SGDataType::FLOAT3; break; + case 4: outType = SGDataType::FLOAT4; break; + default: outType = SGDataType::FLOAT4; break; + } + + ctx.bodyCode += " " + SGDataTypeToHLSL(outType) + " " + varName + + " = " + SGDataTypeToHLSL(outType) + "(" + a + ", " + b + ");\n"; + outputVars.push_back(varName); + } + +} // namespace sky::sg diff --git a/engine/render/shader/src/shadergraph/ShaderGraphNode.cpp b/engine/render/shader/src/shadergraph/ShaderGraphNode.cpp new file mode 100644 index 00000000..535eb52e --- /dev/null +++ b/engine/render/shader/src/shadergraph/ShaderGraphNode.cpp @@ -0,0 +1,31 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include + +namespace sky::sg { + + SGNode::SGNode() : id(Uuid::Create()) + { + } + + void SGNode::LoadJson(JsonInputArchive& archive) + { + archive.LoadKeyValue("id", id); + archive.LoadKeyValue("name", name); + archive.LoadKeyValue("posX", posX); + archive.LoadKeyValue("posY", posY); + } + + void SGNode::SaveJson(JsonOutputArchive& archive) const + { + archive.SaveValueObject("type", GetTypeName()); + archive.SaveValueObject("id", id); + archive.SaveValueObject("name", name); + archive.Key("posX"); archive.SaveValue(posX); + archive.Key("posY"); archive.SaveValue(posY); + } + +} // namespace sky::sg diff --git a/engine/render/shader/src/shadergraph/ShaderGraphOutputNode.cpp b/engine/render/shader/src/shadergraph/ShaderGraphOutputNode.cpp new file mode 100644 index 00000000..7ae9fdff --- /dev/null +++ b/engine/render/shader/src/shadergraph/ShaderGraphOutputNode.cpp @@ -0,0 +1,39 @@ +// +// Created by blues on 2026/3/10. +// + +#include +#include + +namespace sky::sg { + + SGMaterialOutputNode::SGMaterialOutputNode() + { + name = "MaterialOutput"; + // Input pins match PBR material slots (following MaterialSlot enum order) + inputPins.push_back({"BaseColor", SGDataType::FLOAT3, SGPinDirection::INPUT}); + inputPins.push_back({"Metallic", SGDataType::FLOAT, SGPinDirection::INPUT}); + inputPins.push_back({"Roughness", SGDataType::FLOAT, SGPinDirection::INPUT}); + inputPins.push_back({"Normal", SGDataType::FLOAT3, SGPinDirection::INPUT}); + inputPins.push_back({"Emissive", SGDataType::FLOAT3, SGPinDirection::INPUT}); + inputPins.push_back({"Opacity", SGDataType::FLOAT, SGPinDirection::INPUT}); + inputPins.push_back({"OpacityMask", SGDataType::FLOAT, SGPinDirection::INPUT}); + // No output pins – this is the terminal node + } + + void SGMaterialOutputNode::GenerateHLSL(const std::vector& inputVars, + std::vector& /*outputVars*/, + SGCodeGenContext& ctx) const + { + static const char* SLOTS[] = { + "BaseColor", "Metallic", "Roughness", "Normal", "Emissive", "Opacity", "OpacityMask" + }; + + for (size_t i = 0; i < inputPins.size(); ++i) { + if (i < inputVars.size() && !inputVars[i].empty()) { + ctx.bodyCode += " surface." + std::string(SLOTS[i]) + " = " + inputVars[i] + ";\n"; + } + } + } + +} // namespace sky::sg