Also return sql.DB to be able to use goose

This commit is contained in:
Jan Bader 2021-12-02 20:36:22 +00:00
parent a3df95a700
commit 4011f3cace
3 changed files with 14 additions and 15 deletions

View File

@ -18,12 +18,12 @@ func main() {
bv := &bcrypt.Verifier{} bv := &bcrypt.Verifier{}
db, err := postgres.Connect(cfg.DatabaseHost, cfg.DatabaseUser, cfg.DatabasePassword, cfg.DatabaseName) q, db, err := postgres.Connect(cfg.DatabaseHost, cfg.DatabaseUser, cfg.DatabasePassword, cfg.DatabaseName)
if err != nil { if err != nil {
log.Fatalf("Failed connecting to DB: %v", err) log.Fatalf("Failed connecting to DB: %v", err)
} }
us, err := postgres.NewRepository(db) us, err := postgres.NewRepository(q, db)
if err != nil { if err != nil {
log.Fatalf("Failed building Repository: %v", err) log.Fatalf("Failed building Repository: %v", err)
} }

View File

@ -15,29 +15,24 @@ import (
var migrations embed.FS var migrations embed.FS
// Connect to a database // Connect to a database
func Connect(server string, user string, password string, database string) (*Queries, error) { func Connect(server string, user string, password string, database string) (*Queries, *sql.DB, error) {
connString := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, server, database) connString := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, server, database)
conn, err := sql.Open("pgx", connString) conn, err := sql.Open("pgx", connString)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
goose.SetBaseFS(migrations) goose.SetBaseFS(migrations)
if err = goose.Up(conn, "schema"); err != nil { if err = goose.Up(conn, "schema"); err != nil {
return nil, err return nil, nil, err
}
err = conn.Close()
if err != nil {
return nil, err
} }
connPG, err := pgx.Connect(context.Background(), connString) connPG, err := pgx.Connect(context.Background(), connString)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return New(connPG), nil return New(connPG), conn, nil
} }
func (tx Transaction) GetAmount() float64 { func (tx Transaction) GetAmount() float64 {

View File

@ -1,13 +1,17 @@
package postgres package postgres
import "database/sql"
// Repository represents a PostgreSQL implementation of all ModelServices // Repository represents a PostgreSQL implementation of all ModelServices
type Repository struct { type Repository struct {
DB *Queries DB *Queries
LegacyDB *sql.DB
} }
func NewRepository(queries *Queries) (*Repository, error) { func NewRepository(queries *Queries, db *sql.DB) (*Repository, error) {
repo := &Repository{ repo := &Repository{
DB: queries, DB: queries,
LegacyDB: db,
} }
return repo, nil return repo, nil
} }