Introduce own Numeric type and revert to stdlib from pgx

This commit is contained in:
Jan Bader 2021-12-02 21:29:44 +00:00
parent 6f8a94ff5d
commit 4646356b2d
16 changed files with 136 additions and 57 deletions

View File

@ -9,7 +9,7 @@ import (
) )
type AccountData struct { type AccountData struct {
Accounts []postgres.GetAccountsRow Accounts []postgres.GetAccountsWithBalanceRow
} }
func (h *Handler) accounts(c *gin.Context) { func (h *Handler) accounts(c *gin.Context) {
@ -20,7 +20,7 @@ func (h *Handler) accounts(c *gin.Context) {
return return
} }
accounts, err := h.Service.DB.GetAccounts(c.Request.Context(), budgetUUID) accounts, err := h.Service.DB.GetAccountsWithBalance(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -12,7 +12,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgtype"
) )
// Handler handles incoming requests // Handler handles incoming requests
@ -160,7 +159,7 @@ func (h *Handler) newTransaction(c *gin.Context) {
new := postgres.CreateTransactionParams{ new := postgres.CreateTransactionParams{
Memo: transactionMemo, Memo: transactionMemo,
Date: transactionDateValue, Date: transactionDateValue,
Amount: pgtype.Numeric{}, Amount: postgres.Numeric{},
AccountID: transactionAccountID, AccountID: transactionAccountID,
} }
_, err = h.Service.DB.CreateTransaction(c.Request.Context(), new) _, err = h.Service.DB.CreateTransaction(c.Request.Context(), new)

View File

@ -11,7 +11,6 @@ import (
"git.javil.eu/jacob1123/budgeteer/postgres" "git.javil.eu/jacob1123/budgeteer/postgres"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgtype"
) )
type YNABImport struct { type YNABImport struct {
@ -115,12 +114,12 @@ func trimLastChar(s string) string {
return s[:len(s)-size] return s[:len(s)-size]
} }
func GetAmount(inflow string, outflow string) (pgtype.Numeric, error) { func GetAmount(inflow string, outflow string) (postgres.Numeric, error) {
// Remove trailing currency // Remove trailing currency
inflow = strings.Replace(trimLastChar(inflow), ",", ".", 1) inflow = strings.Replace(trimLastChar(inflow), ",", ".", 1)
outflow = strings.Replace(trimLastChar(outflow), ",", ".", 1) outflow = strings.Replace(trimLastChar(outflow), ",", ".", 1)
num := pgtype.Numeric{} num := postgres.Numeric{}
err := num.Set(inflow) err := num.Set(inflow)
if err != nil { if err != nil {
return num, fmt.Errorf("Could not parse inflow %s: %w", inflow, err) return num, fmt.Errorf("Could not parse inflow %s: %w", inflow, err)

View File

@ -22,39 +22,72 @@ type CreateAccountParams struct {
} }
func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) { func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) {
row := q.db.QueryRow(ctx, createAccount, arg.Name, arg.BudgetID) row := q.db.QueryRowContext(ctx, createAccount, arg.Name, arg.BudgetID)
var i Account var i Account
err := row.Scan(&i.ID, &i.BudgetID, &i.Name) err := row.Scan(&i.ID, &i.BudgetID, &i.Name)
return i, err return i, err
} }
const getAccounts = `-- name: GetAccounts :many const getAccounts = `-- name: GetAccounts :many
SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts SELECT accounts.id, accounts.budget_id, accounts.name FROM accounts
WHERE accounts.budget_id = $1 WHERE accounts.budget_id = $1
AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name
` `
type GetAccountsRow struct { func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]Account, error) {
ID uuid.UUID rows, err := q.db.QueryContext(ctx, getAccounts, budgetID)
Name string
Balance int64
}
func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsRow, error) {
rows, err := q.db.Query(ctx, getAccounts, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var items []GetAccountsRow var items []Account
for rows.Next() { for rows.Next() {
var i GetAccountsRow var i Account
if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil { if err := rows.Scan(&i.ID, &i.BudgetID, &i.Name); err != nil {
return nil, err return nil, err
} }
items = append(items, i) items = append(items, i)
} }
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAccountsWithBalance = `-- name: GetAccountsWithBalance :many
SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance
FROM accounts
LEFT JOIN transactions ON transactions.account_id = accounts.id
WHERE accounts.budget_id = $1
AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name
`
type GetAccountsWithBalanceRow struct {
ID uuid.UUID
Name string
Balance Numeric
}
func (q *Queries) GetAccountsWithBalance(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsWithBalanceRow, error) {
rows, err := q.db.QueryContext(ctx, getAccountsWithBalance, budgetID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAccountsWithBalanceRow
for rows.Next() {
var i GetAccountsWithBalanceRow
if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
} }

View File

@ -17,7 +17,7 @@ RETURNING id, name, last_modification
` `
func (q *Queries) CreateBudget(ctx context.Context, name string) (Budget, error) { func (q *Queries) CreateBudget(ctx context.Context, name string) (Budget, error) {
row := q.db.QueryRow(ctx, createBudget, name) row := q.db.QueryRowContext(ctx, createBudget, name)
var i Budget var i Budget
err := row.Scan(&i.ID, &i.Name, &i.LastModification) err := row.Scan(&i.ID, &i.Name, &i.LastModification)
return i, err return i, err
@ -29,7 +29,7 @@ WHERE id = $1
` `
func (q *Queries) GetBudget(ctx context.Context, id uuid.UUID) (Budget, error) { func (q *Queries) GetBudget(ctx context.Context, id uuid.UUID) (Budget, error) {
row := q.db.QueryRow(ctx, getBudget, id) row := q.db.QueryRowContext(ctx, getBudget, id)
var i Budget var i Budget
err := row.Scan(&i.ID, &i.Name, &i.LastModification) err := row.Scan(&i.ID, &i.Name, &i.LastModification)
return i, err return i, err
@ -42,7 +42,7 @@ WHERE user_budgets.user_id = $1
` `
func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Budget, error) { func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Budget, error) {
rows, err := q.db.Query(ctx, getBudgetsForUser, userID) rows, err := q.db.QueryContext(ctx, getBudgetsForUser, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,6 +55,9 @@ func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Bu
} }
items = append(items, i) items = append(items, i)
} }
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
} }

View File

@ -1,12 +1,10 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"embed" "embed"
"fmt" "fmt"
"github.com/jackc/pgx/v4"
_ "github.com/jackc/pgx/v4/stdlib" _ "github.com/jackc/pgx/v4/stdlib"
"github.com/pressly/goose/v3" "github.com/pressly/goose/v3"
) )
@ -27,12 +25,7 @@ func Connect(server string, user string, password string, database string) (*Que
return nil, nil, err return nil, nil, err
} }
connPG, err := pgx.Connect(context.Background(), connString) return New(conn), conn, nil
if err != nil {
return nil, nil, err
}
return New(connPG), conn, nil
} }
func (tx Transaction) GetAmount() float64 { func (tx Transaction) GetAmount() float64 {
@ -62,3 +55,17 @@ func (tx GetTransactionsForBudgetRow) GetPositive() bool {
amount := tx.GetAmount() amount := tx.GetAmount()
return amount >= 0 return amount >= 0
} }
func (tx GetAccountsWithBalanceRow) GetBalance() float64 {
var balance float64
err := tx.Balance.AssignTo(&balance)
if err != nil {
panic(err)
}
return balance
}
func (tx GetAccountsWithBalanceRow) GetPositive() bool {
balance := tx.GetBalance()
return balance >= 0
}

View File

@ -4,15 +4,14 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
) )
type DBTX interface { type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error) PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
} }
func New(db DBTX) *Queries { func New(db DBTX) *Queries {
@ -23,7 +22,7 @@ type Queries struct {
db DBTX db DBTX
} }
func (q *Queries) WithTx(tx pgx.Tx) *Queries { func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{ return &Queries{
db: tx, db: tx,
} }

View File

@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgtype"
) )
type Account struct { type Account struct {
@ -32,7 +31,7 @@ type Transaction struct {
ID uuid.UUID ID uuid.UUID
Date time.Time Date time.Time
Memo string Memo string
Amount pgtype.Numeric Amount Numeric
AccountID uuid.UUID AccountID uuid.UUID
PayeeID uuid.NullUUID PayeeID uuid.NullUUID
} }

21
postgres/numeric.go Normal file
View File

@ -0,0 +1,21 @@
package postgres
import "github.com/jackc/pgtype"
type Numeric struct {
pgtype.Numeric
}
func (n Numeric) GetFloat64() float64 {
var balance float64
err := n.AssignTo(&balance)
if err != nil {
panic(err)
}
return balance
}
func (n Numeric) GetPositive() bool {
float := n.GetFloat64()
return float >= 0
}

View File

@ -22,7 +22,7 @@ type CreatePayeeParams struct {
} }
func (q *Queries) CreatePayee(ctx context.Context, arg CreatePayeeParams) (Payee, error) { func (q *Queries) CreatePayee(ctx context.Context, arg CreatePayeeParams) (Payee, error) {
row := q.db.QueryRow(ctx, createPayee, arg.Name, arg.BudgetID) row := q.db.QueryRowContext(ctx, createPayee, arg.Name, arg.BudgetID)
var i Payee var i Payee
err := row.Scan(&i.ID, &i.BudgetID, &i.Name) err := row.Scan(&i.ID, &i.BudgetID, &i.Name)
return i, err return i, err
@ -34,7 +34,7 @@ WHERE payees.budget_id = $1
` `
func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, error) { func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, error) {
rows, err := q.db.Query(ctx, getPayees, budgetID) rows, err := q.db.QueryContext(ctx, getPayees, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -47,6 +47,9 @@ func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, e
} }
items = append(items, i) items = append(items, i)
} }
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
} }

View File

@ -5,7 +5,13 @@ VALUES ($1, $2)
RETURNING *; RETURNING *;
-- name: GetAccounts :many -- name: GetAccounts :many
SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts SELECT accounts.* FROM accounts
WHERE accounts.budget_id = $1;
-- name: GetAccountsWithBalance :many
SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance
FROM accounts
LEFT JOIN transactions ON transactions.account_id = accounts.id
WHERE accounts.budget_id = $1 WHERE accounts.budget_id = $1
AND transactions.date < NOW() AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name; GROUP BY accounts.id, accounts.name;

View File

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgtype"
) )
const createTransaction = `-- name: CreateTransaction :one const createTransaction = `-- name: CreateTransaction :one
@ -21,13 +20,13 @@ RETURNING id, date, memo, amount, account_id, payee_id
type CreateTransactionParams struct { type CreateTransactionParams struct {
Date time.Time Date time.Time
Memo string Memo string
Amount pgtype.Numeric Amount Numeric
AccountID uuid.UUID AccountID uuid.UUID
PayeeID uuid.NullUUID PayeeID uuid.NullUUID
} }
func (q *Queries) CreateTransaction(ctx context.Context, arg CreateTransactionParams) (Transaction, error) { func (q *Queries) CreateTransaction(ctx context.Context, arg CreateTransactionParams) (Transaction, error) {
row := q.db.QueryRow(ctx, createTransaction, row := q.db.QueryRowContext(ctx, createTransaction,
arg.Date, arg.Date,
arg.Memo, arg.Memo,
arg.Amount, arg.Amount,
@ -54,7 +53,7 @@ LIMIT 200
` `
func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.UUID) ([]Transaction, error) { func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.UUID) ([]Transaction, error) {
rows, err := q.db.Query(ctx, getTransactionsForAccount, accountID) rows, err := q.db.QueryContext(ctx, getTransactionsForAccount, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,6 +73,9 @@ func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.
} }
items = append(items, i) items = append(items, i)
} }
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
} }
@ -95,13 +97,13 @@ type GetTransactionsForBudgetRow struct {
ID uuid.UUID ID uuid.UUID
Date time.Time Date time.Time
Memo string Memo string
Amount pgtype.Numeric Amount Numeric
Account string Account string
Payee string Payee string
} }
func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UUID) ([]GetTransactionsForBudgetRow, error) { func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UUID) ([]GetTransactionsForBudgetRow, error) {
rows, err := q.db.Query(ctx, getTransactionsForBudget, budgetID) rows, err := q.db.QueryContext(ctx, getTransactionsForBudget, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,6 +123,9 @@ func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UU
} }
items = append(items, i) items = append(items, i)
} }
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, err
} }

View File

@ -22,7 +22,7 @@ type LinkBudgetToUserParams struct {
} }
func (q *Queries) LinkBudgetToUser(ctx context.Context, arg LinkBudgetToUserParams) (UserBudget, error) { func (q *Queries) LinkBudgetToUser(ctx context.Context, arg LinkBudgetToUserParams) (UserBudget, error) {
row := q.db.QueryRow(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID) row := q.db.QueryRowContext(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID)
var i UserBudget var i UserBudget
err := row.Scan(&i.UserID, &i.BudgetID) err := row.Scan(&i.UserID, &i.BudgetID)
return i, err return i, err

View File

@ -23,7 +23,7 @@ type CreateUserParams struct {
} }
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRow(ctx, createUser, arg.Email, arg.Name, arg.Password) row := q.db.QueryRowContext(ctx, createUser, arg.Email, arg.Name, arg.Password)
var i User var i User
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
@ -40,7 +40,7 @@ WHERE id = $1
` `
func (q *Queries) GetUser(ctx context.Context, id uuid.UUID) (User, error) { func (q *Queries) GetUser(ctx context.Context, id uuid.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUser, id) row := q.db.QueryRowContext(ctx, getUser, id)
var i User var i User
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,
@ -57,7 +57,7 @@ WHERE email = $1
` `
func (q *Queries) GetUserByUsername(ctx context.Context, email string) (User, error) { func (q *Queries) GetUserByUsername(ctx context.Context, email string) (User, error) {
row := q.db.QueryRow(ctx, getUserByUsername, email) row := q.db.QueryRowContext(ctx, getUserByUsername, email)
var i User var i User
err := row.Scan( err := row.Scan(
&i.ID, &i.ID,

View File

@ -5,4 +5,9 @@ packages:
engine: "postgresql" engine: "postgresql"
schema: "postgres/schema/" schema: "postgres/schema/"
queries: "postgres/queries/" queries: "postgres/queries/"
sql_package: "pgx/v4" overrides:
- go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric"
db_type: "pg_catalog.numeric"
- go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric"
db_type: "pg_catalog.numeric"
nullable: true

View File

@ -11,7 +11,7 @@
{{range .Accounts}} {{range .Accounts}}
<div class="budget-item"> <div class="budget-item">
<a href="account/{{.ID}}">{{.Name}}</a> <a href="account/{{.ID}}">{{.Name}}</a>
<span class="time">{{.Balance}}</span> <span class="time">{{printf "%.2f" .GetBalance}}</span>
</div> </div>
{{end}} {{end}}
{{end}} {{end}}