Remove Repository and use Database instead

This commit is contained in:
Jan Bader 2021-12-11 20:18:09 +00:00
parent d5ebf5a5cf
commit e9adc763b2
14 changed files with 42 additions and 53 deletions

View File

@ -18,20 +18,15 @@ func main() {
bv := &bcrypt.Verifier{} bv := &bcrypt.Verifier{}
q, db, err := postgres.Connect(cfg.DatabaseHost, cfg.DatabaseUser, cfg.DatabasePassword, cfg.DatabaseName) q, err := postgres.Connect(cfg.DatabaseHost, cfg.DatabaseUser, cfg.DatabasePassword, cfg.DatabaseName)
if err != nil { if err != nil {
log.Fatalf("Failed connecting to DB: %v", err) log.Fatalf("Failed connecting to DB: %v", err)
} }
us, err := postgres.NewRepository(q, db)
if err != nil {
log.Fatalf("Failed building Repository: %v", err)
}
tv := &jwt.TokenVerifier{} tv := &jwt.TokenVerifier{}
h := &http.Handler{ h := &http.Handler{
Service: us, Service: q,
TokenVerifier: tv, TokenVerifier: tv,
CredentialsVerifier: bv, CredentialsVerifier: bv,
} }

View File

@ -23,13 +23,13 @@ func (h *Handler) account(c *gin.Context) {
return return
} }
account, err := h.Service.DB.GetAccount(c.Request.Context(), accountUUID) account, err := h.Service.GetAccount(c.Request.Context(), accountUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusNotFound, err) c.AbortWithError(http.StatusNotFound, err)
return return
} }
transactions, err := h.Service.DB.GetTransactionsForAccount(c.Request.Context(), accountUUID) transactions, err := h.Service.GetTransactionsForAccount(c.Request.Context(), accountUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusNotFound, err) c.AbortWithError(http.StatusNotFound, err)
return return

View File

@ -21,11 +21,11 @@ func (h *Handler) admin(c *gin.Context) {
func (h *Handler) clearDatabase(c *gin.Context) { func (h *Handler) clearDatabase(c *gin.Context) {
d := AdminData{} d := AdminData{}
if err := goose.Down(h.Service.LegacyDB, "schema"); err != nil { if err := goose.Down(h.Service.DB, "schema"); err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
} }
if err := goose.Up(h.Service.LegacyDB, "schema"); err != nil { if err := goose.Up(h.Service.DB, "schema"); err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
} }
@ -52,7 +52,7 @@ func (h *Handler) clearBudget(c *gin.Context) {
return return
} }
rows, err := h.Service.DB.DeleteAllAssignments(c.Request.Context(), budgetUUID) rows, err := h.Service.DeleteAllAssignments(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return
@ -60,7 +60,7 @@ func (h *Handler) clearBudget(c *gin.Context) {
fmt.Printf("Deleted %d assignments\n", rows) fmt.Printf("Deleted %d assignments\n", rows)
rows, err = h.Service.DB.DeleteAllTransactions(c.Request.Context(), budgetUUID) rows, err = h.Service.DeleteAllTransactions(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return
@ -77,7 +77,7 @@ func (h *Handler) cleanNegativeBudget(c *gin.Context) {
return return
}*/ }*/
/*min_date, err := h.Service.DB.GetFirstActivity(c.Request.Context(), budgetUUID) /*min_date, err := h.Service.GetFirstActivity(c.Request.Context(), budgetUUID)
date := getFirstOfMonthTime(min_date) date := getFirstOfMonthTime(min_date)
for { for {
nextDate := date.AddDate(0, 1, 0) nextDate := date.AddDate(0, 1, 0)
@ -86,7 +86,7 @@ func (h *Handler) cleanNegativeBudget(c *gin.Context) {
ToDate: nextDate, ToDate: nextDate,
FromDate: date, FromDate: date,
} }
categories, err := h.Service.DB.GetCategoriesWithBalance(c.Request.Context(), params) categories, err := h.Service.GetCategoriesWithBalance(c.Request.Context(), params)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return
@ -104,7 +104,7 @@ func (h *Handler) cleanNegativeBudget(c *gin.Context) {
Amount: negativeAvailable, Amount: negativeAvailable,
CategoryID: category.ID, CategoryID: category.ID,
} }
h.Service.DB.CreateAssignment(c.Request.Context(), createAssignment) h.Service.CreateAssignment(c.Request.Context(), createAssignment)
} }
if nextDate.Before(time.Now()) { if nextDate.Before(time.Now()) {

View File

@ -24,13 +24,13 @@ func (h *Handler) getImportantData(c *gin.Context) {
return return
} }
budget, err := h.Service.DB.GetBudget(c.Request.Context(), budgetUUID) budget, err := h.Service.GetBudget(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusNotFound, err) c.AbortWithError(http.StatusNotFound, err)
return return
} }
accounts, err := h.Service.DB.GetAccountsWithBalance(c.Request.Context(), budgetUUID) accounts, err := h.Service.GetAccountsWithBalance(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -23,7 +23,7 @@ func (h *Handler) allAccounts(c *gin.Context) {
return return
} }
transactions, err := h.Service.DB.GetTransactionsForBudget(c.Request.Context(), budgetUUID) transactions, err := h.Service.GetTransactionsForBudget(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -78,9 +78,9 @@ func (h *Handler) budgeting(c *gin.Context) {
Previous: firstOfPreviousMonth, Previous: firstOfPreviousMonth,
} }
categories, err := h.Service.DB.GetCategories(c.Request.Context(), budgetUUID) categories, err := h.Service.GetCategories(c.Request.Context(), budgetUUID)
cumultativeBalances, err := h.Service.DB.GetCumultativeBalances(c.Request.Context(), budgetUUID) cumultativeBalances, err := h.Service.GetCumultativeBalances(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("load balances: %w", err)) c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("load balances: %w", err))
return return

View File

@ -10,7 +10,7 @@ import (
func (h *Handler) dashboard(c *gin.Context) { func (h *Handler) dashboard(c *gin.Context) {
userID := c.MustGet("token").(budgeteer.Token).GetID() userID := c.MustGet("token").(budgeteer.Token).GetID()
budgets, err := h.Service.DB.GetBudgetsForUser(c.Request.Context(), userID) budgets, err := h.Service.GetBudgetsForUser(c.Request.Context(), userID)
if err != nil { if err != nil {
return return
} }

View File

@ -15,7 +15,7 @@ import (
// Handler handles incoming requests // Handler handles incoming requests
type Handler struct { type Handler struct {
Service *postgres.Repository Service *postgres.Database
TokenVerifier budgeteer.TokenVerifier TokenVerifier budgeteer.TokenVerifier
CredentialsVerifier *bcrypt.Verifier CredentialsVerifier *bcrypt.Verifier
} }

View File

@ -68,7 +68,7 @@ func (h *Handler) loginPost(c *gin.Context) {
username, _ := c.GetPostForm("username") username, _ := c.GetPostForm("username")
password, _ := c.GetPostForm("password") password, _ := c.GetPostForm("password")
user, err := h.Service.DB.GetUserByUsername(c.Request.Context(), username) user, err := h.Service.GetUserByUsername(c.Request.Context(), username)
if err != nil { if err != nil {
c.AbortWithError(http.StatusUnauthorized, err) c.AbortWithError(http.StatusUnauthorized, err)
return return
@ -84,7 +84,7 @@ func (h *Handler) loginPost(c *gin.Context) {
c.AbortWithError(http.StatusUnauthorized, err) c.AbortWithError(http.StatusUnauthorized, err)
} }
go h.Service.DB.UpdateLastLogin(context.Background(), user.ID) go h.Service.UpdateLastLogin(context.Background(), user.ID)
maxAge := (int)((expiration * time.Hour).Seconds()) maxAge := (int)((expiration * time.Hour).Seconds())
c.SetCookie(authCookie, t, maxAge, "", "", false, true) c.SetCookie(authCookie, t, maxAge, "", "", false, true)
@ -98,7 +98,7 @@ func (h *Handler) registerPost(c *gin.Context) {
password, _ := c.GetPostForm("password") password, _ := c.GetPostForm("password")
name, _ := c.GetPostForm("name") name, _ := c.GetPostForm("name")
_, err := h.Service.DB.GetUserByUsername(c.Request.Context(), email) _, err := h.Service.GetUserByUsername(c.Request.Context(), email)
if err == nil { if err == nil {
c.AbortWithStatus(http.StatusUnauthorized) c.AbortWithStatus(http.StatusUnauthorized)
return return
@ -115,7 +115,7 @@ func (h *Handler) registerPost(c *gin.Context) {
Password: hash, Password: hash,
Email: email, Email: email,
} }
_, err = h.Service.DB.CreateUser(c.Request.Context(), createUser) _, err = h.Service.CreateUser(c.Request.Context(), createUser)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
} }

View File

@ -46,7 +46,7 @@ func (h *Handler) newTransaction(c *gin.Context) {
Amount: postgres.Numeric{}, Amount: postgres.Numeric{},
AccountID: transactionAccountID, AccountID: transactionAccountID,
} }
_, err = h.Service.DB.CreateTransaction(c.Request.Context(), new) _, err = h.Service.CreateTransaction(c.Request.Context(), new)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -22,7 +22,7 @@ func (h *Handler) importYNAB(c *gin.Context) {
return return
} }
ynab, err := postgres.NewYNABImport(c.Request.Context(), h.Service.DB, budgetUUID) ynab, err := postgres.NewYNABImport(c.Request.Context(), h.Service.Queries, budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -2,19 +2,22 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"github.com/google/uuid" "github.com/google/uuid"
) )
// NewBudget creates a budget and adds it to the current user // NewBudget creates a budget and adds it to the current user
func (s *Repository) NewBudget(context context.Context, name string, userID uuid.UUID) (*Budget, error) { func (s *Database) NewBudget(context context.Context, name string, userID uuid.UUID) (*Budget, error) {
budget, err := s.DB.CreateBudget(context, name) tx, err := s.BeginTx(context, &sql.TxOptions{})
q := s.WithTx(tx)
budget, err := q.CreateBudget(context, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ub := LinkBudgetToUserParams{UserID: userID, BudgetID: budget.ID} ub := LinkBudgetToUserParams{UserID: userID, BudgetID: budget.ID}
_, err = s.DB.LinkBudgetToUser(context, ub) _, err = q.LinkBudgetToUser(context, ub)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -12,18 +12,26 @@ import (
//go:embed schema/*.sql //go:embed schema/*.sql
var migrations embed.FS var migrations embed.FS
type Database struct {
*Queries
*sql.DB
}
// Connect to a database // Connect to a database
func Connect(server string, user string, password string, database string) (*Queries, *sql.DB, error) { func Connect(server string, user string, password string, database string) (*Database, error) {
connString := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, server, database) connString := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, server, database)
conn, err := sql.Open("pgx", connString) conn, err := sql.Open("pgx", connString)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
goose.SetBaseFS(migrations) goose.SetBaseFS(migrations)
if err = goose.Up(conn, "schema"); err != nil { if err = goose.Up(conn, "schema"); err != nil {
return nil, nil, err return nil, err
} }
return New(conn), conn, nil return &Database{
New(conn),
conn,
}, nil
} }

View File

@ -1,17 +0,0 @@
package postgres
import "database/sql"
// Repository represents a PostgreSQL implementation of all ModelServices
type Repository struct {
DB *Queries
LegacyDB *sql.DB
}
func NewRepository(queries *Queries, db *sql.DB) (*Repository, error) {
repo := &Repository{
DB: queries,
LegacyDB: db,
}
return repo, nil
}