diff --git a/postgres/ynab-import.go b/postgres/ynab-import.go index 82a5025..4879026 100644 --- a/postgres/ynab-import.go +++ b/postgres/ynab-import.go @@ -181,13 +181,14 @@ func (ynab *YNABImport) ImportTransactions(context context.Context, r io.Reader) } payeeName := record[3] - // Transaction is a transfer to - var shouldReturn bool - var returnValue error - openTransfers, shouldReturn, returnValue = ynab.ImportTransaction( - payeeName, context, transaction, accountName, openTransfers, account, amount) - if shouldReturn { - return returnValue + // Transaction is a transfer + if strings.HasPrefix(payeeName, "Transfer : ") { + err = ynab.ImportTransferTransaction(payeeName, context, transaction, accountName, &openTransfers, account, amount) + } else { + err = ynab.ImportRegularTransaction(context, payeeName, transaction) + } + if err != nil { + return err } count++ @@ -207,40 +208,25 @@ func (ynab *YNABImport) ImportTransactions(context context.Context, r io.Reader) return nil } -func (ynab *YNABImport) ImportTransaction(payeeName string, context context.Context, transaction CreateTransactionParams, accountName string, openTransfers []Transfer, account *Account, amount Numeric) ([]Transfer, bool, error) { - if strings.HasPrefix(payeeName, "Transfer : ") { - openTransfers, shouldReturn, returnValue, returnValue1, returnValue2 := ynab.ImportTransferTransaction(payeeName, context, transaction, accountName, openTransfers, account, amount) - if shouldReturn { - return returnValue, returnValue1, returnValue2 - } - } else { - shouldReturn, returnValue, returnValue1, returnValue2 := ynab.ImportRegularTransaction(context, payeeName, transaction) - if shouldReturn { - return returnValue, returnValue1, returnValue2 - } - } - return openTransfers, false, nil -} - -func (ynab *YNABImport) ImportRegularTransaction(context context.Context, payeeName string, transaction CreateTransactionParams) (bool, []Transfer, bool, error) { +func (ynab *YNABImport) ImportRegularTransaction(context context.Context, payeeName string, transaction CreateTransactionParams) error { payeeID, err := ynab.GetPayee(context, payeeName) if err != nil { - return true, nil, true, fmt.Errorf("get payee %s: %w", payeeName, err) + return fmt.Errorf("get payee %s: %w", payeeName, err) } transaction.PayeeID = payeeID _, err = ynab.queries.CreateTransaction(context, transaction) if err != nil { - return true, nil, true, fmt.Errorf("save transaction %v: %w", transaction, err) + return fmt.Errorf("save transaction %v: %w", transaction, err) } - return false, nil, false, nil + return nil } -func (ynab *YNABImport) ImportTransferTransaction(payeeName string, context context.Context, transaction CreateTransactionParams, accountName string, openTransfers []Transfer, account *Account, amount Numeric) ([]Transfer, bool, []Transfer, bool, error) { +func (ynab *YNABImport) ImportTransferTransaction(payeeName string, context context.Context, transaction CreateTransactionParams, accountName string, openTransfers *[]Transfer, account *Account, amount Numeric) error { transferToAccountName := payeeName[11:] transferToAccount, err := ynab.GetAccount(context, transferToAccountName) if err != nil { - return nil, true, nil, true, fmt.Errorf("get transfer account %s: %w", transferToAccountName, err) + return fmt.Errorf("get transfer account %s: %w", transferToAccountName, err) } transfer := Transfer{ @@ -251,7 +237,7 @@ func (ynab *YNABImport) ImportTransferTransaction(payeeName string, context cont } found := false - for i, openTransfer := range openTransfers { + for i, openTransfer := range *openTransfers { if openTransfer.TransferToAccount.ID != transfer.AccountID { continue } @@ -263,8 +249,9 @@ func (ynab *YNABImport) ImportTransferTransaction(payeeName string, context cont } fmt.Printf("Matched transfers from %s to %s over %f\n", account.Name, transferToAccount.Name, amount.GetFloat64()) - openTransfers[i] = openTransfers[len(openTransfers)-1] - openTransfers = openTransfers[:len(openTransfers)-1] + transfers := *openTransfers + transfers[i] = transfers[len(transfers)-1] + *openTransfers = transfers[:len(transfers)-1] found = true groupID := uuid.New() @@ -273,19 +260,19 @@ func (ynab *YNABImport) ImportTransferTransaction(payeeName string, context cont _, err = ynab.queries.CreateTransaction(context, transfer.CreateTransactionParams) if err != nil { - return nil, true, nil, true, fmt.Errorf("save transaction %v: %w", transfer.CreateTransactionParams, err) + return fmt.Errorf("save transaction %v: %w", transfer.CreateTransactionParams, err) } _, err = ynab.queries.CreateTransaction(context, openTransfer.CreateTransactionParams) if err != nil { - return nil, true, nil, true, fmt.Errorf("save transaction %v: %w", openTransfer.CreateTransactionParams, err) + return fmt.Errorf("save transaction %v: %w", openTransfer.CreateTransactionParams, err) } break } if !found { - openTransfers = append(openTransfers, transfer) + *openTransfers = append(*openTransfers, transfer) } - return openTransfers, false, nil, false, nil + return nil } func trimLastChar(s string) string { @@ -366,31 +353,22 @@ func (ynab *YNABImport) GetCategory(context context.Context, group string, name } } - for _, categoryGroup := range ynab.categoryGroups { - if categoryGroup.Name == group { - createCategory := CreateCategoryParams{Name: name, CategoryGroupID: categoryGroup.ID} - category, err := ynab.queries.CreateCategory(context, createCategory) - if err != nil { - return uuid.NullUUID{}, err - } - - getCategory := GetCategoriesRow{ - ID: category.ID, - CategoryGroupID: category.CategoryGroupID, - Name: category.Name, - Group: categoryGroup.Name, - } - ynab.categories = append(ynab.categories, getCategory) - return uuid.NullUUID{UUID: category.ID, Valid: true}, nil + var categoryGroup *CategoryGroup + for _, existingGroup := range ynab.categoryGroups { + if existingGroup.Name == group { + categoryGroup = &existingGroup } } - newGroup := CreateCategoryGroupParams{Name: group, BudgetID: ynab.budgetID} - categoryGroup, err := ynab.queries.CreateCategoryGroup(context, newGroup) - if err != nil { - return uuid.NullUUID{}, err + if categoryGroup == nil { + newGroup := CreateCategoryGroupParams{Name: group, BudgetID: ynab.budgetID} + newCategoryGroup, err := ynab.queries.CreateCategoryGroup(context, newGroup) + if err != nil { + return uuid.NullUUID{}, err + } + ynab.categoryGroups = append(ynab.categoryGroups, newCategoryGroup) + categoryGroup = &newCategoryGroup } - ynab.categoryGroups = append(ynab.categoryGroups, categoryGroup) newCategory := CreateCategoryParams{Name: name, CategoryGroupID: categoryGroup.ID} category, err := ynab.queries.CreateCategory(context, newCategory) diff --git a/server/account_test.go b/server/account_test.go index 8ea523a..cd4efcb 100644 --- a/server/account_test.go +++ b/server/account_test.go @@ -38,7 +38,10 @@ func TestListTimezonesHandler(t *testing.T) { t.Run("RegisterUser", func(t *testing.T) { t.Parallel() - context.Request, err = http.NewRequest(http.MethodPost, "/api/v1/user/register", strings.NewReader(`{"password":"pass","email":"info@example.com","name":"Test"}`)) + context.Request, err = http.NewRequest( + http.MethodPost, + "/api/v1/user/register", + strings.NewReader(`{"password":"pass","email":"info@example.com","name":"Test"}`)) if err != nil { t.Errorf("error creating request: %s", err) return