Refactored initCommand

This commit is contained in:
Mike Farah 2025-07-11 10:17:21 +10:00
parent 369fe56e2d
commit 9b299649f7
2 changed files with 610 additions and 54 deletions

View File

@ -18,52 +18,100 @@ func isAutomaticOutputFormat() bool {
func initCommand(cmd *cobra.Command, args []string) (string, []string, error) { func initCommand(cmd *cobra.Command, args []string) (string, []string, error) {
cmd.SilenceUsage = true cmd.SilenceUsage = true
fileInfo, _ := os.Stdout.Stat() setupColors()
if forceColor || (!forceNoColor && (fileInfo.Mode()&os.ModeCharDevice) != 0) {
colorsEnabled = true
}
expression, args, err := processArgs(args) expression, args, err := processArgs(args)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if err := loadSplitFileExpression(); err != nil {
return "", nil, err
}
handleBackwardsCompatibility()
if err := validateCommandFlags(args); err != nil {
return "", nil, err
}
if err := configureFormats(args); err != nil {
return "", nil, err
}
configureUnwrapScalar()
return expression, args, nil
}
func setupColors() {
fileInfo, _ := os.Stdout.Stat()
if forceColor || (!forceNoColor && (fileInfo.Mode()&os.ModeCharDevice) != 0) {
colorsEnabled = true
}
}
func loadSplitFileExpression() error {
if splitFileExpFile != "" { if splitFileExpFile != "" {
splitExpressionBytes, err := os.ReadFile(splitFileExpFile) splitExpressionBytes, err := os.ReadFile(splitFileExpFile)
if err != nil { if err != nil {
return "", nil, err return err
} }
splitFileExp = string(splitExpressionBytes) splitFileExp = string(splitExpressionBytes)
} }
return nil
}
func handleBackwardsCompatibility() {
// backwards compatibility // backwards compatibility
if outputToJSON { if outputToJSON {
outputFormat = "json" outputFormat = "json"
} }
}
func validateCommandFlags(args []string) error {
if writeInplace && (len(args) == 0 || args[0] == "-") { if writeInplace && (len(args) == 0 || args[0] == "-") {
return "", nil, fmt.Errorf("write in place flag only applicable when giving an expression and at least one file") return fmt.Errorf("write in place flag only applicable when giving an expression and at least one file")
} }
if frontMatter != "" && len(args) == 0 { if frontMatter != "" && len(args) == 0 {
return "", nil, fmt.Errorf("front matter flag only applicable when giving an expression and at least one file") return fmt.Errorf("front matter flag only applicable when giving an expression and at least one file")
} }
if writeInplace && splitFileExp != "" { if writeInplace && splitFileExp != "" {
return "", nil, fmt.Errorf("write in place cannot be used with split file") return fmt.Errorf("write in place cannot be used with split file")
} }
if nullInput && len(args) > 0 { if nullInput && len(args) > 0 {
return "", nil, fmt.Errorf("cannot pass files in when using null-input flag") return fmt.Errorf("cannot pass files in when using null-input flag")
} }
return nil
}
func configureFormats(args []string) error {
inputFilename := "" inputFilename := ""
if len(args) > 0 { if len(args) > 0 {
inputFilename = args[0] inputFilename = args[0]
} }
if inputFormat == "" || inputFormat == "auto" || inputFormat == "a" {
if err := configureInputFormat(inputFilename); err != nil {
return err
}
if err := configureOutputFormat(); err != nil {
return err
}
yqlib.GetLogger().Debug("Using input format %v", inputFormat)
yqlib.GetLogger().Debug("Using output format %v", outputFormat)
return nil
}
func configureInputFormat(inputFilename string) error {
if inputFormat == "" || inputFormat == "auto" || inputFormat == "a" {
inputFormat = yqlib.FormatStringFromFilename(inputFilename) inputFormat = yqlib.FormatStringFromFilename(inputFilename)
_, err := yqlib.FormatFromString(inputFormat) _, err := yqlib.FormatFromString(inputFormat)
@ -88,24 +136,27 @@ func initCommand(cmd *cobra.Command, args []string) (string, []string, error) {
} }
outputFormat = "yaml" outputFormat = "yaml"
} }
return nil
}
func configureOutputFormat() error {
outputFormatType, err := yqlib.FormatFromString(outputFormat) outputFormatType, err := yqlib.FormatFromString(outputFormat)
if err != nil { if err != nil {
return "", nil, err return err
} }
yqlib.GetLogger().Debug("Using input format %v", inputFormat)
yqlib.GetLogger().Debug("Using output format %v", outputFormat)
if outputFormatType == yqlib.YamlFormat || if outputFormatType == yqlib.YamlFormat ||
outputFormatType == yqlib.PropertiesFormat { outputFormatType == yqlib.PropertiesFormat {
unwrapScalar = true unwrapScalar = true
} }
return nil
}
func configureUnwrapScalar() {
if unwrapScalarFlag.IsExplicitlySet() { if unwrapScalarFlag.IsExplicitlySet() {
unwrapScalar = unwrapScalarFlag.IsSet() unwrapScalar = unwrapScalarFlag.IsSet()
} }
return expression, args, nil
} }
func configureDecoder(evaluateTogether bool) (yqlib.Decoder, error) { func configureDecoder(evaluateTogether bool) (yqlib.Decoder, error) {

View File

@ -514,43 +514,6 @@ func TestInitCommand(t *testing.T) {
expectError: true, expectError: true,
errorContains: "write in place flag only applicable when giving an expression and at least one file", errorContains: "write in place flag only applicable when giving an expression and at least one file",
}, },
{
name: "write inplace with dash",
args: []string{"-"},
writeInplace: true,
frontMatter: "",
nullInput: false,
expectError: true,
errorContains: "write in place flag only applicable when giving an expression and at least one file",
},
{
name: "front matter with no args",
args: []string{},
writeInplace: false,
frontMatter: "extract",
nullInput: false,
expectError: true,
errorContains: "front matter flag only applicable when giving an expression and at least one file",
},
{
name: "write inplace with split file",
args: []string{tempFile.Name()},
writeInplace: true,
frontMatter: "",
nullInput: false,
splitFileExp: ".a.b",
expectError: true,
errorContains: "write in place cannot be used with split file",
},
{
name: "null input with args",
args: []string{tempFile.Name()},
writeInplace: false,
frontMatter: "",
nullInput: true,
expectError: true,
errorContains: "cannot pass files in when using null-input flag",
},
{ {
name: "split file expression from file", name: "split file expression from file",
args: []string{tempFile.Name()}, args: []string{tempFile.Name()},
@ -916,6 +879,32 @@ func TestConfigureEncoderWithPropertiesFormat(t *testing.T) {
} }
} }
// Mock boolFlag for testing
type mockBoolFlag struct {
explicitlySet bool
value bool
}
func (f *mockBoolFlag) IsExplicitlySet() bool {
return f.explicitlySet
}
func (f *mockBoolFlag) IsSet() bool {
return f.value
}
func (f *mockBoolFlag) String() string {
return "mock"
}
func (f *mockBoolFlag) Set(_ string) error {
return nil
}
func (f *mockBoolFlag) Type() string {
return "bool"
}
// Helper function to compare string slices // Helper function to compare string slices
func stringsEqual(a, b []string) bool { func stringsEqual(a, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
@ -928,3 +917,519 @@ func stringsEqual(a, b []string) bool {
} }
return true return true
} }
func TestSetupColors(t *testing.T) {
tests := []struct {
name string
forceColor bool
forceNoColor bool
expectColors bool
}{
{
name: "force color enabled",
forceColor: true,
forceNoColor: false,
expectColors: true,
},
{
name: "force no color enabled",
forceColor: false,
forceNoColor: true,
expectColors: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original values
originalForceColor := forceColor
originalForceNoColor := forceNoColor
originalColorsEnabled := colorsEnabled
defer func() {
forceColor = originalForceColor
forceNoColor = originalForceNoColor
colorsEnabled = originalColorsEnabled
}()
forceColor = tt.forceColor
forceNoColor = tt.forceNoColor
colorsEnabled = false // Reset to test the setting
setupColors()
if colorsEnabled != tt.expectColors {
t.Errorf("setupColors() colorsEnabled = %v, want %v", colorsEnabled, tt.expectColors)
}
})
}
}
func TestLoadSplitFileExpression(t *testing.T) {
// Create a temporary file with expression content
tempFile, err := os.CreateTemp("", "split")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempFile.Name())
if _, err = tempFile.WriteString(".a.b"); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tempFile.Close()
tests := []struct {
name string
splitFileExpFile string
expectError bool
expectContent string
}{
{
name: "load from file",
splitFileExpFile: tempFile.Name(),
expectError: false,
expectContent: ".a.b",
},
{
name: "no file specified",
splitFileExpFile: "",
expectError: false,
expectContent: "",
},
{
name: "non-existent file",
splitFileExpFile: "/path/that/does/not/exist",
expectError: true,
expectContent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original value
originalSplitFileExpFile := splitFileExpFile
originalSplitFileExp := splitFileExp
defer func() {
splitFileExpFile = originalSplitFileExpFile
splitFileExp = originalSplitFileExp
}()
splitFileExpFile = tt.splitFileExpFile
splitFileExp = ""
err := loadSplitFileExpression()
if tt.expectError {
if err == nil {
t.Errorf("loadSplitFileExpression() expected error but got none")
}
return
}
if err != nil {
t.Errorf("loadSplitFileExpression() unexpected error: %v", err)
return
}
if splitFileExp != tt.expectContent {
t.Errorf("loadSplitFileExpression() splitFileExp = %v, want %v", splitFileExp, tt.expectContent)
}
})
}
}
func TestHandleBackwardsCompatibility(t *testing.T) {
tests := []struct {
name string
outputToJSON bool
initialFormat string
expectFormat string
}{
{
name: "outputToJSON true",
outputToJSON: true,
initialFormat: "yaml",
expectFormat: "json",
},
{
name: "outputToJSON false",
outputToJSON: false,
initialFormat: "yaml",
expectFormat: "yaml",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original value
originalOutputToJSON := outputToJSON
originalOutputFormat := outputFormat
defer func() {
outputToJSON = originalOutputToJSON
outputFormat = originalOutputFormat
}()
outputToJSON = tt.outputToJSON
outputFormat = tt.initialFormat
handleBackwardsCompatibility()
if outputFormat != tt.expectFormat {
t.Errorf("handleBackwardsCompatibility() outputFormat = %v, want %v", outputFormat, tt.expectFormat)
}
})
}
}
func TestValidateCommandFlags(t *testing.T) {
tests := []struct {
name string
args []string
writeInplace bool
frontMatter string
splitFileExp string
nullInput bool
expectError bool
errorContains string
}{
{
name: "valid flags",
args: []string{"file.yaml"},
writeInplace: false,
frontMatter: "",
splitFileExp: "",
nullInput: false,
expectError: false,
},
{
name: "write inplace with no args",
args: []string{},
writeInplace: true,
frontMatter: "",
splitFileExp: "",
nullInput: false,
expectError: true,
errorContains: "write in place flag only applicable when giving an expression and at least one file",
},
{
name: "write inplace with dash",
args: []string{"-"},
writeInplace: true,
frontMatter: "",
splitFileExp: "",
nullInput: false,
expectError: true,
errorContains: "write in place flag only applicable when giving an expression and at least one file",
},
{
name: "front matter with no args",
args: []string{},
writeInplace: false,
frontMatter: "extract",
splitFileExp: "",
nullInput: false,
expectError: true,
errorContains: "front matter flag only applicable when giving an expression and at least one file",
},
{
name: "write inplace with split file",
args: []string{"file.yaml"},
writeInplace: true,
frontMatter: "",
splitFileExp: ".a.b",
nullInput: false,
expectError: true,
errorContains: "write in place cannot be used with split file",
},
{
name: "null input with args",
args: []string{"file.yaml"},
writeInplace: false,
frontMatter: "",
splitFileExp: "",
nullInput: true,
expectError: true,
errorContains: "cannot pass files in when using null-input flag",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original values
originalWriteInplace := writeInplace
originalFrontMatter := frontMatter
originalSplitFileExp := splitFileExp
originalNullInput := nullInput
defer func() {
writeInplace = originalWriteInplace
frontMatter = originalFrontMatter
splitFileExp = originalSplitFileExp
nullInput = originalNullInput
}()
writeInplace = tt.writeInplace
frontMatter = tt.frontMatter
splitFileExp = tt.splitFileExp
nullInput = tt.nullInput
err := validateCommandFlags(tt.args)
if tt.expectError {
if err == nil {
t.Errorf("validateCommandFlags() expected error but got none")
return
}
if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("validateCommandFlags() error '%v' does not contain '%v'", err.Error(), tt.errorContains)
}
return
}
if err != nil {
t.Errorf("validateCommandFlags() unexpected error: %v", err)
}
})
}
}
func TestConfigureFormats(t *testing.T) {
tests := []struct {
name string
args []string
inputFormat string
outputFormat string
expectError bool
}{
{
name: "valid formats",
args: []string{"file.yaml"},
inputFormat: "auto",
outputFormat: "auto",
expectError: false,
},
{
name: "invalid output format",
args: []string{"file.yaml"},
inputFormat: "auto",
outputFormat: "invalid",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original values
originalInputFormat := inputFormat
originalOutputFormat := outputFormat
defer func() {
inputFormat = originalInputFormat
outputFormat = originalOutputFormat
}()
inputFormat = tt.inputFormat
outputFormat = tt.outputFormat
err := configureFormats(tt.args)
if tt.expectError {
if err == nil {
t.Errorf("configureFormats() expected error but got none")
}
return
}
if err != nil {
t.Errorf("configureFormats() unexpected error: %v", err)
}
})
}
}
func TestConfigureInputFormat(t *testing.T) {
tests := []struct {
name string
inputFilename string
inputFormat string
outputFormat string
expectInput string
expectOutput string
}{
{
name: "auto format with yaml file",
inputFilename: "file.yaml",
inputFormat: "auto",
outputFormat: "auto",
expectInput: "yaml",
expectOutput: "yaml",
},
{
name: "auto format with json file",
inputFilename: "file.json",
inputFormat: "auto",
outputFormat: "auto",
expectInput: "json",
expectOutput: "json",
},
{
name: "auto format with unknown file",
inputFilename: "file.unknown",
inputFormat: "auto",
outputFormat: "auto",
expectInput: "yaml",
expectOutput: "yaml",
},
{
name: "explicit format",
inputFilename: "file.yaml",
inputFormat: "json",
outputFormat: "auto",
expectInput: "json",
expectOutput: "yaml", // backwards compatibility
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original values
originalInputFormat := inputFormat
originalOutputFormat := outputFormat
defer func() {
inputFormat = originalInputFormat
outputFormat = originalOutputFormat
}()
inputFormat = tt.inputFormat
outputFormat = tt.outputFormat
err := configureInputFormat(tt.inputFilename)
if err != nil {
t.Errorf("configureInputFormat() unexpected error: %v", err)
return
}
if inputFormat != tt.expectInput {
t.Errorf("configureInputFormat() inputFormat = %v, want %v", inputFormat, tt.expectInput)
}
if outputFormat != tt.expectOutput {
t.Errorf("configureInputFormat() outputFormat = %v, want %v", outputFormat, tt.expectOutput)
}
})
}
}
func TestConfigureOutputFormat(t *testing.T) {
tests := []struct {
name string
outputFormat string
expectError bool
expectUnwrap bool
}{
{
name: "yaml format",
outputFormat: "yaml",
expectError: false,
expectUnwrap: true,
},
{
name: "properties format",
outputFormat: "properties",
expectError: false,
expectUnwrap: true,
},
{
name: "json format",
outputFormat: "json",
expectError: false,
expectUnwrap: false,
},
{
name: "invalid format",
outputFormat: "invalid",
expectError: true,
expectUnwrap: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original values
originalOutputFormat := outputFormat
originalUnwrapScalar := unwrapScalar
defer func() {
outputFormat = originalOutputFormat
unwrapScalar = originalUnwrapScalar
}()
outputFormat = tt.outputFormat
unwrapScalar = false // Reset to test the setting
err := configureOutputFormat()
if tt.expectError {
if err == nil {
t.Errorf("configureOutputFormat() expected error but got none")
}
return
}
if err != nil {
t.Errorf("configureOutputFormat() unexpected error: %v", err)
return
}
if unwrapScalar != tt.expectUnwrap {
t.Errorf("configureOutputFormat() unwrapScalar = %v, want %v", unwrapScalar, tt.expectUnwrap)
}
})
}
}
func TestConfigureUnwrapScalar(t *testing.T) {
tests := []struct {
name string
explicitlySet bool
flagValue bool
initialUnwrap bool
expectUnwrap bool
}{
{
name: "flag not explicitly set",
explicitlySet: false,
flagValue: true,
initialUnwrap: true,
expectUnwrap: true, // Should remain unchanged
},
{
name: "flag explicitly set to true",
explicitlySet: true,
flagValue: true,
initialUnwrap: false,
expectUnwrap: true,
},
{
name: "flag explicitly set to false",
explicitlySet: true,
flagValue: false,
initialUnwrap: true,
expectUnwrap: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save original value
originalUnwrapScalar := unwrapScalar
originalUnwrapScalarFlag := unwrapScalarFlag
defer func() {
unwrapScalar = originalUnwrapScalar
unwrapScalarFlag = originalUnwrapScalarFlag
}()
unwrapScalar = tt.initialUnwrap
unwrapScalarFlag = &mockBoolFlag{
explicitlySet: tt.explicitlySet,
value: tt.flagValue,
}
configureUnwrapScalar()
if unwrapScalar != tt.expectUnwrap {
t.Errorf("configureUnwrapScalar() unwrapScalar = %v, want %v", unwrapScalar, tt.expectUnwrap)
}
})
}
}