diff --git a/dag.go b/dag.go index c75918b..fea0f94 100644 --- a/dag.go +++ b/dag.go @@ -690,10 +690,10 @@ func (d *DAG) getDescendants(vHash interface{}) map[interface{}]struct{} { if children, ok := d.outboundEdge[vHash]; ok { // for each child use a goroutine to collect its descendants - //var waitGroup sync.WaitGroup - //waitGroup.Add(len(children)) + // var waitGroup sync.WaitGroup + // waitGroup.Add(len(children)) for child := range children { - //go func(child interface{}, mu *sync.Mutex, cache map[interface{}]bool) { + // go func(child interface{}, mu *sync.Mutex, cache map[interface{}]bool) { childDescendants := d.getDescendants(child) mu.Lock() for descendant := range childDescendants { @@ -701,10 +701,10 @@ func (d *DAG) getDescendants(vHash interface{}) map[interface{}]struct{} { } cache[child] = struct{}{} mu.Unlock() - //waitGroup.Done() - //}(child, &mu, cache) + // waitGroup.Done() + // }(child, &mu, cache) } - //waitGroup.Wait() + // waitGroup.Wait() } // remember the collected descendents @@ -943,8 +943,16 @@ func (d *DAG) DescendantsFlow(startID string, inputs []FlowResult, callback Flow return []FlowResult{}, errPar } + parentsInFlow := map[string]bool{} + for parentID, _ := range parents { + _, ok := flowIDs[parentID] + if ok || parentID == startID { + parentsInFlow[parentID] = true + } + } + // Create a buffered input channel that has capacity for all parent results. - inputChannels[id] = make(chan FlowResult, len(parents)) + inputChannels[id] = make(chan FlowResult, len(parentsInFlow)) if d.isLeaf(id) { leafCount += 1 diff --git a/dag_test.go b/dag_test.go index 7fd3a1c..f383279 100644 --- a/dag_test.go +++ b/dag_test.go @@ -2,10 +2,11 @@ package dag import ( "fmt" - "github.com/go-test/deep" "sort" "strconv" "testing" + + "github.com/go-test/deep" ) type iVertex struct{ value int } @@ -1331,6 +1332,85 @@ func TestDAG_DescendantsFlowOneNode(t *testing.T) { } } +// TestDAG_DescendantsFlowMultipleRoots validate if running DescendantsFlow on dag with multiple roots completes and return with right results +func TestDAG_DescendantsFlowMultipleRoots(t *testing.T) { + // Initialize a new graph. + d := NewDAG() + + // Init vertices. + v0, _ := d.AddVertex(0) + v1, _ := d.AddVertex(1) + v2, _ := d.AddVertex(2) + v3, _ := d.AddVertex(3) + v4, _ := d.AddVertex(4) + v5, _ := d.AddVertex(5) + v6, _ := d.AddVertex(6) + v7, _ := d.AddVertex(7) + v8, _ := d.AddVertex(8) + + // Add the above vertices and connect them. + _ = d.AddEdge(v0, v2) + _ = d.AddEdge(v1, v2) + _ = d.AddEdge(v1, v5) + _ = d.AddEdge(v2, v3) + _ = d.AddEdge(v2, v4) + _ = d.AddEdge(v3, v6) + _ = d.AddEdge(v5, v7) + _ = d.AddEdge(v4, v8) + _ = d.AddEdge(v7, v8) + + // 0 1 + // |───┌─────────| + // │ │ + // ┌─── 2 ──┐ │ + // │ │ │ + // 3 4 5 + // │ │ │ + // 6 │ 7 + // └──┬──┘ + // 8 + + // The callback function adds its own value (ID) to the sum of parent results. + flowCallback := func(d *DAG, id string, parentResults []FlowResult) (interface{}, error) { + + v, _ := d.GetVertex(id) + result, _ := v.(int) + var parents []int + for _, r := range parentResults { + p, _ := d.GetVertex(r.ID) + parents = append(parents, p.(int)) + result += r.Result.(int) + } + sort.Ints(parents) + fmt.Printf("%v based on: %+v returns: %d\n", v, parents, result) + return result, nil + } + + res, _ := d.DescendantsFlow(v0, nil, flowCallback) + validResults := map[string]int{ + v6: 11, + v8: 14, + } + + // check result only has 2 items from node v6 and v8 + if len(res) != 2 { + t.Errorf("DescendantsFlow() result count mismatch | got = %d, want 2", len(res)) + } + + // check and confirm only v6 and v8 are in the results with the expected value + for _, r := range res { + v, ok := validResults[r.ID] + if !ok { + t.Errorf("DescendantsFlow() vertex should not be part of result | id = %s", r.ID) + continue + } + result := r.Result.(int) + if v != result { + t.Errorf("DescendantsFlow() vertex result mismatch | got = %d, want %d", v, result) + } + } +} + func largeAux(d *DAG, level int, branches int, parent iVertex) (int, int) { var vertexCount int var edgeCount int