266 lines
5.4 KiB
Go
266 lines
5.4 KiB
Go
package numeric
|
|
|
|
import (
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
"github.com/jackc/pgtype"
|
|
)
|
|
|
|
type Numeric struct {
|
|
pgtype.Numeric
|
|
}
|
|
|
|
func Zero() Numeric {
|
|
return Numeric{pgtype.Numeric{Exp: 0, Int: big.NewInt(0), Status: pgtype.Present, NaN: false}}
|
|
}
|
|
|
|
func FromInt64(value int64) Numeric {
|
|
return Numeric{Numeric: pgtype.Numeric{Int: big.NewInt(value), Status: pgtype.Present}}
|
|
}
|
|
|
|
func FromInt64WithExp(value int64, exp int32) Numeric {
|
|
return Numeric{Numeric: pgtype.Numeric{Int: big.NewInt(value), Exp: exp, Status: pgtype.Present}}
|
|
}
|
|
|
|
func (n Numeric) GetFloat64() float64 {
|
|
if n.Status != pgtype.Present {
|
|
return 0
|
|
}
|
|
var balance float64
|
|
err := n.AssignTo(&balance)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return balance
|
|
}
|
|
|
|
func (n Numeric) IsPositive() bool {
|
|
if n.Status != pgtype.Present {
|
|
return true
|
|
}
|
|
float := n.GetFloat64()
|
|
return float >= 0
|
|
}
|
|
|
|
func (n Numeric) IsZero() bool {
|
|
if n.Status != pgtype.Present {
|
|
return true
|
|
}
|
|
float := n.GetFloat64()
|
|
return float == 0
|
|
}
|
|
|
|
func (n *Numeric) MatchExpI(exp int32) {
|
|
diffExp := n.Exp - exp
|
|
factor := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(diffExp)), nil) //nolint:gomnd
|
|
n.Exp = exp
|
|
n.Int = big.NewInt(0).Mul(n.Int, factor)
|
|
}
|
|
|
|
func (n Numeric) MatchExp(exp int32) Numeric {
|
|
diffExp := n.Exp - exp
|
|
factor := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(diffExp)), nil) //nolint:gomnd
|
|
return Numeric{pgtype.Numeric{
|
|
Exp: exp,
|
|
Int: big.NewInt(0).Mul(n.Int, factor),
|
|
Status: n.Status,
|
|
NaN: n.NaN,
|
|
}}
|
|
}
|
|
|
|
func (n *Numeric) SubI(other Numeric) *Numeric {
|
|
right := other
|
|
if n.Exp < other.Exp {
|
|
right = other.MatchExp(n.Exp)
|
|
} else if n.Exp > other.Exp {
|
|
n.MatchExpI(other.Exp)
|
|
}
|
|
|
|
if n.Exp == right.Exp {
|
|
n.Int = big.NewInt(0).Sub(n.Int, right.Int)
|
|
return n
|
|
}
|
|
|
|
panic("Cannot subtract with different exponents")
|
|
}
|
|
|
|
func (n Numeric) Sub(other Numeric) Numeric {
|
|
left := n
|
|
right := other
|
|
if n.Exp < other.Exp {
|
|
right = other.MatchExp(n.Exp)
|
|
} else if n.Exp > other.Exp {
|
|
left = n.MatchExp(other.Exp)
|
|
}
|
|
|
|
if left.Exp == right.Exp {
|
|
return Numeric{pgtype.Numeric{
|
|
Exp: left.Exp,
|
|
Int: big.NewInt(0).Sub(left.Int, right.Int),
|
|
}}
|
|
}
|
|
|
|
panic("Cannot subtract with different exponents")
|
|
}
|
|
|
|
func (n Numeric) Neg() Numeric {
|
|
return Numeric{pgtype.Numeric{Exp: n.Exp, Int: big.NewInt(-1 * n.Int.Int64()), Status: n.Status}}
|
|
}
|
|
|
|
func (n Numeric) Add(other Numeric) Numeric {
|
|
left := n
|
|
right := other
|
|
if n.Exp < other.Exp {
|
|
right = other.MatchExp(n.Exp)
|
|
} else if n.Exp > other.Exp {
|
|
left = n.MatchExp(other.Exp)
|
|
}
|
|
|
|
if left.Exp == right.Exp {
|
|
return Numeric{pgtype.Numeric{
|
|
Exp: left.Exp,
|
|
Int: big.NewInt(0).Add(left.Int, right.Int),
|
|
}}
|
|
}
|
|
|
|
panic("Cannot add with different exponents")
|
|
}
|
|
|
|
func (n *Numeric) AddI(other Numeric) *Numeric {
|
|
right := other
|
|
if n.Exp < other.Exp {
|
|
right = other.MatchExp(n.Exp)
|
|
} else if n.Exp > other.Exp {
|
|
n.MatchExpI(other.Exp)
|
|
}
|
|
|
|
if n.Exp == right.Exp {
|
|
n.Int = big.NewInt(0).Add(n.Int, right.Int)
|
|
return n
|
|
}
|
|
|
|
panic("Cannot add with different exponents")
|
|
}
|
|
|
|
func (n Numeric) String() string {
|
|
if n.Int == nil || n.Int.Int64() == 0 {
|
|
return "0"
|
|
}
|
|
|
|
s := fmt.Sprintf("%d", n.Int)
|
|
bytes := []byte(s)
|
|
|
|
exp := n.Exp
|
|
for exp > 0 {
|
|
bytes = append(bytes, byte('0'))
|
|
exp--
|
|
}
|
|
|
|
if exp == 0 {
|
|
return string(bytes)
|
|
}
|
|
|
|
length := int32(len(bytes))
|
|
var bytesWithSeparator []byte
|
|
|
|
exp = -exp
|
|
for length <= exp {
|
|
if n.Int.Int64() < 0 {
|
|
bytes = append([]byte{bytes[0], byte('0')}, bytes[1:]...)
|
|
} else {
|
|
bytes = append([]byte{byte('0')}, bytes...)
|
|
}
|
|
length++
|
|
}
|
|
|
|
split := length - exp
|
|
bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...)
|
|
if split == 1 && n.Int.Int64() < 0 {
|
|
bytesWithSeparator = append(bytesWithSeparator, byte('0'))
|
|
}
|
|
bytesWithSeparator = append(bytesWithSeparator, byte('.'))
|
|
bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...)
|
|
return string(bytesWithSeparator)
|
|
}
|
|
|
|
func (n Numeric) MarshalJSON() ([]byte, error) {
|
|
if n.Int == nil || n.Int.Int64() == 0 {
|
|
return []byte("0"), nil
|
|
}
|
|
|
|
s := fmt.Sprintf("%d", n.Int)
|
|
bytes := []byte(s)
|
|
|
|
exp := n.Exp
|
|
for exp > 0 {
|
|
bytes = append(bytes, byte('0'))
|
|
exp--
|
|
}
|
|
|
|
if exp == 0 {
|
|
return bytes, nil
|
|
}
|
|
|
|
length := int32(len(bytes))
|
|
var bytesWithSeparator []byte
|
|
|
|
exp = -exp
|
|
for length <= exp {
|
|
if n.Int.Int64() < 0 {
|
|
bytes = append([]byte{bytes[0], byte('0')}, bytes[1:]...)
|
|
} else {
|
|
bytes = append([]byte{byte('0')}, bytes...)
|
|
}
|
|
length++
|
|
}
|
|
|
|
split := length - exp
|
|
bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...)
|
|
if split == 1 && n.Int.Int64() < 0 {
|
|
bytesWithSeparator = append(bytesWithSeparator, byte('0'))
|
|
}
|
|
bytesWithSeparator = append(bytesWithSeparator, byte('.'))
|
|
bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...)
|
|
return bytesWithSeparator, nil
|
|
}
|
|
|
|
func MustParse(text string) Numeric {
|
|
num, err := Parse(text)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return num
|
|
}
|
|
|
|
func Parse(text string) (Numeric, error) {
|
|
// Unify decimal separator
|
|
text = strings.Replace(text, ",", ".", 1)
|
|
|
|
num := Numeric{}
|
|
err := num.Set(text)
|
|
if err != nil {
|
|
return num, fmt.Errorf("parse numeric %s: %w", text, err)
|
|
}
|
|
|
|
return num, nil
|
|
}
|
|
|
|
func ParseCurrency(text string) (Numeric, error) {
|
|
// Remove trailing currency
|
|
text = trimLastChar(text)
|
|
|
|
return Parse(text)
|
|
}
|
|
|
|
func trimLastChar(s string) string {
|
|
r, size := utf8.DecodeLastRuneInString(s)
|
|
if r == utf8.RuneError && (size == 0 || size == 1) {
|
|
size = 0
|
|
}
|
|
return s[:len(s)-size]
|
|
}
|