package postgres import ( "fmt" "math/big" "github.com/jackc/pgtype" ) type Numeric struct { pgtype.Numeric } func NewZeroNumeric() Numeric { return Numeric{pgtype.Numeric{Exp: 0, Int: big.NewInt(0), Status: pgtype.Present, NaN: false}} } func (n Numeric) GetFloat64() float64 { if n.Status != pgtype.Present { return 0 } var balance float64 err := n.AssignTo(&balance) if err != nil { panic(err) } return balance } func (n Numeric) IsPositive() bool { if n.Status != pgtype.Present { return true } float := n.GetFloat64() return float >= 0 } func (n Numeric) IsZero() bool { if n.Status != pgtype.Present { return true } float := n.GetFloat64() return float == 0 } func (n Numeric) MatchExp(exp int32) Numeric { diffExp := n.Exp - exp factor := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(diffExp)), nil) //nolint:gomnd return Numeric{pgtype.Numeric{ Exp: exp, Int: big.NewInt(0).Mul(n.Int, factor), Status: n.Status, NaN: n.NaN, }} } func (n Numeric) Sub(other Numeric) Numeric { left := n right := other if n.Exp < other.Exp { right = other.MatchExp(n.Exp) } else if n.Exp > other.Exp { left = n.MatchExp(other.Exp) } if left.Exp == right.Exp { return Numeric{pgtype.Numeric{ Exp: left.Exp, Int: big.NewInt(0).Sub(left.Int, right.Int), }} } panic("Cannot subtract with different exponents") } func (n Numeric) Add(other Numeric) Numeric { left := n right := other if n.Exp < other.Exp { right = other.MatchExp(n.Exp) } else if n.Exp > other.Exp { left = n.MatchExp(other.Exp) } if left.Exp == right.Exp { return Numeric{pgtype.Numeric{ Exp: left.Exp, Int: big.NewInt(0).Add(left.Int, right.Int), }} } panic("Cannot add with different exponents") } func (n Numeric) String() string { if n.Int == nil || n.Int.Int64() == 0 { return "0" } s := fmt.Sprintf("%d", n.Int) bytes := []byte(s) exp := n.Exp for exp > 0 { bytes = append(bytes, byte('0')) exp-- } if exp == 0 { return string(bytes) } length := int32(len(bytes)) var bytesWithSeparator []byte exp = -exp for length <= exp { bytes = append(bytes, byte('0')) length++ } split := length - exp bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...) if split == 1 && n.Int.Int64() < 0 { bytesWithSeparator = append(bytesWithSeparator, byte('0')) } bytesWithSeparator = append(bytesWithSeparator, byte('.')) bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...) return string(bytesWithSeparator) } func (n Numeric) MarshalJSON() ([]byte, error) { if n.Int == nil || n.Int.Int64() == 0 { return []byte("0"), nil } s := fmt.Sprintf("%d", n.Int) bytes := []byte(s) exp := n.Exp for exp > 0 { bytes = append(bytes, byte('0')) exp-- } if exp == 0 { return bytes, nil } length := int32(len(bytes)) var bytesWithSeparator []byte exp = -exp for length <= exp { bytes = append(bytes, byte('0')) length++ } split := length - exp bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...) if split == 1 && n.Int.Int64() < 0 { bytesWithSeparator = append(bytesWithSeparator, byte('0')) } bytesWithSeparator = append(bytesWithSeparator, byte('.')) bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...) return bytesWithSeparator, nil }