diff --git a/executor.go b/executor.go index a583a89..8d5e1ce 100644 --- a/executor.go +++ b/executor.go @@ -448,6 +448,11 @@ func (e *Executor) parseTaskFlagsIntoMap(taskName string, flags []string) map[st taskFlag := e.taskFlagsRegistry[taskName][key] flagValErrMsg := fmt.Sprintf("The type for the `%s` flag is `%s`. Please use `%s`", key, taskFlag.ValueType, flagTypeToGetter[taskFlag.ValueType]) + // Bare boolean flags (e.g. --dry-run without =value) mean true. + if val == "" && taskFlag.ValueType == BoolTypeFlag { + val = "true" + } + // If no val was passed, use the default flagArg that is already in the taskFlagsMap. if val != "" { flagArg := FlagArg{ diff --git a/executor_internal_test.go b/executor_internal_test.go index b2feece..de3bdd1 100644 --- a/executor_internal_test.go +++ b/executor_internal_test.go @@ -100,7 +100,8 @@ func TestParseTaskOptionsListToMap(t *testing.T) { expectedFlagCVal *int expectedFlagDVal *float64 expectedFlagEVal *time.Duration - expectFlagFValIsNil bool + expectedFlagFVal *bool + expectFlagFNil bool invalidFlagPanicVal string }{ { @@ -146,9 +147,9 @@ func TestParseTaskOptionsListToMap(t *testing.T) { expectedFlagEVal: durationPtr(time.Duration(11 * time.Hour)), }, { - description: "Should return nil if no arg passed to variable flag and no Default defined", - flagArgs: []string{"-f"}, - expectFlagFValIsNil: true, + description: "Bare bool flag without =value should be treated as true", + flagArgs: []string{"-f"}, + expectedFlagFVal: boolPtr(true), }, { description: "Should store Default values for all flags and Value nil when no args are passed", @@ -158,7 +159,7 @@ func TestParseTaskOptionsListToMap(t *testing.T) { expectedFlagBVal: stringPtr("presidential dolphin"), expectedFlagCVal: intPtr(10), expectedFlagDVal: float64Ptr(1.72), - expectFlagFValIsNil: true, + expectFlagFNil: true, }, { description: "Should panic if there are multiple `=`s in the flag", @@ -280,9 +281,11 @@ func TestParseTaskOptionsListToMap(t *testing.T) { } } - if tc.expectFlagFValIsNil { - flagArg := flagsMap["f"] - assert.Nil(t, flagArg.BoolVal()) + if tc.expectedFlagFVal != nil { + assert.Equal(t, *tc.expectedFlagFVal, *flagsMap["f"].BoolVal()) + } + if tc.expectFlagFNil { + assert.Nil(t, flagsMap["f"].BoolVal()) } } })