From ea6d198bff1ef206ce8ba22e4c72b3663e9e1cd3 Mon Sep 17 00:00:00 2001 From: Jan Bader Date: Wed, 23 Feb 2022 21:52:36 +0000 Subject: [PATCH] Add some unit-tests for numeric --- postgres/numeric/numeric.go | 32 ++++++++++++-- postgres/numeric/numeric_test.go | 75 ++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 postgres/numeric/numeric_test.go diff --git a/postgres/numeric/numeric.go b/postgres/numeric/numeric.go index 83dbd3b..149c615 100644 --- a/postgres/numeric/numeric.go +++ b/postgres/numeric/numeric.go @@ -17,6 +17,19 @@ func Zero() Numeric { return Numeric{pgtype.Numeric{Exp: 0, Int: big.NewInt(0), Status: pgtype.Present, NaN: false}} } +func FromInt64(value int64) Numeric { + num := Numeric{} + num.Set(value) + return num +} + +func FromInt64WithExp(value int64, exp int32) Numeric { + num := Numeric{} + num.Set(value) + num.Exp = exp + return num +} + func (n Numeric) GetFloat64() float64 { if n.Status != pgtype.Present { return 0 @@ -168,10 +181,16 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return bytesWithSeparator, nil } -func Parse(text string) (Numeric, error) { - // Remove trailing currency - text = trimLastChar(text) +func MustParse(text string) Numeric { + num, err := Parse(text) + if err != nil { + panic(err) + } + return num +} + +func Parse(text string) (Numeric, error) { // Unify decimal separator text = strings.Replace(text, ",", ".", 1) @@ -184,6 +203,13 @@ func Parse(text string) (Numeric, error) { return num, nil } +func ParseCurrency(text string) (Numeric, error) { + // Remove trailing currency + text = trimLastChar(text) + + return Parse(text) +} + func trimLastChar(s string) string { r, size := utf8.DecodeLastRuneInString(s) if r == utf8.RuneError && (size == 0 || size == 1) { diff --git a/postgres/numeric/numeric_test.go b/postgres/numeric/numeric_test.go new file mode 100644 index 0000000..50329f0 --- /dev/null +++ b/postgres/numeric/numeric_test.go @@ -0,0 +1,75 @@ +package numeric_test + +import ( + "testing" + + "git.javil.eu/jacob1123/budgeteer/postgres/numeric" +) + +type TestCaseMarshalJSON struct { + Value numeric.Numeric + Result string +} + +func TestMarshalJSON(t *testing.T) { + tests := []TestCaseMarshalJSON{ + {numeric.Zero(), `0`}, + {numeric.MustParse("1.23"), "1.23"}, + {numeric.MustParse("1,24"), "1.24"}, + {numeric.MustParse("123456789.12345"), "123456789.12345"}, + } + for _, test := range tests { + t.Run(test.Result, func(t *testing.T) { + z := test.Value + result, err := z.MarshalJSON() + if err != nil { + t.Error(err) + return + } + + if string(result) != test.Result { + t.Errorf("Expected %s, got %s", test.Result, string(result)) + return + } + }) + } +} + +type TestCaseParse struct { + Result numeric.Numeric + Value string +} + +func TestParse(t *testing.T) { + tests := []TestCaseParse{ + {numeric.Zero(), `0`}, + {numeric.FromInt64(1), `1`}, + {numeric.FromInt64WithExp(1, 1), `10`}, + {numeric.FromInt64WithExp(1, 2), `100`}, + {numeric.MustParse("1.23"), "1.23"}, + {numeric.MustParse("1,24"), "1.24"}, + {numeric.MustParse("123456789.12345"), "123456789.12345"}, + } + for _, test := range tests { + t.Run(test.Value, func(t *testing.T) { + result, err := numeric.Parse(test.Value) + if err != nil { + t.Error(err) + return + } + + if test.Result.Int.Int64() != result.Int.Int64() { + t.Errorf("Expected int %d, got %d", test.Result.Int, result.Int) + return + } + + if test.Result.Exp != result.Exp { + t.Errorf("Expected exp %d, got %d", test.Result.Exp, result.Exp) + return + } + // if string(result) != test.Result { + // return + //} + }) + } +}