Files
pgdump/dumper.go

287 lines
7.4 KiB
Go
Raw Permalink Normal View History

2026-04-10 22:02:47 +02:00
package pgdump
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"io"
"os"
"path"
"slices"
"strings"
"sync"
"time"
_ "github.com/lib/pq"
"golang.org/x/sync/errgroup"
)
type Dumper struct {
ConnectionString string
Parallels int
DumpVersion string
}
func NewDumper(connectionString string, threads int) *Dumper {
// Version number of go-pgdump, used in the template after a dump
dumpVersion := "1.1.0"
// set a default value for Parallels if it is zero or less
if threads <= 0 {
threads = 50
}
return &Dumper{ConnectionString: connectionString, Parallels: threads, DumpVersion: dumpVersion}
}
func (d *Dumper) DumpDatabaseToWriter(writer io.Writer, opts *TableOptions) error {
db, err := sql.Open("postgres", d.ConnectionString)
if err != nil {
return err
}
defer db.Close()
// Template variables
info := DumpInfo{
DumpVersion: d.DumpVersion,
ServerVersion: getServerVersion(db),
CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"),
ThreadsNumber: d.Parallels,
}
if err := writeHeader(writer, info); err != nil {
return err
}
tables, err := getTables(db, opts)
if err != nil {
return err
}
var (
wg sync.WaitGroup
mx sync.Mutex
)
chunks := slices.Chunk(tables, d.Parallels)
for chunk := range chunks {
wg.Add(len(chunk))
for _, table := range chunk {
//we can add the switch here for export and add a go func here.
go func(table string) {
defer wg.Done()
str, err := scriptTable(db, table)
if err != nil {
return
}
mx.Lock()
io.WriteString(writer, str)
mx.Unlock()
}(table)
}
wg.Wait()
}
if err := writeFooter(writer, info); err != nil {
return err
}
return nil
}
func (d *Dumper) DumpDatabase(outputFile string, opts *TableOptions) error {
file, err := os.Create(outputFile)
if err != nil {
return err
}
defer file.Close()
return d.DumpDatabaseToWriter(file, opts)
}
func (d *Dumper) DumpDBToCSV(outputDIR, outputFile string, opts *TableOptions) error {
db, err := sql.Open("postgres", d.ConnectionString)
if err != nil {
return err
}
defer db.Close()
file, err := os.Create(path.Join(outputDIR, outputFile))
if err != nil {
return err
}
defer file.Close()
// Template variables
info := DumpInfo{
DumpVersion: d.DumpVersion,
ServerVersion: getServerVersion(db),
CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"),
ThreadsNumber: d.Parallels,
}
if err := writeHeader(file, info); err != nil {
return err
}
if err := writeFooter(file, info); err != nil {
return err
}
tablename, err := getTables(db, opts)
if err != nil {
return err
}
chunks := slices.Chunk(tablename, d.Parallels)
g, _ := errgroup.WithContext(context.Background())
for chunk := range chunks {
g.SetLimit(len(chunk))
for _, table := range chunk {
table := table // capture the current value of table for use in goroutine
g.Go(func() error {
records, err := getTableDataAsCSV(db, table)
if err != nil {
return err
}
// Correctly open (or create) the file for writing
f, err := os.Create(path.Join(outputDIR, table+".csv"))
if err != nil {
return err
}
defer f.Close()
csvWriter := csv.NewWriter(f)
if err := csvWriter.WriteAll(records); err != nil {
return err
}
csvWriter.Flush()
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
}
return nil
}
func scriptTable(db *sql.DB, tableName string) (string, error) {
var buffer string
// Script CREATE TABLE statement
createStmt, err := getCreateTableStatement(db, tableName)
if err != nil {
return "", fmt.Errorf("error creating table statement for %s: %v", tableName, err)
}
buffer = buffer + createStmt + "\n\n"
// Script associated sequences (if any)
seqStmts, err := scriptSequences(db, tableName)
if err != nil {
return "", fmt.Errorf("error scripting sequences for table %s: %v", tableName, err)
}
buffer = buffer + seqStmts + "\n\n"
// Script primary keys
pkStmt, err := scriptPrimaryKeys(db, tableName)
if err != nil {
return "", fmt.Errorf("error scripting primary keys for table %s: %v", tableName, err)
}
buffer = buffer + pkStmt + "\n\n"
// Dump table data
copyStmt, err := getTableDataCopyFormat(db, tableName)
if err != nil {
return "", fmt.Errorf("error generating COPY statement for table %s: %v", tableName, err)
}
buffer = buffer + copyStmt + "\n\n"
return buffer, nil
}
func scriptSequences(db *sql.DB, tableName string) (string, error) {
var sequencesSQL strings.Builder
// Query to identify sequences linked to the table's columns and fetch sequence definitions
query := `
SELECT 'CREATE SEQUENCE ' || n.nspname || '.' || c.relname || ';' as seq_creation,
pg_get_serial_sequence(quote_ident(n.nspname) || '.' || quote_ident(t.relname), quote_ident(a.attname)) as seq_owned,
'ALTER TABLE ' || quote_ident(n.nspname) || '.' || quote_ident(t.relname) ||
' ALTER COLUMN ' || quote_ident(a.attname) ||
' SET DEFAULT nextval(''' || n.nspname || '.' || c.relname || '''::regclass);' as col_default
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
JOIN pg_depend d ON d.objid = c.oid AND d.deptype = 'a' AND d.classid = 'pg_class'::regclass
JOIN pg_attrdef ad ON ad.adrelid = d.refobjid AND ad.adnum = d.refobjsubid
JOIN pg_attribute a ON a.attrelid = d.refobjid AND a.attnum = d.refobjsubid
JOIN pg_class t ON t.oid = d.refobjid AND t.relkind = 'r'
WHERE c.relkind = 'S' AND t.relname = $1 AND n.nspname = 'public';
`
rows, err := db.Query(query, tableName)
if err != nil {
return "", fmt.Errorf("error querying sequences for table %s: %v", tableName, err)
}
defer rows.Close()
for rows.Next() {
var seqCreation, seqOwned, colDefault string
if err := rows.Scan(&seqCreation, &seqOwned, &colDefault); err != nil {
return "", fmt.Errorf("error scanning sequence information: %v", err)
}
// Here we directly use the sequence creation script.
// The seqOwned might not be necessary if we're focusing on creation and default value setting.
sequencesSQL.WriteString(seqCreation + "\n" + colDefault + "\n")
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating over sequences: %v", err)
}
return sequencesSQL.String(), nil
}
func scriptPrimaryKeys(db *sql.DB, tableName string) (string, error) {
var pksSQL strings.Builder
// Query to find primary key constraints for the specified table.
query := `
SELECT con.conname AS constraint_name,
pg_get_constraintdef(con.oid) AS constraint_def
FROM pg_constraint con
JOIN pg_class rel ON rel.oid = con.conrelid
JOIN pg_namespace nsp ON nsp.oid = connamespace
WHERE con.contype = 'p'
AND rel.relname = $1
AND nsp.nspname = 'public';
`
rows, err := db.Query(query, tableName)
if err != nil {
return "", fmt.Errorf("error querying primary keys for table %s: %v", tableName, err)
}
defer rows.Close()
// Iterate through each primary key constraint found and script it.
for rows.Next() {
var constraintName, constraintDef string
if err := rows.Scan(&constraintName, &constraintDef); err != nil {
return "", fmt.Errorf("error scanning primary key information: %v", err)
}
// Construct the ALTER TABLE statement to add the primary key constraint.
pksSQL.WriteString(fmt.Sprintf(
"ALTER TABLE %s ADD CONSTRAINT %s %s;\n",
escapeReservedName(tableName),
constraintName,
constraintDef,
))
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating over primary keys: %v", err)
}
return pksSQL.String(), nil
}