summaryrefslogtreecommitdiff
path: root/ssl/test/runner/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'ssl/test/runner/handshake_messages.go')
-rw-r--r--ssl/test/runner/handshake_messages.go195
1 files changed, 191 insertions, 4 deletions
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 7fe8bf5..136360d 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -24,7 +24,10 @@ type clientHelloMsg struct {
sessionTicket []uint8
signatureAndHashes []signatureAndHash
secureRenegotiation bool
+ alpnProtocols []string
duplicateExtension bool
+ channelIDSupported bool
+ npnLast bool
}
func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -49,7 +52,11 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
m.ticketSupported == m1.ticketSupported &&
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
- m.secureRenegotiation == m1.secureRenegotiation
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
+ m.duplicateExtension == m1.duplicateExtension &&
+ m.channelIDSupported == m1.channelIDSupported &&
+ m.npnLast == m1.npnLast
}
func (m *clientHelloMsg) marshal() []byte {
@@ -97,6 +104,20 @@ func (m *clientHelloMsg) marshal() []byte {
if m.duplicateExtension {
numExtensions += 2
}
+ if m.channelIDSupported {
+ numExtensions++
+ }
+ if len(m.alpnProtocols) > 0 {
+ extensionsLength += 2
+ for _, s := range m.alpnProtocols {
+ if l := len(s); l == 0 || l > 255 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength++
+ extensionsLength += len(s)
+ }
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -141,7 +162,7 @@ func (m *clientHelloMsg) marshal() []byte {
z[1] = 0xff
z = z[4:]
}
- if m.nextProtoNeg {
+ if m.nextProtoNeg && !m.npnLast {
z[0] = byte(extensionNextProtoNeg >> 8)
z[1] = byte(extensionNextProtoNeg & 0xff)
// The length is always 0
@@ -260,6 +281,38 @@ func (m *clientHelloMsg) marshal() []byte {
z[3] = 1
z = z[5:]
}
+ if len(m.alpnProtocols) > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ lengths := z[2:]
+ z = z[6:]
+
+ stringsLength := 0
+ for _, s := range m.alpnProtocols {
+ l := len(s)
+ z[0] = byte(l)
+ copy(z[1:], s)
+ z = z[1+l:]
+ stringsLength += 1 + l
+ }
+
+ lengths[2] = byte(stringsLength >> 8)
+ lengths[3] = byte(stringsLength)
+ stringsLength += 2
+ lengths[0] = byte(stringsLength >> 8)
+ lengths[1] = byte(stringsLength)
+ }
+ if m.channelIDSupported {
+ z[0] = byte(extensionChannelID >> 8)
+ z[1] = byte(extensionChannelID & 0xff)
+ z = z[4:]
+ }
+ if m.nextProtoNeg && m.npnLast {
+ z[0] = byte(extensionNextProtoNeg >> 8)
+ z[1] = byte(extensionNextProtoNeg & 0xff)
+ // The length is always 0
+ z = z[4:]
+ }
if m.duplicateExtension {
// Add a duplicate bogus extension at the beginning and end.
z[0] = 0xff
@@ -331,6 +384,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.ticketSupported = false
m.sessionTicket = nil
m.signatureAndHashes = nil
+ m.alpnProtocols = nil
if len(data) == 0 {
// ClientHello is optionally followed by extension data
@@ -440,6 +494,29 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false
}
m.secureRenegotiation = true
+ case extensionALPN:
+ if length < 2 {
+ return false
+ }
+ l := int(data[0])<<8 | int(data[1])
+ if l != length-2 {
+ return false
+ }
+ d := data[2:length]
+ for len(d) != 0 {
+ stringLen := int(d[0])
+ d = d[1:]
+ if stringLen == 0 || stringLen > len(d) {
+ return false
+ }
+ m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
+ d = d[stringLen:]
+ }
+ case extensionChannelID:
+ if length > 0 {
+ return false
+ }
+ m.channelIDSupported = true
}
data = data[length:]
}
@@ -460,7 +537,9 @@ type serverHelloMsg struct {
ocspStapling bool
ticketSupported bool
secureRenegotiation bool
+ alpnProtocol string
duplicateExtension bool
+ channelIDRequested bool
}
func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -480,7 +559,10 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling &&
m.ticketSupported == m1.ticketSupported &&
- m.secureRenegotiation == m1.secureRenegotiation
+ m.secureRenegotiation == m1.secureRenegotiation &&
+ m.alpnProtocol == m1.alpnProtocol &&
+ m.duplicateExtension == m1.duplicateExtension &&
+ m.channelIDRequested == m1.channelIDRequested
}
func (m *serverHelloMsg) marshal() []byte {
@@ -514,6 +596,17 @@ func (m *serverHelloMsg) marshal() []byte {
if m.duplicateExtension {
numExtensions += 2
}
+ if m.channelIDRequested {
+ numExtensions++
+ }
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ if alpnLen >= 256 {
+ panic("invalid ALPN protocol")
+ }
+ extensionsLength += 2 + 1 + alpnLen
+ numExtensions++
+ }
+
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -581,6 +674,25 @@ func (m *serverHelloMsg) marshal() []byte {
z[3] = 1
z = z[5:]
}
+ if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+ z[0] = byte(extensionALPN >> 8)
+ z[1] = byte(extensionALPN & 0xff)
+ l := 2 + 1 + alpnLen
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ l -= 2
+ z[4] = byte(l >> 8)
+ z[5] = byte(l)
+ l -= 1
+ z[6] = byte(l)
+ copy(z[7:], []byte(m.alpnProtocol))
+ z = z[7+alpnLen:]
+ }
+ if m.channelIDRequested {
+ z[0] = byte(extensionChannelID >> 8)
+ z[1] = byte(extensionChannelID & 0xff)
+ z = z[4:]
+ }
if m.duplicateExtension {
// Add a duplicate bogus extension at the beginning and end.
z[0] = 0xff
@@ -617,6 +729,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtos = nil
m.ocspStapling = false
m.ticketSupported = false
+ m.alpnProtocol = ""
if len(data) == 0 {
// ServerHello is optionally followed by extension data
@@ -671,6 +784,27 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return false
}
m.secureRenegotiation = true
+ case extensionALPN:
+ d := data[:length]
+ if len(d) < 3 {
+ return false
+ }
+ l := int(d[0])<<8 | int(d[1])
+ if l != len(d)-2 {
+ return false
+ }
+ d = d[2:]
+ l = int(d[0])
+ if l != len(d)-1 {
+ return false
+ }
+ d = d[1:]
+ m.alpnProtocol = string(d)
+ case extensionChannelID:
+ if length > 0 {
+ return false
+ }
+ m.channelIDRequested = true
}
data = data[length:]
}
@@ -1407,7 +1541,8 @@ func (m *helloVerifyRequestMsg) equal(i interface{}) bool {
return false
}
- return m.vers == m1.vers &&
+ return bytes.Equal(m.raw, m1.raw) &&
+ m.vers == m1.vers &&
bytes.Equal(m.cookie, m1.cookie)
}
@@ -1447,6 +1582,58 @@ func (m *helloVerifyRequestMsg) unmarshal(data []byte) bool {
return true
}
+type encryptedExtensionsMsg struct {
+ raw []byte
+ channelID []byte
+}
+
+func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
+ m1, ok := i.(*encryptedExtensionsMsg)
+ if !ok {
+ return false
+ }
+
+ return bytes.Equal(m.raw, m1.raw) &&
+ bytes.Equal(m.channelID, m1.channelID)
+}
+
+func (m *encryptedExtensionsMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ length := 2 + 2 + len(m.channelID)
+
+ x := make([]byte, 4+length)
+ x[0] = typeEncryptedExtensions
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ x[4] = uint8(extensionChannelID >> 8)
+ x[5] = uint8(extensionChannelID & 0xff)
+ x[6] = uint8(len(m.channelID) >> 8)
+ x[7] = uint8(len(m.channelID) & 0xff)
+ copy(x[8:], m.channelID)
+
+ return x
+}
+
+func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
+ if len(data) != 4+2+2+128 {
+ return false
+ }
+ m.raw = data
+ if (uint16(data[4])<<8)|uint16(data[5]) != extensionChannelID {
+ return false
+ }
+ if int(data[6])<<8|int(data[7]) != 128 {
+ return false
+ }
+ m.channelID = data[4+2+2:]
+
+ return true
+}
+
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false