commit 8dcba8adb27c494e68d4c74ba22ccd342cfdf66d Author: scheibling Date: Fri Apr 10 22:02:47 2026 +0200 Updated diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d8f11f4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2d6597c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Jordan Coupal + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..20561e6 --- /dev/null +++ b/README.md @@ -0,0 +1,109 @@ +

+ + +# go-pgdump - Go PostgreSQL Dump + +[![License](https://img.shields.io/badge/license-MIT-green)](./LICENSE) +[![GitHub issues](https://img.shields.io/github/issues-raw/JCoupalK/go-pgdump)](https://github.com/JCoupalK/go-pgdump/issues) +[![GitHub go.mod Go version (branch & subdirectory of monorepo)](https://img.shields.io/github/go-mod/go-version/JCoupalK/go-pgdump/main)](./go.mod) + +Create PostgreSQL or CSV dumps in Go without the pg_dump CLI as a dependancy. + +Inspired by [go-mysqldump](https://github.com/jamf/go-mysqldump) which does that but for MySQL/MariaDB. + +Doesn't feature all of pg_dump features just yet so it is still a work in progress. + +## Simple example for a CLI tool using the library + +```go +package main + +import ( + "flag" + "fmt" + "log" + "path/filepath" + "strings" + "time" + + "github.com/JCoupalK/go-pgdump" +) + +var ( + username = flag.String("u", "", "username for PostgreSQL") + password = flag.String("p", "", "password for PostgreSQL") + hostname = flag.String("h", "", "hostname for PostgreSQL") + db = flag.String("d", "", "database name for PostgreSQL") + port = flag.Int("P", 5432, "port number for PostgreSQL") + dumpCSV = flag.Bool("csv", false, "dump to CSV") + csvTables = flag.String("tables", "", "comma-separated list of table names to dump to CSV") + outputDir = flag.String("o", "", "path to output directory") + suffix = flag.String("sx", "", "suffix of table names for dump") + prefix = flag.String("px", "", "prefix of table names for dump") + schema = flag.String("s", "", "schema filter for dump") +) + +func BackupPostgreSQL(username, password, hostname, dbname, outputDir string, port int) { + // PostgreSQL connection string + psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + hostname, port, username, password, dbname) + + // Create a new dumper instance with connection string and number of threads + dumper := pgdump.NewDumper(psqlInfo, 50) + + // Check if CSV dump is requested + if *dumpCSV { + tableList := strings.Split(*csvTables, ",") + csvFiles, err := dumper.DumpToCSV(outputDir, tableList...) + if err != nil { + log.Fatal("Error dumping to CSV:", err) + } + fmt.Println("CSV files successfully saved in:", csvFiles) + } else { + // Regular SQL dump + currentTime := time.Now() + dumpFilename := filepath.Join( + outputDir, + fmt.Sprintf("%s-%s.sql", dbname, currentTime.Format("20060102T150405")), + ) + + if err := dumper.DumpDatabase(dumpFilename, &pgdump.TableOptions{ + TableSuffix: *suffix, + TablePrefix: *prefix, + Schema: *schema, + }); err != nil { + log.Fatal("Error dumping database:", err) + } + + fmt.Println("Dump successfully saved to:", dumpFilename) + } +} + +func main() { + flag.Parse() + BackupPostgreSQL(*username, *password, *hostname, *db, *outputDir, *port) +} +``` + +### Usage for a database dump with default port + +```bash +./go-pgdump-cli -u user -p example -h localhost -d test -o test -sx example -px test -s myschema +``` + +### Usage for a CSV dump with custom port + +```bash +./go-pgdump-cli -u user -p example -h localhost -d test -P 5433 -o test -csv -tables employees,departments +``` + +See more about the CLI tool [here](https://github.com/JCoupalK/go-pgdump-cli). + +## Contributing + +Contributions are welcome. Please fork the repository and submit a pull request with your changes or improvements. + +## License + +This project is licensed under MIT - see the LICENSE file for details. diff --git a/data.go b/data.go new file mode 100644 index 0000000..1b4a6cb --- /dev/null +++ b/data.go @@ -0,0 +1,162 @@ +package pgdump + +import ( + "database/sql" + "fmt" + "strings" +) + +// options for dumping selective tables. +type TableOptions struct { + TableSuffix string + TablePrefix string + Schema string +} + +// returns a slice of table names matching options, if left blank will default to : +// +// -> no prefix or suffix +// -> public schema +func getTables(db *sql.DB, opts *TableOptions) ([]string, error) { + var ( + query string + ) + if opts != nil { + if opts.Schema == "" { + opts.Schema = "public" + } + query = fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s' AND table_name LIKE '%s'", opts.Schema, (opts.TablePrefix + "%%" + opts.TableSuffix)) + } else { + query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" + } + + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []string + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, err + } + if opts != nil && opts.Schema != "public" { + tables = append(tables, opts.Schema+"."+tableName) + } else { + tables = append(tables, tableName) + } + } + return tables, nil +} + +// generates the SQL for creating a table, including column definitions. +func getCreateTableStatement(db *sql.DB, tableName string) (string, error) { + query := fmt.Sprintf("SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_name = '%s'", tableName) + rows, err := db.Query(query) + if err != nil { + return "", err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var columnName, dataType string + var charMaxLength *int + if err := rows.Scan(&columnName, &dataType, &charMaxLength); err != nil { + return "", err + } + columnDef := fmt.Sprintf("%s %s", columnName, dataType) + if charMaxLength != nil { + columnDef += fmt.Sprintf("(%d)", *charMaxLength) + } + columns = append(columns, columnDef) + } + + return fmt.Sprintf( + "CREATE TABLE %s (\n %s\n);", + escapeReservedName(tableName), + strings.Join(columns, ",\n "), + ), nil +} + +// generates the COPY command to import data for a table. +func getTableDataCopyFormat(db *sql.DB, tableName string) (string, error) { + query := fmt.Sprintf("SELECT * FROM %s", escapeReservedName(tableName)) + rows, err := db.Query(query) + if err != nil { + return "", err + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + return "", err + } + values := make([]sql.RawBytes, len(columns)) + scanArgs := make([]interface{}, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + var output strings.Builder + output.WriteString(fmt.Sprintf( + "COPY %s (%s) FROM stdin;\n", + escapeReservedName(tableName), + strings.Join(columns, ", "), + )) + for rows.Next() { + err := rows.Scan(scanArgs...) + if err != nil { + return "", err + } + var valueStrings []string + for _, value := range values { + valueStrings = append(valueStrings, string(value)) + } + output.WriteString(strings.Join(valueStrings, "\t") + "\n") + } + output.WriteString("\\.\n") + + return output.String(), nil +} + +func getTableDataAsCSV(db *sql.DB, tableName string) ([][]string, error) { + query := fmt.Sprintf("SELECT * FROM %s", escapeReservedName(tableName)) + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + return nil, err + } + + output := [][]string{columns} + + values := make([]sql.RawBytes, len(columns)) + scanArgs := make([]interface{}, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + if err := rows.Scan(scanArgs...); err != nil { + return nil, err + } + var valueStrings []string + for _, value := range values { + if value == nil { + valueStrings = append(valueStrings, "NULL") + } else { + valueStrings = append(valueStrings, string(value)) + } + } + output = append(output, valueStrings) + } + + return output, nil +} diff --git a/dumper.go b/dumper.go new file mode 100644 index 0000000..d89dfd3 --- /dev/null +++ b/dumper.go @@ -0,0 +1,286 @@ +package pgdump + +import ( + "context" + "database/sql" + "encoding/csv" + "fmt" + "io" + "os" + "path" + "slices" + "strings" + "sync" + "time" + + _ "github.com/lib/pq" + "golang.org/x/sync/errgroup" +) + +type Dumper struct { + ConnectionString string + Parallels int + DumpVersion string +} + +func NewDumper(connectionString string, threads int) *Dumper { + // Version number of go-pgdump, used in the template after a dump + dumpVersion := "1.1.0" + + // set a default value for Parallels if it is zero or less + if threads <= 0 { + threads = 50 + } + return &Dumper{ConnectionString: connectionString, Parallels: threads, DumpVersion: dumpVersion} +} + +func (d *Dumper) DumpDatabaseToWriter(writer io.Writer, opts *TableOptions) error { + db, err := sql.Open("postgres", d.ConnectionString) + if err != nil { + return err + } + defer db.Close() + + // Template variables + info := DumpInfo{ + DumpVersion: d.DumpVersion, + ServerVersion: getServerVersion(db), + CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"), + ThreadsNumber: d.Parallels, + } + + if err := writeHeader(writer, info); err != nil { + return err + } + + tables, err := getTables(db, opts) + if err != nil { + return err + } + + var ( + wg sync.WaitGroup + mx sync.Mutex + ) + + chunks := slices.Chunk(tables, d.Parallels) + for chunk := range chunks { + wg.Add(len(chunk)) + for _, table := range chunk { + //we can add the switch here for export and add a go func here. + go func(table string) { + defer wg.Done() + str, err := scriptTable(db, table) + if err != nil { + return + } + mx.Lock() + io.WriteString(writer, str) + mx.Unlock() + }(table) + } + wg.Wait() + } + if err := writeFooter(writer, info); err != nil { + return err + } + + return nil +} + +func (d *Dumper) DumpDatabase(outputFile string, opts *TableOptions) error { + file, err := os.Create(outputFile) + if err != nil { + return err + } + defer file.Close() + + return d.DumpDatabaseToWriter(file, opts) +} + +func (d *Dumper) DumpDBToCSV(outputDIR, outputFile string, opts *TableOptions) error { + db, err := sql.Open("postgres", d.ConnectionString) + if err != nil { + return err + } + defer db.Close() + + file, err := os.Create(path.Join(outputDIR, outputFile)) + if err != nil { + return err + } + defer file.Close() + + // Template variables + info := DumpInfo{ + DumpVersion: d.DumpVersion, + ServerVersion: getServerVersion(db), + CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"), + ThreadsNumber: d.Parallels, + } + + if err := writeHeader(file, info); err != nil { + return err + } + + if err := writeFooter(file, info); err != nil { + return err + } + + tablename, err := getTables(db, opts) + if err != nil { + return err + } + + chunks := slices.Chunk(tablename, d.Parallels) + g, _ := errgroup.WithContext(context.Background()) + for chunk := range chunks { + g.SetLimit(len(chunk)) + for _, table := range chunk { + table := table // capture the current value of table for use in goroutine + g.Go(func() error { + records, err := getTableDataAsCSV(db, table) + if err != nil { + return err + } + + // Correctly open (or create) the file for writing + f, err := os.Create(path.Join(outputDIR, table+".csv")) + if err != nil { + return err + } + defer f.Close() + + csvWriter := csv.NewWriter(f) + if err := csvWriter.WriteAll(records); err != nil { + return err + } + csvWriter.Flush() + return nil + }) + } + if err := g.Wait(); err != nil { + return err + } + } + return nil +} + +func scriptTable(db *sql.DB, tableName string) (string, error) { + var buffer string + // Script CREATE TABLE statement + createStmt, err := getCreateTableStatement(db, tableName) + if err != nil { + return "", fmt.Errorf("error creating table statement for %s: %v", tableName, err) + } + buffer = buffer + createStmt + "\n\n" + + // Script associated sequences (if any) + seqStmts, err := scriptSequences(db, tableName) + if err != nil { + return "", fmt.Errorf("error scripting sequences for table %s: %v", tableName, err) + } + buffer = buffer + seqStmts + "\n\n" + + // Script primary keys + pkStmt, err := scriptPrimaryKeys(db, tableName) + if err != nil { + return "", fmt.Errorf("error scripting primary keys for table %s: %v", tableName, err) + } + buffer = buffer + pkStmt + "\n\n" + + // Dump table data + copyStmt, err := getTableDataCopyFormat(db, tableName) + if err != nil { + return "", fmt.Errorf("error generating COPY statement for table %s: %v", tableName, err) + } + buffer = buffer + copyStmt + "\n\n" + + return buffer, nil +} + +func scriptSequences(db *sql.DB, tableName string) (string, error) { + var sequencesSQL strings.Builder + + // Query to identify sequences linked to the table's columns and fetch sequence definitions + query := ` +SELECT 'CREATE SEQUENCE ' || n.nspname || '.' || c.relname || ';' as seq_creation, + pg_get_serial_sequence(quote_ident(n.nspname) || '.' || quote_ident(t.relname), quote_ident(a.attname)) as seq_owned, + 'ALTER TABLE ' || quote_ident(n.nspname) || '.' || quote_ident(t.relname) || + ' ALTER COLUMN ' || quote_ident(a.attname) || + ' SET DEFAULT nextval(''' || n.nspname || '.' || c.relname || '''::regclass);' as col_default +FROM pg_class c +JOIN pg_namespace n ON c.relnamespace = n.oid +JOIN pg_depend d ON d.objid = c.oid AND d.deptype = 'a' AND d.classid = 'pg_class'::regclass +JOIN pg_attrdef ad ON ad.adrelid = d.refobjid AND ad.adnum = d.refobjsubid +JOIN pg_attribute a ON a.attrelid = d.refobjid AND a.attnum = d.refobjsubid +JOIN pg_class t ON t.oid = d.refobjid AND t.relkind = 'r' +WHERE c.relkind = 'S' AND t.relname = $1 AND n.nspname = 'public'; +` + + rows, err := db.Query(query, tableName) + if err != nil { + return "", fmt.Errorf("error querying sequences for table %s: %v", tableName, err) + } + defer rows.Close() + + for rows.Next() { + var seqCreation, seqOwned, colDefault string + if err := rows.Scan(&seqCreation, &seqOwned, &colDefault); err != nil { + return "", fmt.Errorf("error scanning sequence information: %v", err) + } + + // Here we directly use the sequence creation script. + // The seqOwned might not be necessary if we're focusing on creation and default value setting. + sequencesSQL.WriteString(seqCreation + "\n" + colDefault + "\n") + } + + if err := rows.Err(); err != nil { + return "", fmt.Errorf("error iterating over sequences: %v", err) + } + + return sequencesSQL.String(), nil +} + +func scriptPrimaryKeys(db *sql.DB, tableName string) (string, error) { + var pksSQL strings.Builder + + // Query to find primary key constraints for the specified table. + query := ` +SELECT con.conname AS constraint_name, + pg_get_constraintdef(con.oid) AS constraint_def +FROM pg_constraint con +JOIN pg_class rel ON rel.oid = con.conrelid +JOIN pg_namespace nsp ON nsp.oid = connamespace +WHERE con.contype = 'p' +AND rel.relname = $1 +AND nsp.nspname = 'public'; +` + rows, err := db.Query(query, tableName) + if err != nil { + return "", fmt.Errorf("error querying primary keys for table %s: %v", tableName, err) + } + defer rows.Close() + + // Iterate through each primary key constraint found and script it. + for rows.Next() { + var constraintName, constraintDef string + if err := rows.Scan(&constraintName, &constraintDef); err != nil { + return "", fmt.Errorf("error scanning primary key information: %v", err) + } + + // Construct the ALTER TABLE statement to add the primary key constraint. + pksSQL.WriteString(fmt.Sprintf( + "ALTER TABLE %s ADD CONSTRAINT %s %s;\n", + escapeReservedName(tableName), + constraintName, + constraintDef, + )) + } + + if err := rows.Err(); err != nil { + return "", fmt.Errorf("error iterating over primary keys: %v", err) + } + + return pksSQL.String(), nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..351dd81 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module git.cloudyne.io/go/pgdump + +go 1.26 + +require ( + github.com/lib/pq v1.10.9 + golang.org/x/sync v0.8.0 +) + +// Retract old unstable versions +retract ( + [v1.0.0, v1.0.9] + [v0.2.0, v0.2.9] + [v0.1.0, v0.1.9] +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6b4284a --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= diff --git a/reserved_names.go b/reserved_names.go new file mode 100644 index 0000000..d234d08 --- /dev/null +++ b/reserved_names.go @@ -0,0 +1,121 @@ +package pgdump + +import ( + "fmt" + "strings" +) + +var ( + postgresReservedNames = map[string]struct{}{ + "ALL": {}, + "ANALYSE": {}, + "ANALYZE": {}, + "AND": {}, + "ANY": {}, + "ARRAY": {}, + "AS": {}, + "ASC": {}, + "ASYMMETRIC": {}, + "AUTHORIZATION": {}, + "BINARY": {}, + "BOTH": {}, + "CASE": {}, + "CAST": {}, + "CHECK": {}, + "COLLATE": {}, + "COLLATION": {}, + "COLUMN": {}, + "CONCURRENTLY": {}, + "CONSTRAINT": {}, + "CREATE": {}, + "CROSS": {}, + "CURRENT_CATALOG": {}, + "CURRENT_DATE": {}, + "CURRENT_ROLE": {}, + "CURRENT_SCHEMA": {}, + "CURRENT_TIME": {}, + "CURRENT_TIMESTAMP": {}, + "CURRENT_USER": {}, + "DEFAULT": {}, + "DEFERRABLE": {}, + "DESC": {}, + "DISTINCT": {}, + "DO": {}, + "ELSE": {}, + "END": {}, + "EXCEPT": {}, + "FALSE": {}, + "FETCH": {}, + "FOR": {}, + "FOREIGN": {}, + "FREEZE": {}, + "FROM": {}, + "FULL": {}, + "GRANT": {}, + "GROUP": {}, + "HAVING": {}, + "ILIKE": {}, + "IN": {}, + "INITIALLY": {}, + "INNER": {}, + "INTERSECT": {}, + "INTO": {}, + "IS": {}, + "ISNULL": {}, + "JOIN": {}, + "LATERAL": {}, + "LEADING": {}, + "LEFT": {}, + "LIKE": {}, + "LIMIT": {}, + "LOCALTIME": {}, + "LOCALTIMESTAMP": {}, + "NATURAL": {}, + "NOT": {}, + "NOTNULL": {}, + "NULL": {}, + "OFFSET": {}, + "ON": {}, + "ONLY": {}, + "OR": {}, + "ORDER": {}, + "OUTER": {}, + "OVERLAPS": {}, + "PLACING": {}, + "PRIMARY": {}, + "REFERENCES": {}, + "RETURNING": {}, + "RIGHT": {}, + "SELECT": {}, + "SESSION_USER": {}, + "SIMILAR": {}, + "SOME": {}, + "SYMMETRIC": {}, + "TABLE": {}, + "TABLESAMPLE": {}, + "THEN": {}, + "TO": {}, + "TRAILING": {}, + "TRUE": {}, + "UNION": {}, + "UNIQUE": {}, + "USER": {}, + "USING": {}, + "VARIADIC": {}, + "VERBOSE": {}, + "WHEN": {}, + "WHERE": {}, + "WINDOW": {}, + "WITH": {}, + } +) + +func escapeReservedName(name string) string { + normalizedName := strings.ToUpper(name) + + if _, isReserved := postgresReservedNames[normalizedName]; isReserved { + return fmt.Sprintf("\"%s\"", name) + } + + return name +} diff --git a/template.go b/template.go new file mode 100644 index 0000000..a2970bd --- /dev/null +++ b/template.go @@ -0,0 +1,65 @@ +package pgdump + +import ( + "database/sql" + "io" + "text/template" +) + +type DumpInfo struct { + DumpVersion string + ServerVersion string + CompleteTime string + ThreadsNumber int +} + +func getServerVersion(db *sql.DB) string { + var version string + query := "SELECT version();" + row := db.QueryRow(query) + if err := row.Scan(&version); err != nil { + return "Unknown" + } + return version +} + +func writeHeader(file io.Writer, info DumpInfo) error { + const headerTemplate = `-- Go PostgreSQL Dump v{{ .DumpVersion }} +-- +-- Server version: +-- {{ .ServerVersion }} +-- Threads Used: +-- {{ .ThreadsNumber }} + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +SET default_tablespace = ''; + +SET default_table_access_method = heap; +` + tmpl, err := template.New("header").Parse(headerTemplate) + if err != nil { + return err + } + return tmpl.Execute(file, info) +} + +func writeFooter(file io.Writer, info DumpInfo) error { + const footerTemplate = `-- +-- Dump completed on {{ .CompleteTime }} +--` + tmpl, err := template.New("footer").Parse(footerTemplate) + if err != nil { + return err + } + return tmpl.Execute(file, info) +}