273 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			273 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package numeric
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"math/big"
 | |
| 	"strings"
 | |
| 	"unicode/utf8"
 | |
| 
 | |
| 	"github.com/jackc/pgtype"
 | |
| )
 | |
| 
 | |
| type Numeric struct {
 | |
| 	pgtype.Numeric
 | |
| }
 | |
| 
 | |
| func Zero() Numeric {
 | |
| 	return Numeric{pgtype.Numeric{Exp: 0, Int: big.NewInt(0), Status: pgtype.Present, NaN: false}}
 | |
| }
 | |
| 
 | |
| func FromInt64(value int64) Numeric {
 | |
| 	return Numeric{Numeric: pgtype.Numeric{Int: big.NewInt(value), Status: pgtype.Present}}
 | |
| }
 | |
| 
 | |
| func FromInt64WithExp(value int64, exp int32) Numeric {
 | |
| 	return Numeric{Numeric: pgtype.Numeric{Int: big.NewInt(value), Exp: exp, Status: pgtype.Present}}
 | |
| }
 | |
| 
 | |
| func (n *Numeric) SetZero() {
 | |
| 	n.Exp = 0
 | |
| 	n.Int = big.NewInt(0)
 | |
| 	n.Status = pgtype.Present
 | |
| 	n.NaN = false
 | |
| }
 | |
| 
 | |
| func (n Numeric) GetFloat64() float64 {
 | |
| 	if n.Status != pgtype.Present {
 | |
| 		return 0
 | |
| 	}
 | |
| 	var balance float64
 | |
| 	err := n.AssignTo(&balance)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	return balance
 | |
| }
 | |
| 
 | |
| func (n Numeric) IsPositive() bool {
 | |
| 	if n.Status != pgtype.Present {
 | |
| 		return true
 | |
| 	}
 | |
| 	float := n.GetFloat64()
 | |
| 	return float >= 0
 | |
| }
 | |
| 
 | |
| func (n Numeric) IsZero() bool {
 | |
| 	if n.Status != pgtype.Present {
 | |
| 		return true
 | |
| 	}
 | |
| 	float := n.GetFloat64()
 | |
| 	return float == 0
 | |
| }
 | |
| 
 | |
| func (n *Numeric) MatchExpI(exp int32) {
 | |
| 	diffExp := n.Exp - exp
 | |
| 	factor := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(diffExp)), nil) //nolint:gomnd
 | |
| 	n.Exp = exp
 | |
| 	n.Int = big.NewInt(0).Mul(n.Int, factor)
 | |
| }
 | |
| 
 | |
| func (n Numeric) MatchExp(exp int32) Numeric {
 | |
| 	diffExp := n.Exp - exp
 | |
| 	factor := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(diffExp)), nil) //nolint:gomnd
 | |
| 	return Numeric{pgtype.Numeric{
 | |
| 		Exp:    exp,
 | |
| 		Int:    big.NewInt(0).Mul(n.Int, factor),
 | |
| 		Status: n.Status,
 | |
| 		NaN:    n.NaN,
 | |
| 	}}
 | |
| }
 | |
| 
 | |
| func (n *Numeric) SubI(other Numeric) *Numeric {
 | |
| 	right := other
 | |
| 	if n.Exp < other.Exp {
 | |
| 		right = other.MatchExp(n.Exp)
 | |
| 	} else if n.Exp > other.Exp {
 | |
| 		n.MatchExpI(other.Exp)
 | |
| 	}
 | |
| 
 | |
| 	if n.Exp == right.Exp {
 | |
| 		n.Int = big.NewInt(0).Sub(n.Int, right.Int)
 | |
| 		return n
 | |
| 	}
 | |
| 
 | |
| 	panic("Cannot subtract with different exponents")
 | |
| }
 | |
| 
 | |
| func (n Numeric) Sub(other Numeric) Numeric {
 | |
| 	left := n
 | |
| 	right := other
 | |
| 	if n.Exp < other.Exp {
 | |
| 		right = other.MatchExp(n.Exp)
 | |
| 	} else if n.Exp > other.Exp {
 | |
| 		left = n.MatchExp(other.Exp)
 | |
| 	}
 | |
| 
 | |
| 	if left.Exp == right.Exp {
 | |
| 		return Numeric{pgtype.Numeric{
 | |
| 			Exp: left.Exp,
 | |
| 			Int: big.NewInt(0).Sub(left.Int, right.Int),
 | |
| 		}}
 | |
| 	}
 | |
| 
 | |
| 	panic("Cannot subtract with different exponents")
 | |
| }
 | |
| 
 | |
| func (n Numeric) Neg() Numeric {
 | |
| 	return Numeric{pgtype.Numeric{Exp: n.Exp, Int: big.NewInt(-1 * n.Int.Int64()), Status: n.Status}}
 | |
| }
 | |
| 
 | |
| func (n Numeric) Add(other Numeric) Numeric {
 | |
| 	left := n
 | |
| 	right := other
 | |
| 	if n.Exp < other.Exp {
 | |
| 		right = other.MatchExp(n.Exp)
 | |
| 	} else if n.Exp > other.Exp {
 | |
| 		left = n.MatchExp(other.Exp)
 | |
| 	}
 | |
| 
 | |
| 	if left.Exp == right.Exp {
 | |
| 		return Numeric{pgtype.Numeric{
 | |
| 			Exp: left.Exp,
 | |
| 			Int: big.NewInt(0).Add(left.Int, right.Int),
 | |
| 		}}
 | |
| 	}
 | |
| 
 | |
| 	panic("Cannot add with different exponents")
 | |
| }
 | |
| 
 | |
| func (n *Numeric) AddI(other Numeric) *Numeric {
 | |
| 	right := other
 | |
| 	if n.Exp < other.Exp {
 | |
| 		right = other.MatchExp(n.Exp)
 | |
| 	} else if n.Exp > other.Exp {
 | |
| 		n.MatchExpI(other.Exp)
 | |
| 	}
 | |
| 
 | |
| 	if n.Exp == right.Exp {
 | |
| 		n.Int = big.NewInt(0).Add(n.Int, right.Int)
 | |
| 		return n
 | |
| 	}
 | |
| 
 | |
| 	panic("Cannot add with different exponents")
 | |
| }
 | |
| 
 | |
| func (n Numeric) String() string {
 | |
| 	if n.Int == nil || n.Int.Int64() == 0 {
 | |
| 		return "0"
 | |
| 	}
 | |
| 
 | |
| 	s := fmt.Sprintf("%d", n.Int)
 | |
| 	bytes := []byte(s)
 | |
| 
 | |
| 	exp := n.Exp
 | |
| 	for exp > 0 {
 | |
| 		bytes = append(bytes, byte('0'))
 | |
| 		exp--
 | |
| 	}
 | |
| 
 | |
| 	if exp == 0 {
 | |
| 		return string(bytes)
 | |
| 	}
 | |
| 
 | |
| 	length := int32(len(bytes))
 | |
| 	var bytesWithSeparator []byte
 | |
| 
 | |
| 	exp = -exp
 | |
| 	for length <= exp {
 | |
| 		if n.Int.Int64() < 0 {
 | |
| 			bytes = append([]byte{bytes[0], byte('0')}, bytes[1:]...)
 | |
| 		} else {
 | |
| 			bytes = append([]byte{byte('0')}, bytes...)
 | |
| 		}
 | |
| 		length++
 | |
| 	}
 | |
| 
 | |
| 	split := length - exp
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...)
 | |
| 	if split == 1 && n.Int.Int64() < 0 {
 | |
| 		bytesWithSeparator = append(bytesWithSeparator, byte('0'))
 | |
| 	}
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, byte('.'))
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...)
 | |
| 	return string(bytesWithSeparator)
 | |
| }
 | |
| 
 | |
| func (n Numeric) MarshalJSON() ([]byte, error) {
 | |
| 	if n.Int == nil || n.Int.Int64() == 0 {
 | |
| 		return []byte("0"), nil
 | |
| 	}
 | |
| 
 | |
| 	s := fmt.Sprintf("%d", n.Int)
 | |
| 	bytes := []byte(s)
 | |
| 
 | |
| 	exp := n.Exp
 | |
| 	for exp > 0 {
 | |
| 		bytes = append(bytes, byte('0'))
 | |
| 		exp--
 | |
| 	}
 | |
| 
 | |
| 	if exp == 0 {
 | |
| 		return bytes, nil
 | |
| 	}
 | |
| 
 | |
| 	length := int32(len(bytes))
 | |
| 	var bytesWithSeparator []byte
 | |
| 
 | |
| 	exp = -exp
 | |
| 	for length <= exp {
 | |
| 		if n.Int.Int64() < 0 {
 | |
| 			bytes = append([]byte{bytes[0], byte('0')}, bytes[1:]...)
 | |
| 		} else {
 | |
| 			bytes = append([]byte{byte('0')}, bytes...)
 | |
| 		}
 | |
| 		length++
 | |
| 	}
 | |
| 
 | |
| 	split := length - exp
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, bytes[:split]...)
 | |
| 	if split == 1 && n.Int.Int64() < 0 {
 | |
| 		bytesWithSeparator = append(bytesWithSeparator, byte('0'))
 | |
| 	}
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, byte('.'))
 | |
| 	bytesWithSeparator = append(bytesWithSeparator, bytes[split:]...)
 | |
| 	return bytesWithSeparator, nil
 | |
| }
 | |
| 
 | |
| func MustParse(text string) Numeric {
 | |
| 	num, err := Parse(text)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 
 | |
| 	return num
 | |
| }
 | |
| 
 | |
| func Parse(text string) (Numeric, error) {
 | |
| 	// Unify decimal separator
 | |
| 	text = strings.Replace(text, ",", ".", 1)
 | |
| 
 | |
| 	num := Numeric{}
 | |
| 	err := num.Set(text)
 | |
| 	if err != nil {
 | |
| 		return num, fmt.Errorf("parse numeric %s: %w", text, err)
 | |
| 	}
 | |
| 
 | |
| 	return num, nil
 | |
| }
 | |
| 
 | |
| func ParseCurrency(text string) (Numeric, error) {
 | |
| 	// Remove trailing currency
 | |
| 	text = trimLastChar(text)
 | |
| 
 | |
| 	return Parse(text)
 | |
| }
 | |
| 
 | |
| func trimLastChar(s string) string {
 | |
| 	r, size := utf8.DecodeLastRuneInString(s)
 | |
| 	if r == utf8.RuneError && (size == 0 || size == 1) {
 | |
| 		size = 0
 | |
| 	}
 | |
| 	return s[:len(s)-size]
 | |
| }
 |