Updated
This commit is contained in:
286
dumper.go
Normal file
286
dumper.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package pgdump
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type Dumper struct {
|
||||
ConnectionString string
|
||||
Parallels int
|
||||
DumpVersion string
|
||||
}
|
||||
|
||||
func NewDumper(connectionString string, threads int) *Dumper {
|
||||
// Version number of go-pgdump, used in the template after a dump
|
||||
dumpVersion := "1.1.0"
|
||||
|
||||
// set a default value for Parallels if it is zero or less
|
||||
if threads <= 0 {
|
||||
threads = 50
|
||||
}
|
||||
return &Dumper{ConnectionString: connectionString, Parallels: threads, DumpVersion: dumpVersion}
|
||||
}
|
||||
|
||||
func (d *Dumper) DumpDatabaseToWriter(writer io.Writer, opts *TableOptions) error {
|
||||
db, err := sql.Open("postgres", d.ConnectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Template variables
|
||||
info := DumpInfo{
|
||||
DumpVersion: d.DumpVersion,
|
||||
ServerVersion: getServerVersion(db),
|
||||
CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"),
|
||||
ThreadsNumber: d.Parallels,
|
||||
}
|
||||
|
||||
if err := writeHeader(writer, info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tables, err := getTables(db, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
mx sync.Mutex
|
||||
)
|
||||
|
||||
chunks := slices.Chunk(tables, d.Parallels)
|
||||
for chunk := range chunks {
|
||||
wg.Add(len(chunk))
|
||||
for _, table := range chunk {
|
||||
//we can add the switch here for export and add a go func here.
|
||||
go func(table string) {
|
||||
defer wg.Done()
|
||||
str, err := scriptTable(db, table)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mx.Lock()
|
||||
io.WriteString(writer, str)
|
||||
mx.Unlock()
|
||||
}(table)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
if err := writeFooter(writer, info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Dumper) DumpDatabase(outputFile string, opts *TableOptions) error {
|
||||
file, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return d.DumpDatabaseToWriter(file, opts)
|
||||
}
|
||||
|
||||
func (d *Dumper) DumpDBToCSV(outputDIR, outputFile string, opts *TableOptions) error {
|
||||
db, err := sql.Open("postgres", d.ConnectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
file, err := os.Create(path.Join(outputDIR, outputFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Template variables
|
||||
info := DumpInfo{
|
||||
DumpVersion: d.DumpVersion,
|
||||
ServerVersion: getServerVersion(db),
|
||||
CompleteTime: time.Now().Format("2006-01-02 15:04:05 -0700 MST"),
|
||||
ThreadsNumber: d.Parallels,
|
||||
}
|
||||
|
||||
if err := writeHeader(file, info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeFooter(file, info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablename, err := getTables(db, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunks := slices.Chunk(tablename, d.Parallels)
|
||||
g, _ := errgroup.WithContext(context.Background())
|
||||
for chunk := range chunks {
|
||||
g.SetLimit(len(chunk))
|
||||
for _, table := range chunk {
|
||||
table := table // capture the current value of table for use in goroutine
|
||||
g.Go(func() error {
|
||||
records, err := getTableDataAsCSV(db, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Correctly open (or create) the file for writing
|
||||
f, err := os.Create(path.Join(outputDIR, table+".csv"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
csvWriter := csv.NewWriter(f)
|
||||
if err := csvWriter.WriteAll(records); err != nil {
|
||||
return err
|
||||
}
|
||||
csvWriter.Flush()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scriptTable(db *sql.DB, tableName string) (string, error) {
|
||||
var buffer string
|
||||
// Script CREATE TABLE statement
|
||||
createStmt, err := getCreateTableStatement(db, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating table statement for %s: %v", tableName, err)
|
||||
}
|
||||
buffer = buffer + createStmt + "\n\n"
|
||||
|
||||
// Script associated sequences (if any)
|
||||
seqStmts, err := scriptSequences(db, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error scripting sequences for table %s: %v", tableName, err)
|
||||
}
|
||||
buffer = buffer + seqStmts + "\n\n"
|
||||
|
||||
// Script primary keys
|
||||
pkStmt, err := scriptPrimaryKeys(db, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error scripting primary keys for table %s: %v", tableName, err)
|
||||
}
|
||||
buffer = buffer + pkStmt + "\n\n"
|
||||
|
||||
// Dump table data
|
||||
copyStmt, err := getTableDataCopyFormat(db, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating COPY statement for table %s: %v", tableName, err)
|
||||
}
|
||||
buffer = buffer + copyStmt + "\n\n"
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
func scriptSequences(db *sql.DB, tableName string) (string, error) {
|
||||
var sequencesSQL strings.Builder
|
||||
|
||||
// Query to identify sequences linked to the table's columns and fetch sequence definitions
|
||||
query := `
|
||||
SELECT 'CREATE SEQUENCE ' || n.nspname || '.' || c.relname || ';' as seq_creation,
|
||||
pg_get_serial_sequence(quote_ident(n.nspname) || '.' || quote_ident(t.relname), quote_ident(a.attname)) as seq_owned,
|
||||
'ALTER TABLE ' || quote_ident(n.nspname) || '.' || quote_ident(t.relname) ||
|
||||
' ALTER COLUMN ' || quote_ident(a.attname) ||
|
||||
' SET DEFAULT nextval(''' || n.nspname || '.' || c.relname || '''::regclass);' as col_default
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
JOIN pg_depend d ON d.objid = c.oid AND d.deptype = 'a' AND d.classid = 'pg_class'::regclass
|
||||
JOIN pg_attrdef ad ON ad.adrelid = d.refobjid AND ad.adnum = d.refobjsubid
|
||||
JOIN pg_attribute a ON a.attrelid = d.refobjid AND a.attnum = d.refobjsubid
|
||||
JOIN pg_class t ON t.oid = d.refobjid AND t.relkind = 'r'
|
||||
WHERE c.relkind = 'S' AND t.relname = $1 AND n.nspname = 'public';
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error querying sequences for table %s: %v", tableName, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var seqCreation, seqOwned, colDefault string
|
||||
if err := rows.Scan(&seqCreation, &seqOwned, &colDefault); err != nil {
|
||||
return "", fmt.Errorf("error scanning sequence information: %v", err)
|
||||
}
|
||||
|
||||
// Here we directly use the sequence creation script.
|
||||
// The seqOwned might not be necessary if we're focusing on creation and default value setting.
|
||||
sequencesSQL.WriteString(seqCreation + "\n" + colDefault + "\n")
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating over sequences: %v", err)
|
||||
}
|
||||
|
||||
return sequencesSQL.String(), nil
|
||||
}
|
||||
|
||||
func scriptPrimaryKeys(db *sql.DB, tableName string) (string, error) {
|
||||
var pksSQL strings.Builder
|
||||
|
||||
// Query to find primary key constraints for the specified table.
|
||||
query := `
|
||||
SELECT con.conname AS constraint_name,
|
||||
pg_get_constraintdef(con.oid) AS constraint_def
|
||||
FROM pg_constraint con
|
||||
JOIN pg_class rel ON rel.oid = con.conrelid
|
||||
JOIN pg_namespace nsp ON nsp.oid = connamespace
|
||||
WHERE con.contype = 'p'
|
||||
AND rel.relname = $1
|
||||
AND nsp.nspname = 'public';
|
||||
`
|
||||
rows, err := db.Query(query, tableName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error querying primary keys for table %s: %v", tableName, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Iterate through each primary key constraint found and script it.
|
||||
for rows.Next() {
|
||||
var constraintName, constraintDef string
|
||||
if err := rows.Scan(&constraintName, &constraintDef); err != nil {
|
||||
return "", fmt.Errorf("error scanning primary key information: %v", err)
|
||||
}
|
||||
|
||||
// Construct the ALTER TABLE statement to add the primary key constraint.
|
||||
pksSQL.WriteString(fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD CONSTRAINT %s %s;\n",
|
||||
escapeReservedName(tableName),
|
||||
constraintName,
|
||||
constraintDef,
|
||||
))
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating over primary keys: %v", err)
|
||||
}
|
||||
|
||||
return pksSQL.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user