aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Roth <2413031+sejr@users.noreply.github.com>2021-07-12 18:10:02 -0400
committerGitHub <noreply@github.com>2021-07-12 17:10:02 -0500
commitae25fc6a8e2fa9c3a72fc325117330736fe729f6 (patch)
treedd2ea3776d06c709fcac39c14fe4762218881004
parent655bf50db9d265813a26b0fe2a5d1e80fb9e3c6b (diff)
downloadgoogle-uuid-ae25fc6a8e2fa9c3a72fc325117330736fe729f6.tar.gz
feat(uuid): Added support for NullUUID (#76)
-rw-r--r--null.go120
-rw-r--r--null_test.go238
2 files changed, 358 insertions, 0 deletions
diff --git a/null.go b/null.go
new file mode 100644
index 0000000..95bb632
--- /dev/null
+++ b/null.go
@@ -0,0 +1,120 @@
+// Copyright 2021 Google Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package uuid
+
+import (
+ "bytes"
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+)
+
+// NullUUID represents a UUID that may be null.
+// NullUUID implements the Scanner interface so
+// it can be used as a scan destination:
+//
+// var u uuid.NullUUID
+// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u)
+// ...
+// if u.Valid {
+// // use u.UUID
+// } else {
+// // NULL value
+// }
+//
+type NullUUID struct {
+ UUID UUID
+ Valid bool // Valid is true if UUID is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (nu *NullUUID) Scan(value interface{}) error {
+ if value == nil {
+ nu.UUID, nu.Valid = Nil, false
+ return nil
+ }
+
+ err := nu.UUID.Scan(value)
+ if err != nil {
+ nu.Valid = false
+ return err
+ }
+
+ nu.Valid = true
+ return nil
+}
+
+// Value implements the driver Valuer interface.
+func (nu NullUUID) Value() (driver.Value, error) {
+ if !nu.Valid {
+ return nil, nil
+ }
+ // Delegate to UUID Value function
+ return nu.UUID.Value()
+}
+
+// MarshalBinary implements encoding.BinaryMarshaler.
+func (nu NullUUID) MarshalBinary() ([]byte, error) {
+ if nu.Valid {
+ return nu.UUID[:], nil
+ }
+
+ return []byte(nil), nil
+}
+
+// UnmarshalBinary implements encoding.BinaryUnmarshaler.
+func (nu *NullUUID) UnmarshalBinary(data []byte) error {
+ if len(data) != 16 {
+ return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
+ }
+ copy(nu.UUID[:], data)
+ nu.Valid = true
+ return nil
+}
+
+// MarshalText implements encoding.TextMarshaler.
+func (nu NullUUID) MarshalText() ([]byte, error) {
+ if nu.Valid {
+ return nu.UUID.MarshalText()
+ }
+
+ return []byte{110, 117, 108, 108}, nil
+}
+
+// UnmarshalText implements encoding.TextUnmarshaler.
+func (nu *NullUUID) UnmarshalText(data []byte) error {
+ id, err := ParseBytes(data)
+ if err != nil {
+ nu.Valid = false
+ return err
+ }
+ nu.UUID = id
+ nu.Valid = true
+ return nil
+}
+
+// MarshalJSON implements json.Marshaler.
+func (nu NullUUID) MarshalJSON() ([]byte, error) {
+ if nu.Valid {
+ return json.Marshal(nu.UUID)
+ }
+
+ return json.Marshal(nil)
+}
+
+// UnmarshalJSON implements json.Unmarshaler.
+func (nu *NullUUID) UnmarshalJSON(data []byte) error {
+ null := []byte{110, 117, 108, 108}
+ if bytes.Equal(data, null) {
+ return nil // valid null UUID
+ }
+
+ var u UUID
+ // tossing as we know u is valid
+ _ = json.Unmarshal(data, &u)
+ nu.Valid = true
+ nu.UUID = u
+ return nil
+}
diff --git a/null_test.go b/null_test.go
new file mode 100644
index 0000000..b1988a4
--- /dev/null
+++ b/null_test.go
@@ -0,0 +1,238 @@
+// Copyright 2021 Google Inc. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package uuid
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "testing"
+)
+
+func TestNullUUIDScan(t *testing.T) {
+ var u UUID
+ var nu NullUUID
+
+ uNilErr := u.Scan(nil)
+ nuNilErr := nu.Scan(nil)
+ if uNilErr != nil &&
+ nuNilErr != nil &&
+ uNilErr.Error() != nuNilErr.Error() {
+ t.Errorf("expected errors to be equal, got %s, %s", uNilErr, nuNilErr)
+ }
+
+ uInvalidStringErr := u.Scan("test")
+ nuInvalidStringErr := nu.Scan("test")
+ if uInvalidStringErr != nil &&
+ nuInvalidStringErr != nil &&
+ uInvalidStringErr.Error() != nuInvalidStringErr.Error() {
+ t.Errorf("expected errors to be equal, got %s, %s", uInvalidStringErr, nuInvalidStringErr)
+ }
+
+ valid := "12345678-abcd-1234-abcd-0123456789ab"
+ uValidErr := u.Scan(valid)
+ nuValidErr := nu.Scan(valid)
+ if uValidErr != nuValidErr {
+ t.Errorf("expected errors to be equal, got %s, %s", uValidErr, nuValidErr)
+ }
+}
+
+func TestNullUUIDValue(t *testing.T) {
+ var u UUID
+ var nu NullUUID
+
+ nuValue, nuErr := nu.Value()
+ if nuErr != nil {
+ t.Errorf("expected nil err, got err %s", nuErr)
+ }
+ if nuValue != nil {
+ t.Errorf("expected nil value, got non-nil %s", nuValue)
+ }
+
+ u = MustParse("12345678-abcd-1234-abcd-0123456789ab")
+ nu = NullUUID{
+ UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
+ Valid: true,
+ }
+
+ uValue, uErr := u.Value()
+ nuValue, nuErr = nu.Value()
+ if uErr != nil {
+ t.Errorf("expected nil err, got err %s", uErr)
+ }
+ if nuErr != nil {
+ t.Errorf("expected nil err, got err %s", nuErr)
+ }
+ if uValue != nuValue {
+ t.Errorf("expected uuid %s and nulluuid %s to be equal ", uValue, nuValue)
+ }
+}
+
+func TestNullUUIDMarshalText(t *testing.T) {
+ tests := []struct {
+ nullUUID NullUUID
+ }{
+ {
+ nullUUID: NullUUID{},
+ },
+ {
+ nullUUID: NullUUID{
+ UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
+ Valid: true,
+ },
+ },
+ }
+ for _, test := range tests {
+ var uText []byte
+ var uErr error
+ nuText, nuErr := test.nullUUID.MarshalText()
+ if test.nullUUID.Valid {
+ uText, uErr = test.nullUUID.UUID.MarshalText()
+ } else {
+ uText = []byte("null")
+ }
+ if nuErr != uErr {
+ t.Errorf("expected error %e, got %e", nuErr, uErr)
+ }
+ if !bytes.Equal(nuText, uText) {
+ t.Errorf("expected text data %s, got %s", string(nuText), string(uText))
+ }
+ }
+}
+
+func TestNullUUIDUnmarshalText(t *testing.T) {
+ tests := []struct {
+ nullUUID NullUUID
+ }{
+ {
+ nullUUID: NullUUID{},
+ },
+ {
+ nullUUID: NullUUID{
+ UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
+ Valid: true,
+ },
+ },
+ }
+ for _, test := range tests {
+ var uText []byte
+ var uErr error
+ nuText, nuErr := test.nullUUID.MarshalText()
+ if test.nullUUID.Valid {
+ uText, uErr = test.nullUUID.UUID.MarshalText()
+ } else {
+ uText = []byte("null")
+ }
+ if nuErr != uErr {
+ t.Errorf("expected error %e, got %e", nuErr, uErr)
+ }
+ if !bytes.Equal(nuText, uText) {
+ t.Errorf("expected text data %s, got %s", string(nuText), string(uText))
+ }
+ }
+}
+
+func TestNullUUIDMarshalBinary(t *testing.T) {
+ tests := []struct {
+ nullUUID NullUUID
+ }{
+ {
+ nullUUID: NullUUID{},
+ },
+ {
+ nullUUID: NullUUID{
+ UUID: MustParse("12345678-abcd-1234-abcd-0123456789ab"),
+ Valid: true,
+ },
+ },
+ }
+ for _, test := range tests {
+ var uBinary []byte
+ var uErr error
+ nuBinary, nuErr := test.nullUUID.MarshalBinary()
+ if test.nullUUID.Valid {
+ uBinary, uErr = test.nullUUID.UUID.MarshalBinary()
+ } else {
+ uBinary = []byte(nil)
+ }
+ if nuErr != uErr {
+ t.Errorf("expected error %e, got %e", nuErr, uErr)
+ }
+ if !bytes.Equal(nuBinary, uBinary) {
+ t.Errorf("expected binary data %s, got %s", string(nuBinary), string(uBinary))
+ }
+ }
+}
+
+func TestNullUUIDMarshalJSON(t *testing.T) {
+ jsonNull, _ := json.Marshal(nil)
+ jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab"))
+ tests := []struct {
+ nullUUID NullUUID
+ expected []byte
+ expectedErr error
+ }{
+ {
+ nullUUID: NullUUID{},
+ expected: jsonNull,
+ expectedErr: nil,
+ },
+ {
+ nullUUID: NullUUID{
+ UUID: MustParse(string(jsonUUID)),
+ Valid: true,
+ },
+ expected: []byte(`"12345678-abcd-1234-abcd-0123456789ab"`),
+ expectedErr: nil,
+ },
+ }
+ for _, test := range tests {
+ data, err := json.Marshal(&test.nullUUID)
+ if err != test.expectedErr {
+ t.Errorf("expected error %e, got %e", test.expectedErr, err)
+ }
+ if !bytes.Equal(data, test.expected) {
+ t.Errorf("expected json data %s, got %s", string(test.expected), string(data))
+ }
+ }
+}
+
+func TestNullUUIDUnmarshalJSON(t *testing.T) {
+ jsonNull, _ := json.Marshal(nil)
+ jsonUUID, _ := json.Marshal(MustParse("12345678-abcd-1234-abcd-0123456789ab"))
+
+ var nu NullUUID
+ err := json.Unmarshal(jsonNull, &nu)
+ if err != nil || nu.Valid {
+ t.Errorf("expected nil when unmarshaling null, got %s", err)
+ }
+ err = json.Unmarshal(jsonUUID, &nu)
+ if err != nil || !nu.Valid {
+ t.Errorf("expected nil when unmarshaling null, got %s", err)
+ }
+}
+
+func TestConformance(t *testing.T) {
+ input := []byte(`"12345678-abcd-1234-abcd-0123456789ab"`)
+ var n NullUUID
+ var u UUID
+
+ err := json.Unmarshal(input, &n)
+ fmt.Printf("Unmarshal NullUUID: %+v %v\n", n, err)
+ err = json.Unmarshal(input, &u)
+ fmt.Printf("Unmarshal UUID: %+v %v\n", u, err)
+
+ n = NullUUID{}
+ data, err := json.Marshal(&n)
+ fmt.Printf("Marshal Empty NullUUID %s %v\n", data, err)
+
+ n.Valid = true
+ n.UUID = u
+ data, err = json.Marshal(&n)
+ fmt.Printf("Marshal Filled NullUUID %s %v\n", data, err)
+
+ data, err = json.Marshal(&u)
+ fmt.Printf("Marshal UUID: %s %v\n", data, err)
+}