diff --git a/README.md b/README.md index ff27ce9..bbc730a 100644 --- a/README.md +++ b/README.md @@ -135,42 +135,35 @@ In the future, additional flags may be added to allow you to include tables, add ## Plugins -If you would like to implement your own `valueFuncs`, you can do so by writing a ripoff plugin. - -Plugins are local unauthenticated TCP servers that consume and emit newline-separated JSON messages from ripoff. +If you would like to implement your own `valueFuncs`, you can do so by writing a ripoff plugin, which is a local TCP server that sends/recieves JSON. ### Writing a plugin -Plugins must listen to a local TCP port and provide a TCP stream (loop of receiving and sending messages) to clients. +Plugins must meet the following requirements: -On startup, plugins must output the string `READY` in its first line of output to indicate to ripoff that it is ready to receive TCP messges. +- Listen to a local TCP port +- Consume newline-separated JSON messages, which come in as a stream +- Output newline-separated JSON responses +- Ouput `READY` in the first line of standard output when the plugin is ready for TCP connections -Each incoming message will be a single line of JSON in the following types: +Each incoming message will be a single line of JSON of the following shapes: -#### Return a value +#### valueFunc Your plugin must process an arbitrary `valueFunc` and return a string value. You can decide how to handle functions you do not expect/provide, by either returning an empty value or disconnecting the client. +The `id` field is used to support unordered stream messages, so you can return responses at any time and in any order as long as they have the same `id` as the relevant request. + Message from ripoff: ```json -{"type": "valueFunc", "valueFunc": "someFuncName", "args": ["some", "argument", "list"]} +{"id": "some-id", "type": "valueFunc", "valueFunc": "someFuncName", "args": ["some", "argument", "list"]} ``` Response from your TCP server: ```json -{"value": "someString"} -``` - -#### Exit your process - -Ripoff will send a kill signal to your process, but if you'd like to clean up before that an exit message will be sent beforehand. - -Request message: - -```json -{"type": "exit"} +{"id": "the-same-id-from-the-request", "value": "someString"} ``` #### Example @@ -179,7 +172,9 @@ An example plugin can be found at `cmd/helloplugin/helloplugin.go`. although TCP ### Using a plugin -Plugins are defined in your ripoff files, which instruct ripoff to spawn a process to start your TCP server, then later connect to it with a single TCP stream. Here's an example from ripoff's tests: +Plugins are defined in your ripoff files, which instruct ripoff to spawn a process to start your TCP server. + +Here's an example from ripoff's tests: ```yml # A list of plugins to register with ripoff. diff --git a/cmd/helloplugin/helloplugin.go b/cmd/helloplugin/helloplugin.go index fec5cad..2135bb9 100644 --- a/cmd/helloplugin/helloplugin.go +++ b/cmd/helloplugin/helloplugin.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "net" - "os" ) func main() { @@ -33,12 +32,14 @@ func main() { } type Request struct { + Id string `json:"id"` Type string `json:"type"` ValueFunc string `json:"valueFunc"` Args []string `json:"args"` } type Response struct { + Id string `json:"id"` Value string `json:"value"` } @@ -59,10 +60,6 @@ func handleConnection(conn net.Conn) { log.Println("Error parsing body:", err) return } - if r.Type == "exit" { - os.Exit(0) - return - } if len(r.Args) == 0 { log.Println("No args provided") return @@ -78,6 +75,7 @@ func handleConnection(conn net.Conn) { return } resp, err := json.Marshal(Response{ + Id: r.Id, Value: value, }) if err != nil { diff --git a/cmd/ripoff/ripoff.go b/cmd/ripoff/ripoff.go index aedbc47..08d38ae 100644 --- a/cmd/ripoff/ripoff.go +++ b/cmd/ripoff/ripoff.go @@ -1,12 +1,15 @@ package main import ( + "bufio" "context" "flag" "fmt" "log/slog" "os" "path" + "slices" + "strings" "github.com/jackc/pgx/v5" @@ -17,9 +20,54 @@ func errAttr(err error) slog.Attr { return slog.Any("error", err) } +func confirmPluginsSafe(plugins map[string]ripoff.RipoffPlugin) { + baseDir, err := os.UserHomeDir() + if err != nil { + baseDir = os.TempDir() + } + consentFilePath := path.Join(baseDir, ".ripoff-consent") + consentFile, err := os.ReadFile(consentFilePath) + if err != nil && !os.IsNotExist(err) { + slog.Error("Could not read from consent file", errAttr(err), slog.String("filepath", consentFilePath)) + } + consentFileLines := strings.Split(string(consentFile), "\n") + scanner := bufio.NewScanner(os.Stdin) + newConsentLines := []string{} + for _, plugin := range plugins { + cmdJoined := strings.Join(append([]string{plugin.Address, " -> "}, plugin.Command...), " ") + if !slices.Contains(consentFileLines, cmdJoined) { + newConsentLines = append(newConsentLines, cmdJoined) + } + } + if len(newConsentLines) > 0 { + fmt.Printf("You have not run these ripoff plugins before, please confirm that the following commands are safe to run on your machine: \n") + fmt.Println() + for _, consentLine := range newConsentLines { + fmt.Printf(" %s\n", consentLine) + } + fmt.Println() + fmt.Println("Run the above? (Y/N)") + scanner.Scan() + input := scanner.Text() + if input == "y" || input == "Y" { + consentFileLines = append(consentFileLines, newConsentLines...) + err = os.WriteFile(consentFilePath, []byte(strings.Join(consentFileLines, "\n")), 0644) + if err != nil { + slog.Error("Could not append to the consent file", errAttr(err), slog.String("filepath", consentFilePath)) + } + fmt.Println("Proceeding...") + } else { + fmt.Println("ABORT") + os.Exit(1) + } + } +} + func main() { verbosePtr := flag.Bool("v", false, "enable verbose output") softPtr := flag.Bool("s", false, "do not commit generated queries") + maxConcurrencyPtr := flag.Int("c", ripoff.DEFAULT_MAX_CONCURRENCY, "maximum number of rows to generate queries for at one time. defaults at 1000") + unsafePluginPtr := flag.Bool("u", false, "execute new plugin commands without prompting. only for use in CI or trusted environments") flag.Parse() if *verbosePtr { @@ -77,7 +125,11 @@ func main() { os.Exit(1) } - err = ripoff.RunRipoff(ctx, tx, totalRipoff) + if !*unsafePluginPtr && len(totalRipoff.Plugins) > 0 { + confirmPluginsSafe(totalRipoff.Plugins) + } + + err = ripoff.RunRipoff(ctx, tx, totalRipoff, *maxConcurrencyPtr) if err != nil { slog.Error("Could not run ripoff", errAttr(err)) os.Exit(1) diff --git a/db.go b/db.go index 24257e5..77f3e51 100644 --- a/db.go +++ b/db.go @@ -10,6 +10,7 @@ import ( "regexp" "slices" "strings" + "sync" "time" "github.com/brianvoe/gofakeit/v7" @@ -19,9 +20,11 @@ import ( "github.com/tj/go-naturaldate" ) +const DEFAULT_MAX_CONCURRENCY = 1000 + // Runs ripoff from start to finish, without committing the transaction. -func RunRipoff(ctx context.Context, tx pgx.Tx, totalRipoff RipoffFile) error { - manager, err := NewPluginManager(totalRipoff.Plugins) +func RunRipoff(ctx context.Context, tx pgx.Tx, totalRipoff RipoffFile, maxConcurrency int) error { + manager, err := NewPluginManager(ctx, totalRipoff.Plugins) if err != nil { return err } @@ -32,7 +35,7 @@ func RunRipoff(ctx context.Context, tx pgx.Tx, totalRipoff RipoffFile) error { return err } - queries, err := buildQueriesForRipoff(manager, primaryKeys, totalRipoff) + queries, err := buildQueriesForRipoff(maxConcurrency, manager, primaryKeys, totalRipoff) if err != nil { return err } @@ -163,10 +166,11 @@ func prepareValue(manager *PluginManager, rawValue string) (string, error) { return fakerResult, nil } -func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, rowId string, row Row, dependencyGraph map[string][]string) (string, error) { +func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, rowId string, row Row) (string, []string, error) { + dependencyResult := []string{} parts := strings.Split(rowId, ":") if len(parts) < 2 { - return "", fmt.Errorf("invalid id: %s", rowId) + return "", dependencyResult, fmt.Errorf("invalid id: %s", rowId) } table := parts[0] primaryKeysForTable, hasPrimaryKeysForTable := primaryKeys[table] @@ -210,10 +214,10 @@ func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, row case []string: dependencies = v default: - return "", fmt.Errorf("cannot parse ~dependencies value in row %s", rowId) + return "", dependencyResult, fmt.Errorf("cannot parse ~dependencies value in row %s", rowId) } - dependencyGraph[rowId] = append(dependencyGraph[rowId], dependencies...) - dependencyGraph[rowId] = slices.Compact(dependencyGraph[rowId]) + dependencyResult = append(dependencyResult, dependencies...) + dependencyResult = slices.Compact(dependencyResult) continue } @@ -230,14 +234,14 @@ func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, row addEdge := referenceRegex.MatchString(value) // Don't add edges to and from the same row. if addEdge && rowId != value { - dependencyGraph[rowId] = append(dependencyGraph[rowId], value) - dependencyGraph[rowId] = slices.Compact(dependencyGraph[rowId]) + dependencyResult = append(dependencyResult, value) + dependencyResult = slices.Compact(dependencyResult) } columns = append(columns, pq.QuoteIdentifier(column)) valuePrepared, err := prepareValue(manager, value) if err != nil { - return "", err + return "", dependencyResult, err } // Assume this column is the primary key. if rowId == value && onConflictColumn == "" { @@ -249,7 +253,7 @@ func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, row } if onConflictColumn == "" { - return "", fmt.Errorf("cannot determine column to conflict with for: %s, saw %s", rowId, row) + return "", dependencyResult, fmt.Errorf("cannot determine column to conflict with for: %s, saw %s", rowId, row) } // Extremely smart query builder. @@ -263,11 +267,11 @@ func buildQueryForRow(manager *PluginManager, primaryKeys PrimaryKeysResult, row strings.Join(values, ","), onConflictColumn, strings.Join(setStatements, ","), - ), nil + ), dependencyResult, nil } // Returns a sorted array of queries to run based on a given ripoff file. -func buildQueriesForRipoff(manager *PluginManager, primaryKeys PrimaryKeysResult, totalRipoff RipoffFile) ([]string, error) { +func buildQueriesForRipoff(maxConcurrency int, manager *PluginManager, primaryKeys PrimaryKeysResult, totalRipoff RipoffFile) ([]string, error) { dependencyGraph := map[string][]string{} queries := map[string]string{} @@ -277,12 +281,37 @@ func buildQueriesForRipoff(manager *PluginManager, primaryKeys PrimaryKeysResult } // Build queries. + var wg sync.WaitGroup + semaphore := make(chan struct{}, maxConcurrency) + type rowChanItem struct { + rowId string + query string + dependencies []string + err error + } + rowChan := make(chan rowChanItem, len(totalRipoff.Rows)) for rowId, row := range totalRipoff.Rows { - query, err := buildQueryForRow(manager, primaryKeys, rowId, row, dependencyGraph) - if err != nil { - return []string{}, err + semaphore <- struct{}{} + wg.Add(1) + go func(rowId string, row Row) { + defer wg.Done() + defer func() { <-semaphore }() + query, dependencies, err := buildQueryForRow(manager, primaryKeys, rowId, row) + rowChan <- rowChanItem{rowId, query, dependencies, err} + }(rowId, row) + } + + go func() { + wg.Wait() + close(rowChan) + }() + + for rowItem := range rowChan { + if rowItem.err != nil { + return []string{}, rowItem.err } - queries[rowId] = query + dependencyGraph[rowItem.rowId] = rowItem.dependencies + queries[rowItem.rowId] = rowItem.query } // Sort and reverse the graph, so queries are in order of least (hopefully none) to most dependencies. diff --git a/db_test.go b/db_test.go index 07010f5..aa41cf2 100644 --- a/db_test.go +++ b/db_test.go @@ -23,10 +23,10 @@ func runTestData(t *testing.T, ctx context.Context, tx pgx.Tx, testDir string) { require.NoError(t, err) totalRipoff, err := RipoffFromDirectory(testDir, enums) require.NoError(t, err) - err = RunRipoff(ctx, tx, totalRipoff) + err = RunRipoff(ctx, tx, totalRipoff, DEFAULT_MAX_CONCURRENCY) require.NoError(t, err) // Run again to implicitly test upsert behavior. - err = RunRipoff(ctx, tx, totalRipoff) + err = RunRipoff(ctx, tx, totalRipoff, DEFAULT_MAX_CONCURRENCY) require.NoError(t, err) // Try to verify that the number of generated rows matches the ripoff. tableCount := map[string]int{} diff --git a/export_test.go b/export_test.go index 8955071..70c6b94 100644 --- a/export_test.go +++ b/export_test.go @@ -48,7 +48,7 @@ func runExportTestData(t *testing.T, ctx context.Context, tx pgx.Tx, testDir str _, err = tx.Exec(ctx, string(truncateFile)) require.NoError(t, err) // Run generated ripoff. - err = RunRipoff(ctx, tx, ripoffFile) + err = RunRipoff(ctx, tx, ripoffFile, DEFAULT_MAX_CONCURRENCY) require.NoError(t, err) // Try to verify that the number of generated rows matches the ripoff. tableCount := map[string]int{} diff --git a/plugins.go b/plugins.go index b218fa5..c93f0bb 100644 --- a/plugins.go +++ b/plugins.go @@ -2,32 +2,64 @@ package ripoff import ( "bufio" + "context" "encoding/json" "fmt" "log/slog" "net" "os/exec" "strings" + "sync" "syscall" "time" + + "github.com/google/uuid" ) +const PLUGIN_STARTUP_DEADLINE = 5 * time.Second +const PLUGIN_TCP_CONNECTION_DEADLINE = time.Second + +// The shape that plugins expect for requests +type Request struct { + Id string `json:"id"` + Type string `json:"type"` + ValueFunc string `json:"valueFunc"` + Args []string `json:"args"` +} + +// The shape that plugins expect for responses +type Response struct { + Id string `json:"id"` + Value string `json:"value"` +} + +// Used to communicate async responses to a goroutine that sends them syncronously to plugins +type ResponseChanMessage struct { + response Response + err error +} + +// Used by a goroutine that sends messages over a response channel +type CallChanMessage struct { + plugin RipoffPlugin + valueFunc string + args []string + responseChan chan ResponseChanMessage +} + +// Manages plugin commands and TCP connections - intended to be used as a singleton for the entire ripoff process. type PluginManager struct { valueFuncMap map[string]RipoffPlugin spawnedCommands []*exec.Cmd addressToConn map[string]net.Conn + callChan chan CallChanMessage } +// Closes all open connections and kills process group for each plugin command and its children. func (m *PluginManager) Close() { + close(m.callChan) for _, conn := range m.addressToConn { - message, _ := json.Marshal(Request{ - Type: "exit", - }) - _, err := conn.Write(append(message, '\n')) - if err != nil { - slog.Error("Could not write exit message to plugn connection", slog.Any("error", err)) - } - err = conn.Close() + err := conn.Close() if err != nil { slog.Error("Could not close plugin connection", slog.Any("error", err)) } @@ -40,21 +72,13 @@ func (m *PluginManager) Close() { } } +// Determines if a plugin provides the given valueFunc func (m *PluginManager) Supports(valueFunc string) bool { _, ok := m.valueFuncMap[valueFunc] return ok } -type Request struct { - Type string `json:"type"` - ValueFunc string `json:"valueFunc"` - Args []string `json:"args"` -} - -type Response struct { - Value string `json:"value"` -} - +// Spawns a new plugin and waits for it to be ready func spawn(command []string) (*exec.Cmd, error) { commandArgs := []string{} if len(command) > 1 { @@ -74,7 +98,7 @@ func spawn(command []string) (*exec.Cmd, error) { scanner.Scan() line := scanner.Text() // Set deadline for outputting READY message - timer := time.AfterFunc(5*time.Second, func() { + timer := time.AfterFunc(PLUGIN_STARTUP_DEADLINE, func() { err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) if err != nil { slog.Error("Could not kill plugin after READY timeout", slog.Any("error", err)) @@ -92,69 +116,102 @@ func spawn(command []string) (*exec.Cmd, error) { return cmd, nil } +// Initializes a connection to the given TCP address. func connect(address string) (net.Conn, error) { - conn, err := net.Dial("tcp", address) + conn, err := net.DialTimeout("tcp", address, PLUGIN_TCP_CONNECTION_DEADLINE) if err != nil { return nil, err } return conn, nil } +// Starts goroutines for the plugin manager, which mostly handle TCP requests and responses. +func (m *PluginManager) run(ctx context.Context) { + // A sync map used to associate arbitrary responses with stalled goroutines, based on a random ID from a request message. + var idMap sync.Map + for _, conn := range m.addressToConn { + // Watch for new responses from this plugin. + // Should hopefully be halted when the connection is closed. + go func() { + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + line := scanner.Bytes() + response := Response{} + err := json.Unmarshal(line, &response) + if err != nil { + slog.Error("Unable to parse response", slog.Any("error", err)) + continue + } + responseChanMessage, ok := idMap.Load(response.Id) + if !ok { + slog.Error("No plugin channel found in map for response ID", slog.Any("line", line)) + continue + } + // The goroutine that sent the request is waiting for a response + responseChanMessage.(chan ResponseChanMessage) <- ResponseChanMessage{response: response, err: nil} + } + }() + } + for { + select { + case <-ctx.Done(): + return + // New request to send to a plugin. + case call := <-m.callChan: + conn, hasCon := m.addressToConn[call.plugin.Address] + if !hasCon { + call.responseChan <- ResponseChanMessage{err: fmt.Errorf("connection for plugin %s does not exist", strings.Join(call.plugin.Command, " "))} + return + } + // Generate a random ID to associate responses with this request. + id := uuid.New().String() + idMap.Store(id, call.responseChan) + message, err := json.Marshal(Request{ + Id: id, + Type: "valueFunc", + ValueFunc: call.valueFunc, + Args: call.args, + }) + if err != nil { + call.responseChan <- ResponseChanMessage{err: err} + return + } + _, err = conn.Write(append(message, '\n')) + if err != nil { + call.responseChan <- ResponseChanMessage{err: err} + return + } + } + } +} + +// Calls an arbitrary plugin associated with this valueFunc over TCP. func (m *PluginManager) Call(valueFunc string, args ...string) (string, error) { plugin, hasPlugin := m.valueFuncMap[valueFunc] if !hasPlugin { return "", fmt.Errorf("plugin for valueFunc %s is not defined", valueFunc) } - conn, ok := m.addressToConn[plugin.Address] - // Attempt to start process and wait for port to open - if !ok { - cmd, err := spawn(plugin.Command) - if err != nil { - return "", err - } - m.spawnedCommands = append(m.spawnedCommands, cmd) - conn, err = connect(plugin.Address) - if err != nil { - return "", err - } - m.addressToConn[plugin.Address] = conn - } - // Send message to open TCP socket - err := conn.SetReadDeadline(time.Now().Add(time.Second)) - if err != nil { - slog.Error("Could not set read deadline for plugin connection", slog.Any("error", err)) - } - scanner := bufio.NewScanner(conn) - message, err := json.Marshal(Request{ - Type: "valueFunc", - ValueFunc: valueFunc, - Args: args, - }) - if err != nil { - return "", err - } - _, err = conn.Write(append(message, '\n')) - if err != nil { - return "", err - } - if !scanner.Scan() { - return "", fmt.Errorf("plugin command '%s' failed to response to TCP message: %v", strings.Join(plugin.Command, " "), scanner.Err()) - } - line := scanner.Bytes() - response := Response{} - err = json.Unmarshal(line, &response) - if err != nil { - return "", err - } - return response.Value, nil + // Create a channel that can be used to resume this function + responseChan := make(chan ResponseChanMessage, 1) + m.callChan <- CallChanMessage{ + plugin: plugin, + valueFunc: valueFunc, + args: args, + responseChan: responseChan, + } + // Block as we wait for a response + response := <-responseChan + return response.response.Value, response.err } -func NewPluginManager(plugins map[string]RipoffPlugin) (*PluginManager, error) { +func NewPluginManager(ctx context.Context, plugins map[string]RipoffPlugin) (*PluginManager, error) { m := &PluginManager{ valueFuncMap: map[string]RipoffPlugin{}, - addressToConn: map[string]net.Conn{}, spawnedCommands: []*exec.Cmd{}, + callChan: make(chan CallChanMessage), + addressToConn: map[string]net.Conn{}, } + // Store a map of valueFuncs to plugins and also validate that there is no overlap. for pluginName, plugin := range plugins { if len(plugin.Command) == 0 { return nil, fmt.Errorf("cannot create new PluginManager - the plugin %s does not define a command", pluginName) @@ -167,5 +224,20 @@ func NewPluginManager(plugins map[string]RipoffPlugin) (*PluginManager, error) { m.valueFuncMap[valueFunc] = plugin } } + for _, plugin := range plugins { + // Startup the plugin + cmd, err := spawn(plugin.Command) + if err != nil { + return nil, err + } + m.spawnedCommands = append(m.spawnedCommands, cmd) + // Connect to the plugin's address over TCP + conn, err := connect(plugin.Address) + if err != nil { + return nil, err + } + m.addressToConn[plugin.Address] = conn + } + go m.run(ctx) return m, nil } diff --git a/testdata/import/plugins_scale/plugins_scale.yml b/testdata/import/plugins_scale/plugins_scale.yml new file mode 100644 index 0000000..11e0cfa --- /dev/null +++ b/testdata/import/plugins_scale/plugins_scale.yml @@ -0,0 +1,9 @@ +plugins: + helloplugin: + command: [go, run, cmd/helloplugin/helloplugin.go] + address: localhost:6767 + valueFuncs: [sayHello, sayGoodbye] +rows: + unused: + template: template_user.yml + numUsers: 1000 diff --git a/testdata/import/plugins_scale/schema.sql b/testdata/import/plugins_scale/schema.sql new file mode 100644 index 0000000..50f3760 --- /dev/null +++ b/testdata/import/plugins_scale/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE users ( + id UUID NOT NULL PRIMARY KEY, + email TEXT NOT NULL, + name TEXT NOT NULL +); diff --git a/testdata/import/plugins_scale/template_user.yml b/testdata/import/plugins_scale/template_user.yml new file mode 100644 index 0000000..813c7f9 --- /dev/null +++ b/testdata/import/plugins_scale/template_user.yml @@ -0,0 +1,6 @@ +rows: + {{ range $i, $v := (intSlice .numUsers) }} + users:uuid({{ $i }}): + email: email({{ $i }}) + name: sayHello({{ $i }}) + {{ end }}