Improve handling of context

This commit is contained in:
2021-12-07 19:08:53 +00:00
parent fbd283cd1c
commit 1d4bc158a8
8 changed files with 21 additions and 40 deletions

View File

@ -1,7 +1,6 @@
package http package http
import ( import (
"context"
"net/http" "net/http"
"git.javil.eu/jacob1123/budgeteer/postgres" "git.javil.eu/jacob1123/budgeteer/postgres"
@ -23,7 +22,7 @@ func (h *Handler) getImportantData(c *gin.Context) {
return return
} }
budget, err := h.Service.DB.GetBudget(context.Background(), budgetUUID) budget, err := h.Service.DB.GetBudget(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusNotFound, err) c.AbortWithError(http.StatusNotFound, err)
return return

View File

@ -1,7 +1,6 @@
package http package http
import ( import (
"context"
"net/http" "net/http"
"git.javil.eu/jacob1123/budgeteer" "git.javil.eu/jacob1123/budgeteer"
@ -23,7 +22,7 @@ func (h *Handler) budget(c *gin.Context) {
return return
} }
transactions, err := h.Service.DB.GetTransactionsForBudget(context.Background(), budgetUUID) transactions, err := h.Service.DB.GetTransactionsForBudget(c.Request.Context(), budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return
@ -45,7 +44,7 @@ func (h *Handler) newBudget(c *gin.Context) {
} }
userID := c.MustGet("token").(budgeteer.Token).GetID() userID := c.MustGet("token").(budgeteer.Token).GetID()
_, err := h.Service.NewBudget(budgetName, userID) _, err := h.Service.NewBudget(c.Request.Context(), budgetName, userID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -1,7 +1,6 @@
package http package http
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -59,7 +58,7 @@ func (h *Handler) budgeting(c *gin.Context) {
FromDate: firstOfMonth, FromDate: firstOfMonth,
ToDate: firstOfNextMonth, ToDate: firstOfNextMonth,
} }
categories, err := h.Service.DB.GetCategoriesWithBalance(context.Background(), params) categories, err := h.Service.DB.GetCategoriesWithBalance(c.Request.Context(), params)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -10,7 +10,7 @@ import (
func (h *Handler) dashboard(c *gin.Context) { func (h *Handler) dashboard(c *gin.Context) {
userID := c.MustGet("token").(budgeteer.Token).GetID() userID := c.MustGet("token").(budgeteer.Token).GetID()
budgets, err := h.Service.BudgetsForUser(userID) budgets, err := h.Service.DB.GetBudgetsForUser(c.Request.Context(), userID)
if err != nil { if err != nil {
return return
} }

View File

@ -68,7 +68,7 @@ func (h *Handler) loginPost(c *gin.Context) {
username, _ := c.GetPostForm("username") username, _ := c.GetPostForm("username")
password, _ := c.GetPostForm("password") password, _ := c.GetPostForm("password")
user, err := h.Service.DB.GetUserByUsername(context.Background(), username) user, err := h.Service.DB.GetUserByUsername(c.Request.Context(), username)
if err != nil { if err != nil {
c.AbortWithError(http.StatusUnauthorized, err) c.AbortWithError(http.StatusUnauthorized, err)
return return
@ -84,7 +84,8 @@ func (h *Handler) loginPost(c *gin.Context) {
c.AbortWithError(http.StatusUnauthorized, err) c.AbortWithError(http.StatusUnauthorized, err)
} }
_, _ = h.Service.DB.UpdateLastLogin(context.Background(), user.ID) go h.Service.DB.UpdateLastLogin(context.Background(), user.ID)
maxAge := (int)((expiration * time.Hour).Seconds()) maxAge := (int)((expiration * time.Hour).Seconds())
c.SetCookie(authCookie, t, maxAge, "", "", false, true) c.SetCookie(authCookie, t, maxAge, "", "", false, true)
c.JSON(http.StatusOK, map[string]string{ c.JSON(http.StatusOK, map[string]string{
@ -97,7 +98,7 @@ func (h *Handler) registerPost(c *gin.Context) {
password, _ := c.GetPostForm("password") password, _ := c.GetPostForm("password")
name, _ := c.GetPostForm("name") name, _ := c.GetPostForm("name")
_, err := h.Service.DB.GetUserByUsername(context.Background(), email) _, err := h.Service.DB.GetUserByUsername(c.Request.Context(), email)
if err == nil { if err == nil {
c.AbortWithStatus(http.StatusUnauthorized) c.AbortWithStatus(http.StatusUnauthorized)
return return
@ -114,7 +115,7 @@ func (h *Handler) registerPost(c *gin.Context) {
Password: hash, Password: hash,
Email: email, Email: email,
} }
_, err = h.Service.DB.CreateUser(context.Background(), createUser) _, err = h.Service.DB.CreateUser(c.Request.Context(), createUser)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
} }

View File

@ -22,7 +22,7 @@ func (h *Handler) importYNAB(c *gin.Context) {
return return
} }
ynab, err := postgres.NewYNABImport(h.Service.DB, budgetUUID) ynab, err := postgres.NewYNABImport(c.Request.Context(), h.Service.DB, budgetUUID)
if err != nil { if err != nil {
c.AbortWithError(http.StatusInternalServerError, err) c.AbortWithError(http.StatusInternalServerError, err)
return return

View File

@ -6,32 +6,15 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Budget returns a budget for a given id. // NewBudget creates a budget and adds it to the current user
func (s *Repository) Budget(id uuid.UUID) (*Budget, error) { func (s *Repository) NewBudget(context context.Context, name string, userID uuid.UUID) (*Budget, error) {
budget, err := s.DB.GetBudget(context.Background(), id) budget, err := s.DB.CreateBudget(context, name)
if err != nil {
return nil, err
}
return &budget, nil
}
func (s *Repository) BudgetsForUser(id uuid.UUID) ([]Budget, error) {
budgets, err := s.DB.GetBudgetsForUser(context.Background(), id)
if err != nil {
return nil, err
}
return budgets, nil
}
func (s *Repository) NewBudget(name string, userID uuid.UUID) (*Budget, error) {
budget, err := s.DB.CreateBudget(context.Background(), name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ub := LinkBudgetToUserParams{UserID: userID, BudgetID: budget.ID} ub := LinkBudgetToUserParams{UserID: userID, BudgetID: budget.ID}
_, err = s.DB.LinkBudgetToUser(context.Background(), ub) _, err = s.DB.LinkBudgetToUser(context, ub)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,29 +22,29 @@ type YNABImport struct {
budgetID uuid.UUID budgetID uuid.UUID
} }
func NewYNABImport(q *Queries, budgetID uuid.UUID) (*YNABImport, error) { func NewYNABImport(context context.Context, q *Queries, budgetID uuid.UUID) (*YNABImport, error) {
accounts, err := q.GetAccounts(context.Background(), budgetID) accounts, err := q.GetAccounts(context, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payees, err := q.GetPayees(context.Background(), budgetID) payees, err := q.GetPayees(context, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
categories, err := q.GetCategories(context.Background(), budgetID) categories, err := q.GetCategories(context, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
categoryGroups, err := q.GetCategoryGroups(context.Background(), budgetID) categoryGroups, err := q.GetCategoryGroups(context, budgetID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &YNABImport{ return &YNABImport{
Context: context.Background(), Context: context,
accounts: accounts, accounts: accounts,
payees: payees, payees: payees,
categories: categories, categories: categories,