Introduce own Numeric type and revert to stdlib from pgx

This commit is contained in:
Jan Bader 2021-12-02 21:29:44 +00:00
parent 6f8a94ff5d
commit 4646356b2d
16 changed files with 136 additions and 57 deletions

View File

@ -9,7 +9,7 @@ import (
)
type AccountData struct {
Accounts []postgres.GetAccountsRow
Accounts []postgres.GetAccountsWithBalanceRow
}
func (h *Handler) accounts(c *gin.Context) {
@ -20,7 +20,7 @@ func (h *Handler) accounts(c *gin.Context) {
return
}
accounts, err := h.Service.DB.GetAccounts(c.Request.Context(), budgetUUID)
accounts, err := h.Service.DB.GetAccountsWithBalance(c.Request.Context(), budgetUUID)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return

View File

@ -12,7 +12,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/jackc/pgtype"
)
// Handler handles incoming requests
@ -160,7 +159,7 @@ func (h *Handler) newTransaction(c *gin.Context) {
new := postgres.CreateTransactionParams{
Memo: transactionMemo,
Date: transactionDateValue,
Amount: pgtype.Numeric{},
Amount: postgres.Numeric{},
AccountID: transactionAccountID,
}
_, err = h.Service.DB.CreateTransaction(c.Request.Context(), new)

View File

@ -11,7 +11,6 @@ import (
"git.javil.eu/jacob1123/budgeteer/postgres"
"github.com/google/uuid"
"github.com/jackc/pgtype"
)
type YNABImport struct {
@ -115,12 +114,12 @@ func trimLastChar(s string) string {
return s[:len(s)-size]
}
func GetAmount(inflow string, outflow string) (pgtype.Numeric, error) {
func GetAmount(inflow string, outflow string) (postgres.Numeric, error) {
// Remove trailing currency
inflow = strings.Replace(trimLastChar(inflow), ",", ".", 1)
outflow = strings.Replace(trimLastChar(outflow), ",", ".", 1)
num := pgtype.Numeric{}
num := postgres.Numeric{}
err := num.Set(inflow)
if err != nil {
return num, fmt.Errorf("Could not parse inflow %s: %w", inflow, err)

View File

@ -22,39 +22,72 @@ type CreateAccountParams struct {
}
func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) {
row := q.db.QueryRow(ctx, createAccount, arg.Name, arg.BudgetID)
row := q.db.QueryRowContext(ctx, createAccount, arg.Name, arg.BudgetID)
var i Account
err := row.Scan(&i.ID, &i.BudgetID, &i.Name)
return i, err
}
const getAccounts = `-- name: GetAccounts :many
SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts
SELECT accounts.id, accounts.budget_id, accounts.name FROM accounts
WHERE accounts.budget_id = $1
AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name
`
type GetAccountsRow struct {
ID uuid.UUID
Name string
Balance int64
}
func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsRow, error) {
rows, err := q.db.Query(ctx, getAccounts, budgetID)
func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]Account, error) {
rows, err := q.db.QueryContext(ctx, getAccounts, budgetID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAccountsRow
var items []Account
for rows.Next() {
var i GetAccountsRow
if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil {
var i Account
if err := rows.Scan(&i.ID, &i.BudgetID, &i.Name); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAccountsWithBalance = `-- name: GetAccountsWithBalance :many
SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance
FROM accounts
LEFT JOIN transactions ON transactions.account_id = accounts.id
WHERE accounts.budget_id = $1
AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name
`
type GetAccountsWithBalanceRow struct {
ID uuid.UUID
Name string
Balance Numeric
}
func (q *Queries) GetAccountsWithBalance(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsWithBalanceRow, error) {
rows, err := q.db.QueryContext(ctx, getAccountsWithBalance, budgetID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetAccountsWithBalanceRow
for rows.Next() {
var i GetAccountsWithBalanceRow
if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}

View File

@ -17,7 +17,7 @@ RETURNING id, name, last_modification
`
func (q *Queries) CreateBudget(ctx context.Context, name string) (Budget, error) {
row := q.db.QueryRow(ctx, createBudget, name)
row := q.db.QueryRowContext(ctx, createBudget, name)
var i Budget
err := row.Scan(&i.ID, &i.Name, &i.LastModification)
return i, err
@ -29,7 +29,7 @@ WHERE id = $1
`
func (q *Queries) GetBudget(ctx context.Context, id uuid.UUID) (Budget, error) {
row := q.db.QueryRow(ctx, getBudget, id)
row := q.db.QueryRowContext(ctx, getBudget, id)
var i Budget
err := row.Scan(&i.ID, &i.Name, &i.LastModification)
return i, err
@ -42,7 +42,7 @@ WHERE user_budgets.user_id = $1
`
func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Budget, error) {
rows, err := q.db.Query(ctx, getBudgetsForUser, userID)
rows, err := q.db.QueryContext(ctx, getBudgetsForUser, userID)
if err != nil {
return nil, err
}
@ -55,6 +55,9 @@ func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Bu
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}

View File

@ -1,12 +1,10 @@
package postgres
import (
"context"
"database/sql"
"embed"
"fmt"
"github.com/jackc/pgx/v4"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/pressly/goose/v3"
)
@ -27,12 +25,7 @@ func Connect(server string, user string, password string, database string) (*Que
return nil, nil, err
}
connPG, err := pgx.Connect(context.Background(), connString)
if err != nil {
return nil, nil, err
}
return New(connPG), conn, nil
return New(conn), conn, nil
}
func (tx Transaction) GetAmount() float64 {
@ -62,3 +55,17 @@ func (tx GetTransactionsForBudgetRow) GetPositive() bool {
amount := tx.GetAmount()
return amount >= 0
}
func (tx GetAccountsWithBalanceRow) GetBalance() float64 {
var balance float64
err := tx.Balance.AssignTo(&balance)
if err != nil {
panic(err)
}
return balance
}
func (tx GetAccountsWithBalanceRow) GetPositive() bool {
balance := tx.GetBalance()
return balance >= 0
}

View File

@ -4,15 +4,14 @@ package postgres
import (
"context"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"database/sql"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
@ -23,7 +22,7 @@ type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}

View File

@ -7,7 +7,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/jackc/pgtype"
)
type Account struct {
@ -32,7 +31,7 @@ type Transaction struct {
ID uuid.UUID
Date time.Time
Memo string
Amount pgtype.Numeric
Amount Numeric
AccountID uuid.UUID
PayeeID uuid.NullUUID
}

21
postgres/numeric.go Normal file
View File

@ -0,0 +1,21 @@
package postgres
import "github.com/jackc/pgtype"
type Numeric struct {
pgtype.Numeric
}
func (n Numeric) GetFloat64() float64 {
var balance float64
err := n.AssignTo(&balance)
if err != nil {
panic(err)
}
return balance
}
func (n Numeric) GetPositive() bool {
float := n.GetFloat64()
return float >= 0
}

View File

@ -22,7 +22,7 @@ type CreatePayeeParams struct {
}
func (q *Queries) CreatePayee(ctx context.Context, arg CreatePayeeParams) (Payee, error) {
row := q.db.QueryRow(ctx, createPayee, arg.Name, arg.BudgetID)
row := q.db.QueryRowContext(ctx, createPayee, arg.Name, arg.BudgetID)
var i Payee
err := row.Scan(&i.ID, &i.BudgetID, &i.Name)
return i, err
@ -34,7 +34,7 @@ WHERE payees.budget_id = $1
`
func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, error) {
rows, err := q.db.Query(ctx, getPayees, budgetID)
rows, err := q.db.QueryContext(ctx, getPayees, budgetID)
if err != nil {
return nil, err
}
@ -47,6 +47,9 @@ func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, e
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}

View File

@ -5,7 +5,13 @@ VALUES ($1, $2)
RETURNING *;
-- name: GetAccounts :many
SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts
SELECT accounts.* FROM accounts
WHERE accounts.budget_id = $1;
-- name: GetAccountsWithBalance :many
SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance
FROM accounts
LEFT JOIN transactions ON transactions.account_id = accounts.id
WHERE accounts.budget_id = $1
AND transactions.date < NOW()
GROUP BY accounts.id, accounts.name;

View File

@ -8,7 +8,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/jackc/pgtype"
)
const createTransaction = `-- name: CreateTransaction :one
@ -21,13 +20,13 @@ RETURNING id, date, memo, amount, account_id, payee_id
type CreateTransactionParams struct {
Date time.Time
Memo string
Amount pgtype.Numeric
Amount Numeric
AccountID uuid.UUID
PayeeID uuid.NullUUID
}
func (q *Queries) CreateTransaction(ctx context.Context, arg CreateTransactionParams) (Transaction, error) {
row := q.db.QueryRow(ctx, createTransaction,
row := q.db.QueryRowContext(ctx, createTransaction,
arg.Date,
arg.Memo,
arg.Amount,
@ -54,7 +53,7 @@ LIMIT 200
`
func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.UUID) ([]Transaction, error) {
rows, err := q.db.Query(ctx, getTransactionsForAccount, accountID)
rows, err := q.db.QueryContext(ctx, getTransactionsForAccount, accountID)
if err != nil {
return nil, err
}
@ -74,6 +73,9 @@ func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
@ -95,13 +97,13 @@ type GetTransactionsForBudgetRow struct {
ID uuid.UUID
Date time.Time
Memo string
Amount pgtype.Numeric
Amount Numeric
Account string
Payee string
}
func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UUID) ([]GetTransactionsForBudgetRow, error) {
rows, err := q.db.Query(ctx, getTransactionsForBudget, budgetID)
rows, err := q.db.QueryContext(ctx, getTransactionsForBudget, budgetID)
if err != nil {
return nil, err
}
@ -121,6 +123,9 @@ func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UU
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}

View File

@ -22,7 +22,7 @@ type LinkBudgetToUserParams struct {
}
func (q *Queries) LinkBudgetToUser(ctx context.Context, arg LinkBudgetToUserParams) (UserBudget, error) {
row := q.db.QueryRow(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID)
row := q.db.QueryRowContext(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID)
var i UserBudget
err := row.Scan(&i.UserID, &i.BudgetID)
return i, err

View File

@ -23,7 +23,7 @@ type CreateUserParams struct {
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRow(ctx, createUser, arg.Email, arg.Name, arg.Password)
row := q.db.QueryRowContext(ctx, createUser, arg.Email, arg.Name, arg.Password)
var i User
err := row.Scan(
&i.ID,
@ -40,7 +40,7 @@ WHERE id = $1
`
func (q *Queries) GetUser(ctx context.Context, id uuid.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUser, id)
row := q.db.QueryRowContext(ctx, getUser, id)
var i User
err := row.Scan(
&i.ID,
@ -57,7 +57,7 @@ WHERE email = $1
`
func (q *Queries) GetUserByUsername(ctx context.Context, email string) (User, error) {
row := q.db.QueryRow(ctx, getUserByUsername, email)
row := q.db.QueryRowContext(ctx, getUserByUsername, email)
var i User
err := row.Scan(
&i.ID,

View File

@ -5,4 +5,9 @@ packages:
engine: "postgresql"
schema: "postgres/schema/"
queries: "postgres/queries/"
sql_package: "pgx/v4"
overrides:
- go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric"
db_type: "pg_catalog.numeric"
- go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric"
db_type: "pg_catalog.numeric"
nullable: true

View File

@ -11,7 +11,7 @@
{{range .Accounts}}
<div class="budget-item">
<a href="account/{{.ID}}">{{.Name}}</a>
<span class="time">{{.Balance}}</span>
<span class="time">{{printf "%.2f" .GetBalance}}</span>
</div>
{{end}}
{{end}}