diff --git a/http/accounts.go b/http/accounts.go index 53934d7..1a77b12 100644 --- a/http/accounts.go +++ b/http/accounts.go @@ -9,7 +9,7 @@ import ( ) type AccountData struct { - Accounts []postgres.GetAccountsRow + Accounts []postgres.GetAccountsWithBalanceRow } func (h *Handler) accounts(c *gin.Context) { @@ -20,7 +20,7 @@ func (h *Handler) accounts(c *gin.Context) { return } - accounts, err := h.Service.DB.GetAccounts(c.Request.Context(), budgetUUID) + accounts, err := h.Service.DB.GetAccountsWithBalance(c.Request.Context(), budgetUUID) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return diff --git a/http/http.go b/http/http.go index b4ddae8..57f8cb1 100644 --- a/http/http.go +++ b/http/http.go @@ -12,7 +12,6 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/jackc/pgtype" ) // Handler handles incoming requests @@ -160,7 +159,7 @@ func (h *Handler) newTransaction(c *gin.Context) { new := postgres.CreateTransactionParams{ Memo: transactionMemo, Date: transactionDateValue, - Amount: pgtype.Numeric{}, + Amount: postgres.Numeric{}, AccountID: transactionAccountID, } _, err = h.Service.DB.CreateTransaction(c.Request.Context(), new) diff --git a/http/ynab-import.go b/http/ynab-import.go index 9d9cbb4..ce91460 100644 --- a/http/ynab-import.go +++ b/http/ynab-import.go @@ -11,7 +11,6 @@ import ( "git.javil.eu/jacob1123/budgeteer/postgres" "github.com/google/uuid" - "github.com/jackc/pgtype" ) type YNABImport struct { @@ -115,12 +114,12 @@ func trimLastChar(s string) string { return s[:len(s)-size] } -func GetAmount(inflow string, outflow string) (pgtype.Numeric, error) { +func GetAmount(inflow string, outflow string) (postgres.Numeric, error) { // Remove trailing currency inflow = strings.Replace(trimLastChar(inflow), ",", ".", 1) outflow = strings.Replace(trimLastChar(outflow), ",", ".", 1) - num := pgtype.Numeric{} + num := postgres.Numeric{} err := num.Set(inflow) if err != nil { return num, fmt.Errorf("Could not parse inflow %s: %w", inflow, err) diff --git a/postgres/accounts.sql.go b/postgres/accounts.sql.go index 382aecc..7751530 100644 --- a/postgres/accounts.sql.go +++ b/postgres/accounts.sql.go @@ -22,39 +22,72 @@ type CreateAccountParams struct { } func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) { - row := q.db.QueryRow(ctx, createAccount, arg.Name, arg.BudgetID) + row := q.db.QueryRowContext(ctx, createAccount, arg.Name, arg.BudgetID) var i Account err := row.Scan(&i.ID, &i.BudgetID, &i.Name) return i, err } const getAccounts = `-- name: GetAccounts :many -SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts +SELECT accounts.id, accounts.budget_id, accounts.name FROM accounts WHERE accounts.budget_id = $1 -AND transactions.date < NOW() -GROUP BY accounts.id, accounts.name ` -type GetAccountsRow struct { - ID uuid.UUID - Name string - Balance int64 -} - -func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsRow, error) { - rows, err := q.db.Query(ctx, getAccounts, budgetID) +func (q *Queries) GetAccounts(ctx context.Context, budgetID uuid.UUID) ([]Account, error) { + rows, err := q.db.QueryContext(ctx, getAccounts, budgetID) if err != nil { return nil, err } defer rows.Close() - var items []GetAccountsRow + var items []Account for rows.Next() { - var i GetAccountsRow - if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil { + var i Account + if err := rows.Scan(&i.ID, &i.BudgetID, &i.Name); err != nil { return nil, err } items = append(items, i) } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAccountsWithBalance = `-- name: GetAccountsWithBalance :many +SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance +FROM accounts +LEFT JOIN transactions ON transactions.account_id = accounts.id +WHERE accounts.budget_id = $1 +AND transactions.date < NOW() +GROUP BY accounts.id, accounts.name +` + +type GetAccountsWithBalanceRow struct { + ID uuid.UUID + Name string + Balance Numeric +} + +func (q *Queries) GetAccountsWithBalance(ctx context.Context, budgetID uuid.UUID) ([]GetAccountsWithBalanceRow, error) { + rows, err := q.db.QueryContext(ctx, getAccountsWithBalance, budgetID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAccountsWithBalanceRow + for rows.Next() { + var i GetAccountsWithBalanceRow + if err := rows.Scan(&i.ID, &i.Name, &i.Balance); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } diff --git a/postgres/budgets.sql.go b/postgres/budgets.sql.go index 328941d..e4827e1 100644 --- a/postgres/budgets.sql.go +++ b/postgres/budgets.sql.go @@ -17,7 +17,7 @@ RETURNING id, name, last_modification ` func (q *Queries) CreateBudget(ctx context.Context, name string) (Budget, error) { - row := q.db.QueryRow(ctx, createBudget, name) + row := q.db.QueryRowContext(ctx, createBudget, name) var i Budget err := row.Scan(&i.ID, &i.Name, &i.LastModification) return i, err @@ -29,7 +29,7 @@ WHERE id = $1 ` func (q *Queries) GetBudget(ctx context.Context, id uuid.UUID) (Budget, error) { - row := q.db.QueryRow(ctx, getBudget, id) + row := q.db.QueryRowContext(ctx, getBudget, id) var i Budget err := row.Scan(&i.ID, &i.Name, &i.LastModification) return i, err @@ -42,7 +42,7 @@ WHERE user_budgets.user_id = $1 ` func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Budget, error) { - rows, err := q.db.Query(ctx, getBudgetsForUser, userID) + rows, err := q.db.QueryContext(ctx, getBudgetsForUser, userID) if err != nil { return nil, err } @@ -55,6 +55,9 @@ func (q *Queries) GetBudgetsForUser(ctx context.Context, userID uuid.UUID) ([]Bu } items = append(items, i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } diff --git a/postgres/conn.go b/postgres/conn.go index 291d49d..36acf9a 100644 --- a/postgres/conn.go +++ b/postgres/conn.go @@ -1,12 +1,10 @@ package postgres import ( - "context" "database/sql" "embed" "fmt" - "github.com/jackc/pgx/v4" _ "github.com/jackc/pgx/v4/stdlib" "github.com/pressly/goose/v3" ) @@ -27,12 +25,7 @@ func Connect(server string, user string, password string, database string) (*Que return nil, nil, err } - connPG, err := pgx.Connect(context.Background(), connString) - if err != nil { - return nil, nil, err - } - - return New(connPG), conn, nil + return New(conn), conn, nil } func (tx Transaction) GetAmount() float64 { @@ -62,3 +55,17 @@ func (tx GetTransactionsForBudgetRow) GetPositive() bool { amount := tx.GetAmount() return amount >= 0 } + +func (tx GetAccountsWithBalanceRow) GetBalance() float64 { + var balance float64 + err := tx.Balance.AssignTo(&balance) + if err != nil { + panic(err) + } + return balance +} + +func (tx GetAccountsWithBalanceRow) GetPositive() bool { + balance := tx.GetBalance() + return balance >= 0 +} diff --git a/postgres/db.go b/postgres/db.go index f2900b8..8d02508 100644 --- a/postgres/db.go +++ b/postgres/db.go @@ -4,15 +4,14 @@ package postgres import ( "context" - - "github.com/jackc/pgconn" - "github.com/jackc/pgx/v4" + "database/sql" ) type DBTX interface { - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - Query(context.Context, string, ...interface{}) (pgx.Rows, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row } func New(db DBTX) *Queries { @@ -23,7 +22,7 @@ type Queries struct { db DBTX } -func (q *Queries) WithTx(tx pgx.Tx) *Queries { +func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ db: tx, } diff --git a/postgres/models.go b/postgres/models.go index 1edc459..a7070cd 100644 --- a/postgres/models.go +++ b/postgres/models.go @@ -7,7 +7,6 @@ import ( "time" "github.com/google/uuid" - "github.com/jackc/pgtype" ) type Account struct { @@ -32,7 +31,7 @@ type Transaction struct { ID uuid.UUID Date time.Time Memo string - Amount pgtype.Numeric + Amount Numeric AccountID uuid.UUID PayeeID uuid.NullUUID } diff --git a/postgres/numeric.go b/postgres/numeric.go new file mode 100644 index 0000000..acf5f03 --- /dev/null +++ b/postgres/numeric.go @@ -0,0 +1,21 @@ +package postgres + +import "github.com/jackc/pgtype" + +type Numeric struct { + pgtype.Numeric +} + +func (n Numeric) GetFloat64() float64 { + var balance float64 + err := n.AssignTo(&balance) + if err != nil { + panic(err) + } + return balance +} + +func (n Numeric) GetPositive() bool { + float := n.GetFloat64() + return float >= 0 +} diff --git a/postgres/payees.sql.go b/postgres/payees.sql.go index 7a62bde..a723f28 100644 --- a/postgres/payees.sql.go +++ b/postgres/payees.sql.go @@ -22,7 +22,7 @@ type CreatePayeeParams struct { } func (q *Queries) CreatePayee(ctx context.Context, arg CreatePayeeParams) (Payee, error) { - row := q.db.QueryRow(ctx, createPayee, arg.Name, arg.BudgetID) + row := q.db.QueryRowContext(ctx, createPayee, arg.Name, arg.BudgetID) var i Payee err := row.Scan(&i.ID, &i.BudgetID, &i.Name) return i, err @@ -34,7 +34,7 @@ WHERE payees.budget_id = $1 ` func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, error) { - rows, err := q.db.Query(ctx, getPayees, budgetID) + rows, err := q.db.QueryContext(ctx, getPayees, budgetID) if err != nil { return nil, err } @@ -47,6 +47,9 @@ func (q *Queries) GetPayees(ctx context.Context, budgetID uuid.UUID) ([]Payee, e } items = append(items, i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } diff --git a/postgres/queries/accounts.sql b/postgres/queries/accounts.sql index 62e502e..788ee42 100644 --- a/postgres/queries/accounts.sql +++ b/postgres/queries/accounts.sql @@ -5,7 +5,13 @@ VALUES ($1, $2) RETURNING *; -- name: GetAccounts :many -SELECT accounts.id, accounts.name, SUM(transactions.amount) as balance FROM accounts +SELECT accounts.* FROM accounts +WHERE accounts.budget_id = $1; + +-- name: GetAccountsWithBalance :many +SELECT accounts.id, accounts.name, SUM(transactions.amount)::decimal(12,2) as balance +FROM accounts +LEFT JOIN transactions ON transactions.account_id = accounts.id WHERE accounts.budget_id = $1 AND transactions.date < NOW() GROUP BY accounts.id, accounts.name; \ No newline at end of file diff --git a/postgres/transactions.sql.go b/postgres/transactions.sql.go index 26e49d3..353ba0a 100644 --- a/postgres/transactions.sql.go +++ b/postgres/transactions.sql.go @@ -8,7 +8,6 @@ import ( "time" "github.com/google/uuid" - "github.com/jackc/pgtype" ) const createTransaction = `-- name: CreateTransaction :one @@ -21,13 +20,13 @@ RETURNING id, date, memo, amount, account_id, payee_id type CreateTransactionParams struct { Date time.Time Memo string - Amount pgtype.Numeric + Amount Numeric AccountID uuid.UUID PayeeID uuid.NullUUID } func (q *Queries) CreateTransaction(ctx context.Context, arg CreateTransactionParams) (Transaction, error) { - row := q.db.QueryRow(ctx, createTransaction, + row := q.db.QueryRowContext(ctx, createTransaction, arg.Date, arg.Memo, arg.Amount, @@ -54,7 +53,7 @@ LIMIT 200 ` func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid.UUID) ([]Transaction, error) { - rows, err := q.db.Query(ctx, getTransactionsForAccount, accountID) + rows, err := q.db.QueryContext(ctx, getTransactionsForAccount, accountID) if err != nil { return nil, err } @@ -74,6 +73,9 @@ func (q *Queries) GetTransactionsForAccount(ctx context.Context, accountID uuid. } items = append(items, i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -95,13 +97,13 @@ type GetTransactionsForBudgetRow struct { ID uuid.UUID Date time.Time Memo string - Amount pgtype.Numeric + Amount Numeric Account string Payee string } func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UUID) ([]GetTransactionsForBudgetRow, error) { - rows, err := q.db.Query(ctx, getTransactionsForBudget, budgetID) + rows, err := q.db.QueryContext(ctx, getTransactionsForBudget, budgetID) if err != nil { return nil, err } @@ -121,6 +123,9 @@ func (q *Queries) GetTransactionsForBudget(ctx context.Context, budgetID uuid.UU } items = append(items, i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } diff --git a/postgres/user_budgets.sql.go b/postgres/user_budgets.sql.go index c1e07ba..4740c0e 100644 --- a/postgres/user_budgets.sql.go +++ b/postgres/user_budgets.sql.go @@ -22,7 +22,7 @@ type LinkBudgetToUserParams struct { } func (q *Queries) LinkBudgetToUser(ctx context.Context, arg LinkBudgetToUserParams) (UserBudget, error) { - row := q.db.QueryRow(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID) + row := q.db.QueryRowContext(ctx, linkBudgetToUser, arg.UserID, arg.BudgetID) var i UserBudget err := row.Scan(&i.UserID, &i.BudgetID) return i, err diff --git a/postgres/users.sql.go b/postgres/users.sql.go index aaf7a4c..75fb84a 100644 --- a/postgres/users.sql.go +++ b/postgres/users.sql.go @@ -23,7 +23,7 @@ type CreateUserParams struct { } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRow(ctx, createUser, arg.Email, arg.Name, arg.Password) + row := q.db.QueryRowContext(ctx, createUser, arg.Email, arg.Name, arg.Password) var i User err := row.Scan( &i.ID, @@ -40,7 +40,7 @@ WHERE id = $1 ` func (q *Queries) GetUser(ctx context.Context, id uuid.UUID) (User, error) { - row := q.db.QueryRow(ctx, getUser, id) + row := q.db.QueryRowContext(ctx, getUser, id) var i User err := row.Scan( &i.ID, @@ -57,7 +57,7 @@ WHERE email = $1 ` func (q *Queries) GetUserByUsername(ctx context.Context, email string) (User, error) { - row := q.db.QueryRow(ctx, getUserByUsername, email) + row := q.db.QueryRowContext(ctx, getUserByUsername, email) var i User err := row.Scan( &i.ID, diff --git a/sqlc.yaml b/sqlc.yaml index 4cb38c9..ed524bf 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -5,4 +5,9 @@ packages: engine: "postgresql" schema: "postgres/schema/" queries: "postgres/queries/" - sql_package: "pgx/v4" \ No newline at end of file +overrides: + - go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric" + db_type: "pg_catalog.numeric" + - go_type: "git.javil.eu/jacob1123/budgeteer/postgres.Numeric" + db_type: "pg_catalog.numeric" + nullable: true \ No newline at end of file diff --git a/web/accounts.html b/web/accounts.html index 7ff29da..830ec81 100644 --- a/web/accounts.html +++ b/web/accounts.html @@ -11,7 +11,7 @@ {{range .Accounts}}
{{.Name}} - {{.Balance}} + {{printf "%.2f" .GetBalance}}
{{end}} {{end}} \ No newline at end of file