diff --git a/yq.go b/yq.go index 88cb4917..c573b2ee 100644 --- a/yq.go +++ b/yq.go @@ -221,9 +221,9 @@ func read(args []string) (interface{}, error) { path = args[1] } - if err := readData(args[0], &parsedData); err != nil { + if err := readData(args[0], docIndex, &parsedData); err != nil { var generalData interface{} - if err = readData(args[0], &generalData); err != nil { + if err = readData(args[0], docIndex, &generalData); err != nil { return nil, err } item := yaml.MapItem{Key: "thing", Value: generalData} @@ -233,7 +233,7 @@ func read(args []string) (interface{}, error) { if parsedData != nil && parsedData[0].Key == nil { var interfaceData []map[interface{}]interface{} - if err := readData(args[0], &interfaceData); err == nil { + if err := readData(args[0], docIndex, &interfaceData); err == nil { var listMap []yaml.MapSlice for _, item := range interfaceData { listMap = append(listMap, mapToMapSlice(item)) @@ -299,7 +299,7 @@ func newProperty(cmd *cobra.Command, args []string) error { func newYaml(args []string) (interface{}, error) { var writeCommands yaml.MapSlice if writeScript != "" { - if err := readData(writeScript, &writeCommands); err != nil { + if err := readData(writeScript, 0, &writeCommands); err != nil { return nil, err } } else if len(args) < 2 { @@ -365,9 +365,9 @@ func deleteYaml(args []string) (interface{}, error) { deletePath = args[1] - if err := readData(args[0], &parsedData); err != nil { + if err := readData(args[0], 0, &parsedData); err != nil { var generalData interface{} - if err = readData(args[0], &generalData); err != nil { + if err = readData(args[0], 0, &generalData); err != nil { return nil, err } item := yaml.MapItem{Key: "thing", Value: generalData} @@ -396,7 +396,7 @@ func mergeYaml(args []string) (interface{}, error) { for _, f := range args { var parsedData map[interface{}]interface{} - if err := readData(f, &parsedData); err != nil { + if err := readData(f, 0, &parsedData); err != nil { return nil, err } if err := merge(&updatedData, parsedData, overwriteFlag); err != nil { @@ -428,7 +428,7 @@ func updateYaml(args []string) (interface{}, error) { var writeCommands yaml.MapSlice var prependCommand = "" if writeScript != "" { - if err := readData(writeScript, &writeCommands); err != nil { + if err := readData(writeScript, 0, &writeCommands); err != nil { return nil, err } } else if len(args) < 3 { @@ -439,9 +439,9 @@ func updateYaml(args []string) (interface{}, error) { } var parsedData yaml.MapSlice - if err := readData(args[0], &parsedData); err != nil { + if err := readData(args[0], 0, &parsedData); err != nil { var generalData interface{} - if err = readData(args[0], &generalData); err != nil { + if err = readData(args[0], 0, &generalData); err != nil { return nil, err } item := yaml.MapItem{Key: "thing", Value: generalData} @@ -512,7 +512,9 @@ func safelyCloseFile(file *os.File) { } } -func readData(filename string, parsedData interface{}) error { +type yamlDecoderFn func(*yaml.Decoder) error + +func readStream(filename string, yamlDecoder yamlDecoderFn) error { if filename == "" { return errors.New("Must provide filename") } @@ -528,15 +530,19 @@ func readData(filename string, parsedData interface{}) error { defer safelyCloseFile(file) stream = file } - - var decoder = yaml.NewDecoder(stream) - // naive implementation of document indexing, decodes all the yaml documents - // before the docIndex and throws them away. - for currentIndex := 0; currentIndex < docIndex; currentIndex++ { - errorSkipping := decoder.Decode(parsedData) - if errorSkipping != nil { - return fmt.Errorf("Error processing document at index %v, %v", currentIndex, errorSkipping) - } - } - return decoder.Decode(parsedData) + return yamlDecoder(yaml.NewDecoder(stream)) +} + +func readData(filename string, indexToRead int, parsedData interface{}) error { + return readStream(filename, func(decoder *yaml.Decoder) error { + // naive implementation of document indexing, decodes all the yaml documents + // before the docIndex and throws them away. + for currentIndex := 0; currentIndex < indexToRead; currentIndex++ { + errorSkipping := decoder.Decode(parsedData) + if errorSkipping != nil { + return fmt.Errorf("Error processing document at index %v, %v", currentIndex, errorSkipping) + } + } + return decoder.Decode(parsedData) + }) }