diff options
author | Germán Fuentes Capella <47056480+gfcapella@users.noreply.github.com> | 2021-01-18 02:14:59 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-17 20:14:59 -0500 |
commit | b6d3e7f4b0f5b20e6b08b4755a94ac7eab8981d0 (patch) | |
tree | 8d27b438c6083644d3bdb36261b5a664daa64ee4 | |
parent | 8756d3ec17702f93b2d5e50e9ea6762f30bec167 (diff) | |
download | starlark-go-b6d3e7f4b0f5b20e6b08b4755a94ac7eab8981d0.tar.gz |
Fix: Immutability of Int is broken by BigInt method (#332) (#333)
* Fix: Immutability of Int is broken by BigInt method (#332)
The interface of Int assumes the value is immutable but a reference to
its content is leaked through the BigInt() method, which is mutable.
This fix resolves the leak by:
- making a copy of the *big.Int in MakeBigInt()
- returning a copy of the *big.Int content in BigInt()
- creating an internal method bigInt() for internal use that still
returns reference to the internal big.Int
Fixes: #332
* Fix: Immutability of Int is broken by BigInt method (#332)
Addressing review comments regarding documentation
Fixes: #332
-rw-r--r-- | starlark/int.go | 41 | ||||
-rw-r--r-- | starlark/int_test.go | 29 |
2 files changed, 55 insertions, 15 deletions
diff --git a/starlark/int.go b/starlark/int.go index c13c8dd..9ee46f9 100644 --- a/starlark/int.go +++ b/starlark/int.go @@ -44,12 +44,13 @@ func MakeUint64(x uint64) Int { } // MakeBigInt returns a Starlark int for the specified big.Int. -// The caller must not subsequently modify x. +// The new Int value will contain a copy of x. The caller is safe to modify x. func MakeBigInt(x *big.Int) Int { if n := x.BitLen(); n < 32 || n == 32 && x.Int64() == math.MinInt32 { return makeSmallInt(x.Int64()) } - return makeBigInt(x) + z := new(big.Int).Set(x) + return makeBigInt(z) } var ( @@ -86,11 +87,21 @@ func (i Int) Int64() (_ int64, ok bool) { return iSmall, true } -// BigInt returns the value as a big.Int. -// The returned variable must not be modified by the client. +// BigInt returns a new big.Int with the same value as the Int. func (i Int) BigInt() *big.Int { iSmall, iBig := i.get() if iBig != nil { + return new(big.Int).Set(iBig) + } + return big.NewInt(iSmall) +} + +// bigInt returns the value as a big.Int. +// It differs from BigInt in that this method returns the actual +// reference and any modification will change the state of i. +func (i Int) bigInt() *big.Int { + iSmall, iBig := i.get() + if iBig != nil { return iBig } return big.NewInt(iSmall) @@ -179,7 +190,7 @@ func (x Int) CompareSameType(op syntax.Token, v Value, depth int) (bool, error) xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return threeway(op, x.BigInt().Cmp(y.BigInt())), nil + return threeway(op, x.bigInt().Cmp(y.bigInt())), nil } return threeway(op, signum64(xSmall-ySmall)), nil } @@ -216,7 +227,7 @@ func (x Int) Add(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).Add(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).Add(x.bigInt(), y.bigInt())) } return MakeInt64(xSmall + ySmall) } @@ -224,7 +235,7 @@ func (x Int) Sub(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).Sub(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).Sub(x.bigInt(), y.bigInt())) } return MakeInt64(xSmall - ySmall) } @@ -232,7 +243,7 @@ func (x Int) Mul(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).Mul(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).Mul(x.bigInt(), y.bigInt())) } return MakeInt64(xSmall * ySmall) } @@ -240,7 +251,7 @@ func (x Int) Or(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).Or(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).Or(x.bigInt(), y.bigInt())) } return makeSmallInt(xSmall | ySmall) } @@ -248,7 +259,7 @@ func (x Int) And(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).And(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).And(x.bigInt(), y.bigInt())) } return makeSmallInt(xSmall & ySmall) } @@ -256,7 +267,7 @@ func (x Int) Xor(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - return MakeBigInt(new(big.Int).Xor(x.BigInt(), y.BigInt())) + return MakeBigInt(new(big.Int).Xor(x.bigInt(), y.bigInt())) } return makeSmallInt(xSmall ^ ySmall) } @@ -267,8 +278,8 @@ func (x Int) Not() Int { } return makeSmallInt(^xSmall) } -func (x Int) Lsh(y uint) Int { return MakeBigInt(new(big.Int).Lsh(x.BigInt(), y)) } -func (x Int) Rsh(y uint) Int { return MakeBigInt(new(big.Int).Rsh(x.BigInt(), y)) } +func (x Int) Lsh(y uint) Int { return MakeBigInt(new(big.Int).Lsh(x.bigInt(), y)) } +func (x Int) Rsh(y uint) Int { return MakeBigInt(new(big.Int).Rsh(x.bigInt(), y)) } // Precondition: y is nonzero. func (x Int) Div(y Int) Int { @@ -276,7 +287,7 @@ func (x Int) Div(y Int) Int { ySmall, yBig := y.get() // http://python-history.blogspot.com/2010/08/why-pythons-integer-division-floors.html if xBig != nil || yBig != nil { - xb, yb := x.BigInt(), y.BigInt() + xb, yb := x.bigInt(), y.bigInt() var quo, rem big.Int quo.QuoRem(xb, yb, &rem) @@ -298,7 +309,7 @@ func (x Int) Mod(y Int) Int { xSmall, xBig := x.get() ySmall, yBig := y.get() if xBig != nil || yBig != nil { - xb, yb := x.BigInt(), y.BigInt() + xb, yb := x.bigInt(), y.bigInt() var quo, rem big.Int quo.QuoRem(xb, yb, &rem) diff --git a/starlark/int_test.go b/starlark/int_test.go index 6725616..ad1bf92 100644 --- a/starlark/int_test.go +++ b/starlark/int_test.go @@ -71,3 +71,32 @@ func TestIntOpts(t *testing.T) { } } } + +func TestImmutabilityMakeBigInt(t *testing.T) { + // use max int64 for the test + expect := int64(^uint64(0) >> 1) + + mutint := big.NewInt(expect) + value := MakeBigInt(mutint) + mutint.Set(big.NewInt(1)) + + got, _ := value.Int64() + if got != expect { + t.Errorf("expected %d, got %d", expect, got) + } +} + +func TestImmutabilityBigInt(t *testing.T) { + // use 1 and max int64 for the test + for _, expect := range []int64{1, int64(^uint64(0) >> 1)} { + value := MakeBigInt(big.NewInt(expect)) + + bigint := value.BigInt() + bigint.Set(big.NewInt(2)) + + got, _ := value.Int64() + if got != expect { + t.Errorf("expected %d, got %d", expect, got) + } + } +} |