diff options
Diffstat (limited to 'ssl/test/runner/handshake_messages.go')
-rw-r--r-- | ssl/test/runner/handshake_messages.go | 195 |
1 files changed, 191 insertions, 4 deletions
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 7fe8bf5..136360d 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -24,7 +24,10 @@ type clientHelloMsg struct { sessionTicket []uint8 signatureAndHashes []signatureAndHash secureRenegotiation bool + alpnProtocols []string duplicateExtension bool + channelIDSupported bool + npnLast bool } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -49,7 +52,11 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.ticketSupported == m1.ticketSupported && bytes.Equal(m.sessionTicket, m1.sessionTicket) && eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && - m.secureRenegotiation == m1.secureRenegotiation + m.secureRenegotiation == m1.secureRenegotiation && + eqStrings(m.alpnProtocols, m1.alpnProtocols) && + m.duplicateExtension == m1.duplicateExtension && + m.channelIDSupported == m1.channelIDSupported && + m.npnLast == m1.npnLast } func (m *clientHelloMsg) marshal() []byte { @@ -97,6 +104,20 @@ func (m *clientHelloMsg) marshal() []byte { if m.duplicateExtension { numExtensions += 2 } + if m.channelIDSupported { + numExtensions++ + } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) + } + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -141,7 +162,7 @@ func (m *clientHelloMsg) marshal() []byte { z[1] = 0xff z = z[4:] } - if m.nextProtoNeg { + if m.nextProtoNeg && !m.npnLast { z[0] = byte(extensionNextProtoNeg >> 8) z[1] = byte(extensionNextProtoNeg & 0xff) // The length is always 0 @@ -260,6 +281,38 @@ func (m *clientHelloMsg) marshal() []byte { z[3] = 1 z = z[5:] } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } + + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } + if m.channelIDSupported { + z[0] = byte(extensionChannelID >> 8) + z[1] = byte(extensionChannelID & 0xff) + z = z[4:] + } + if m.nextProtoNeg && m.npnLast { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] + } if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. z[0] = 0xff @@ -331,6 +384,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.ticketSupported = false m.sessionTicket = nil m.signatureAndHashes = nil + m.alpnProtocols = nil if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -440,6 +494,29 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } m.secureRenegotiation = true + case extensionALPN: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + d := data[2:length] + for len(d) != 0 { + stringLen := int(d[0]) + d = d[1:] + if stringLen == 0 || stringLen > len(d) { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + d = d[stringLen:] + } + case extensionChannelID: + if length > 0 { + return false + } + m.channelIDSupported = true } data = data[length:] } @@ -460,7 +537,9 @@ type serverHelloMsg struct { ocspStapling bool ticketSupported bool secureRenegotiation bool + alpnProtocol string duplicateExtension bool + channelIDRequested bool } func (m *serverHelloMsg) equal(i interface{}) bool { @@ -480,7 +559,10 @@ func (m *serverHelloMsg) equal(i interface{}) bool { eqStrings(m.nextProtos, m1.nextProtos) && m.ocspStapling == m1.ocspStapling && m.ticketSupported == m1.ticketSupported && - m.secureRenegotiation == m1.secureRenegotiation + m.secureRenegotiation == m1.secureRenegotiation && + m.alpnProtocol == m1.alpnProtocol && + m.duplicateExtension == m1.duplicateExtension && + m.channelIDRequested == m1.channelIDRequested } func (m *serverHelloMsg) marshal() []byte { @@ -514,6 +596,17 @@ func (m *serverHelloMsg) marshal() []byte { if m.duplicateExtension { numExtensions += 2 } + if m.channelIDRequested { + numExtensions++ + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + extensionsLength += 2 + 1 + alpnLen + numExtensions++ + } + if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -581,6 +674,25 @@ func (m *serverHelloMsg) marshal() []byte { z[3] = 1 z = z[5:] } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } + if m.channelIDRequested { + z[0] = byte(extensionChannelID >> 8) + z[1] = byte(extensionChannelID & 0xff) + z = z[4:] + } if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. z[0] = 0xff @@ -617,6 +729,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtos = nil m.ocspStapling = false m.ticketSupported = false + m.alpnProtocol = "" if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -671,6 +784,27 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } m.secureRenegotiation = true + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return false + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return false + } + d = d[1:] + m.alpnProtocol = string(d) + case extensionChannelID: + if length > 0 { + return false + } + m.channelIDRequested = true } data = data[length:] } @@ -1407,7 +1541,8 @@ func (m *helloVerifyRequestMsg) equal(i interface{}) bool { return false } - return m.vers == m1.vers && + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && bytes.Equal(m.cookie, m1.cookie) } @@ -1447,6 +1582,58 @@ func (m *helloVerifyRequestMsg) unmarshal(data []byte) bool { return true } +type encryptedExtensionsMsg struct { + raw []byte + channelID []byte +} + +func (m *encryptedExtensionsMsg) equal(i interface{}) bool { + m1, ok := i.(*encryptedExtensionsMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.channelID, m1.channelID) +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 2 + len(m.channelID) + + x := make([]byte, 4+length) + x[0] = typeEncryptedExtensions + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(extensionChannelID >> 8) + x[5] = uint8(extensionChannelID & 0xff) + x[6] = uint8(len(m.channelID) >> 8) + x[7] = uint8(len(m.channelID) & 0xff) + copy(x[8:], m.channelID) + + return x +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + if len(data) != 4+2+2+128 { + return false + } + m.raw = data + if (uint16(data[4])<<8)|uint16(data[5]) != extensionChannelID { + return false + } + if int(data[6])<<8|int(data[7]) != 128 { + return false + } + m.channelID = data[4+2+2:] + + return true +} + func eqUint16s(x, y []uint16) bool { if len(x) != len(y) { return false |