Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -690,21 +690,21 @@ func (d *DAG) getDescendants(vHash interface{}) map[interface{}]struct{} {
if children, ok := d.outboundEdge[vHash]; ok {

// for each child use a goroutine to collect its descendants
//var waitGroup sync.WaitGroup
//waitGroup.Add(len(children))
// var waitGroup sync.WaitGroup
// waitGroup.Add(len(children))
for child := range children {
//go func(child interface{}, mu *sync.Mutex, cache map[interface{}]bool) {
// go func(child interface{}, mu *sync.Mutex, cache map[interface{}]bool) {
childDescendants := d.getDescendants(child)
mu.Lock()
for descendant := range childDescendants {
cache[descendant] = struct{}{}
}
cache[child] = struct{}{}
mu.Unlock()
//waitGroup.Done()
//}(child, &mu, cache)
// waitGroup.Done()
// }(child, &mu, cache)
}
//waitGroup.Wait()
// waitGroup.Wait()
}

// remember the collected descendents
Expand Down Expand Up @@ -943,8 +943,16 @@ func (d *DAG) DescendantsFlow(startID string, inputs []FlowResult, callback Flow
return []FlowResult{}, errPar
}

parentsInFlow := map[string]bool{}
for parentID, _ := range parents {
_, ok := flowIDs[parentID]
if ok || parentID == startID {
parentsInFlow[parentID] = true
}
}

// Create a buffered input channel that has capacity for all parent results.
inputChannels[id] = make(chan FlowResult, len(parents))
inputChannels[id] = make(chan FlowResult, len(parentsInFlow))

if d.isLeaf(id) {
leafCount += 1
Expand Down
82 changes: 81 additions & 1 deletion dag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package dag

import (
"fmt"
"github.com/go-test/deep"
"sort"
"strconv"
"testing"

"github.com/go-test/deep"
)

type iVertex struct{ value int }
Expand Down Expand Up @@ -1331,6 +1332,85 @@ func TestDAG_DescendantsFlowOneNode(t *testing.T) {
}
}

// TestDAG_DescendantsFlowMultipleRoots validate if running DescendantsFlow on dag with multiple roots completes and return with right results
func TestDAG_DescendantsFlowMultipleRoots(t *testing.T) {
// Initialize a new graph.
d := NewDAG()

// Init vertices.
v0, _ := d.AddVertex(0)
v1, _ := d.AddVertex(1)
v2, _ := d.AddVertex(2)
v3, _ := d.AddVertex(3)
v4, _ := d.AddVertex(4)
v5, _ := d.AddVertex(5)
v6, _ := d.AddVertex(6)
v7, _ := d.AddVertex(7)
v8, _ := d.AddVertex(8)

// Add the above vertices and connect them.
_ = d.AddEdge(v0, v2)
_ = d.AddEdge(v1, v2)
_ = d.AddEdge(v1, v5)
_ = d.AddEdge(v2, v3)
_ = d.AddEdge(v2, v4)
_ = d.AddEdge(v3, v6)
_ = d.AddEdge(v5, v7)
_ = d.AddEdge(v4, v8)
_ = d.AddEdge(v7, v8)

// 0 1
// |───┌─────────|
// │ │
// ┌─── 2 ──┐ │
// │ │ │
// 3 4 5
// │ │ │
// 6 │ 7
// └──┬──┘
// 8

// The callback function adds its own value (ID) to the sum of parent results.
flowCallback := func(d *DAG, id string, parentResults []FlowResult) (interface{}, error) {

v, _ := d.GetVertex(id)
result, _ := v.(int)
var parents []int
for _, r := range parentResults {
p, _ := d.GetVertex(r.ID)
parents = append(parents, p.(int))
result += r.Result.(int)
}
sort.Ints(parents)
fmt.Printf("%v based on: %+v returns: %d\n", v, parents, result)
return result, nil
}

res, _ := d.DescendantsFlow(v0, nil, flowCallback)
validResults := map[string]int{
v6: 11,
v8: 14,
}

// check result only has 2 items from node v6 and v8
if len(res) != 2 {
t.Errorf("DescendantsFlow() result count mismatch | got = %d, want 2", len(res))
}

// check and confirm only v6 and v8 are in the results with the expected value
for _, r := range res {
v, ok := validResults[r.ID]
if !ok {
t.Errorf("DescendantsFlow() vertex should not be part of result | id = %s", r.ID)
continue
}
result := r.Result.(int)
if v != result {
t.Errorf("DescendantsFlow() vertex result mismatch | got = %d, want %d", v, result)
}
}
}

func largeAux(d *DAG, level int, branches int, parent iVertex) (int, int) {
var vertexCount int
var edgeCount int
Expand Down