diff options
Diffstat (limited to 'ssl/test/runner/handshake_messages.go')
-rw-r--r-- | ssl/test/runner/handshake_messages.go | 55 |
1 files changed, 37 insertions, 18 deletions
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 1114a6f..12a9f3d 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -23,7 +23,7 @@ type clientHelloMsg struct { ticketSupported bool sessionTicket []uint8 signatureAndHashes []signatureAndHash - secureRenegotiation bool + secureRenegotiation []byte alpnProtocols []string duplicateExtension bool channelIDSupported bool @@ -53,7 +53,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.ticketSupported == m1.ticketSupported && bytes.Equal(m.sessionTicket, m1.sessionTicket) && eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && - m.secureRenegotiation == m1.secureRenegotiation && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + (m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) && eqStrings(m.alpnProtocols, m1.alpnProtocols) && m.duplicateExtension == m1.duplicateExtension && m.channelIDSupported == m1.channelIDSupported && @@ -99,8 +100,8 @@ func (m *clientHelloMsg) marshal() []byte { extensionsLength += 2 + 2*len(m.signatureAndHashes) numExtensions++ } - if m.secureRenegotiation { - extensionsLength += 1 + if m.secureRenegotiation != nil { + extensionsLength += 1 + len(m.secureRenegotiation) numExtensions++ } if m.duplicateExtension { @@ -279,12 +280,15 @@ func (m *clientHelloMsg) marshal() []byte { z = z[2:] } } - if m.secureRenegotiation { + if m.secureRenegotiation != nil { z[0] = byte(extensionRenegotiationInfo >> 8) z[1] = byte(extensionRenegotiationInfo & 0xff) z[2] = 0 - z[3] = 1 + z[3] = byte(1 + len(m.secureRenegotiation)) + z[4] = byte(len(m.secureRenegotiation)) z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] } if len(m.alpnProtocols) > 0 { z[0] = byte(extensionALPN >> 8) @@ -374,7 +378,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { for i := 0; i < numCipherSuites; i++ { m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) if m.cipherSuites[i] == scsvRenegotiation { - m.secureRenegotiation = true + m.secureRenegotiation = []byte{} } } data = data[2+cipherSuiteLen:] @@ -501,11 +505,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.signatureAndHashes[i].signature = d[1] d = d[2:] } - case extensionRenegotiationInfo + 1: - if length != 1 || data[0] != 0 { + case extensionRenegotiationInfo: + if length < 1 || length != int(data[0])+1 { return false } - m.secureRenegotiation = true + m.secureRenegotiation = data[1:length] case extensionALPN: if length < 2 { return false @@ -553,7 +557,7 @@ type serverHelloMsg struct { nextProtos []string ocspStapling bool ticketSupported bool - secureRenegotiation bool + secureRenegotiation []byte alpnProtocol string duplicateExtension bool channelIDRequested bool @@ -577,7 +581,8 @@ func (m *serverHelloMsg) equal(i interface{}) bool { eqStrings(m.nextProtos, m1.nextProtos) && m.ocspStapling == m1.ocspStapling && m.ticketSupported == m1.ticketSupported && - m.secureRenegotiation == m1.secureRenegotiation && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + (m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) && m.alpnProtocol == m1.alpnProtocol && m.duplicateExtension == m1.duplicateExtension && m.channelIDRequested == m1.channelIDRequested && @@ -608,8 +613,8 @@ func (m *serverHelloMsg) marshal() []byte { if m.ticketSupported { numExtensions++ } - if m.secureRenegotiation { - extensionsLength += 1 + if m.secureRenegotiation != nil { + extensionsLength += 1 + len(m.secureRenegotiation) numExtensions++ } if m.duplicateExtension { @@ -689,12 +694,15 @@ func (m *serverHelloMsg) marshal() []byte { z[1] = byte(extensionSessionTicket) z = z[4:] } - if m.secureRenegotiation { + if m.secureRenegotiation != nil { z[0] = byte(extensionRenegotiationInfo >> 8) z[1] = byte(extensionRenegotiationInfo & 0xff) z[2] = 0 - z[3] = 1 + z[3] = byte(1 + len(m.secureRenegotiation)) + z[4] = byte(len(m.secureRenegotiation)) z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] } if alpnLen := len(m.alpnProtocol); alpnLen > 0 { z[0] = byte(extensionALPN >> 8) @@ -808,10 +816,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } m.ticketSupported = true case extensionRenegotiationInfo: - if length != 1 || data[0] != 0 { + if length < 1 || length != int(data[0])+1 { return false } - m.secureRenegotiation = true + m.secureRenegotiation = data[1:length] case extensionALPN: d := data[:length] if len(d) < 3 { @@ -1667,6 +1675,17 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { return true } +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + func eqUint16s(x, y []uint16) bool { if len(x) != len(y) { return false |