Skip to main content
  • Home
  • login
  • Browse the archive

    swh mirror partner logo
swh logo
SoftwareHeritage
Software
Heritage
Mirror
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

Revision dbd56c149072e656ca8d6a43a59588f3e7513da2 authored by Hubert Kario on 29 September 2021, 13:05:34 UTC, committed by GitHub on 29 September 2021, 13:05:34 UTC
Merge pull request #777 from tlsfuzzer/descriptive-ExpectAlert
ExpectAlert - add __repr__
2 parent s 09d51ea + 34b0bc2
  • Files
  • Changes
  • dbc30ee
  • /
  • tests
  • /
  • test_tlsfuzzer_messages.py
Raw File
Cook and download a directory from the Software Heritage Vault

You have requested the cooking of the directory with identifier None into a standard tar.gz archive.

Are you sure you want to continue ?

Download a directory from the Software Heritage Vault

You have requested the download of the directory with identifier None as a standard tar.gz archive.

Are you sure you want to continue ?

Cook and download a revision from the Software Heritage Vault

You have requested the cooking of the history heading to revision with identifier swh:1:rev:dbd56c149072e656ca8d6a43a59588f3e7513da2 into a bare git archive.

Are you sure you want to continue ?

Download a revision from the Software Heritage Vault

You have requested the download of the history heading to revision with identifier swh:1:rev:dbd56c149072e656ca8d6a43a59588f3e7513da2 as a bare git archive.

Are you sure you want to continue ?

Invalid Email !

The provided email is not well-formed.

Download link has expired

The requested archive is no longer available for download from the Software Heritage Vault.

Do you want to cook it again ?

Permalinks

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • revision
  • content
revision badge
swh:1:rev:dbd56c149072e656ca8d6a43a59588f3e7513da2
content badge Iframe embedding
swh:1:cnt:c2f66b030806d78f5b1a9cf2f588021453668e8e
test_tlsfuzzer_messages.py
# Author: Hubert Kario, (c) 2015
# Released under Gnu GPL v2.0, see LICENSE file for details

try:
    import unittest2 as unittest
except ImportError:
    import unittest

try:
    import mock
    from mock import call
except ImportError:
    import unittest.mock as mock
    from unittest.mock import call

import socket
import os

from tlsfuzzer.messages import ClientHelloGenerator, ClientKeyExchangeGenerator,\
        ChangeCipherSpecGenerator, FinishedGenerator, \
        RenegotiationInfoExtension, ResetHandshakeHashes, SetMaxRecordSize, \
        pad_handshake, truncate_handshake, Close, fuzz_message, \
        RawMessageGenerator, split_message, PopMessageFromList, \
        FlushMessageList, fuzz_mac, fuzz_padding, ApplicationDataGenerator, \
        CertificateGenerator, CertificateVerifyGenerator, CertificateRequest, \
        ResetRenegotiationInfo, fuzz_plaintext, Connect, \
        ClientMasterKeyGenerator, TCPBufferingEnable, TCPBufferingDisable, \
        TCPBufferingFlush, fuzz_encrypted_message, fuzz_pkcs1_padding, \
        CollectNonces, AlertGenerator, PlaintextMessageGenerator, \
        SetPaddingCallback, replace_plaintext, ch_cookie_handler, \
        ch_key_share_handler, SetRecordVersion, CopyVariables, \
        ResetWriteConnectionState, HeartbeatGenerator, Certificate, \
        KeyUpdateGenerator, ClearContext, RawSocketWriteGenerator
from tlsfuzzer.helpers import psk_ext_gen, psk_ext_updater, \
        psk_session_ext_gen, AutoEmptyExtension
from tlsfuzzer.runner import ConnectionState
import tlslite.messages as messages
import tlslite.messagesocket as messagesocket
import tlslite.extensions as extensions
from tlslite.handshakehashes import HandshakeHashes
from tlslite.keyexchange import KeyExchange
import tlslite.utils.keyfactory as keyfactory
from tlslite.utils.cryptomath import bytesToNumber, numberToByteArray, \
        secureHash
from tlsfuzzer.utils.ordered_dict import OrderedDict
import tlslite.constants as constants
import tlslite.defragmenter as defragmenter
from tlslite.utils.codec import Parser
from tests.mocksock import MockSocket
from tlslite.utils.keyfactory import generateRSAKey
from tlslite.utils.cryptomath import numberToByteArray
from tlslite.utils.python_rsakey import Python_RSAKey
from tlslite.utils.python_ecdsakey import Python_ECDSAKey
from tlslite.utils.python_eddsakey import Python_EdDSAKey
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from ecdsa import SigningKey, Ed25519


rsa_pss_cert = X509CertChain([X509().parse(
    "-----BEGIN CERTIFICATE-----\n"
    "MIIDWzCCAhKgAwIBAgIJAM94DjB2Qf+GMD4GCSqGSIb3DQEBCjAxoA0wCwYJYIZI\n"
    "AWUDBAIBoRowGAYJKoZIhvcNAQEIMAsGCWCGSAFlAwQCAaIEAgIA3jAUMRIwEAYD\n"
    "VQQDDAlsb2NhbGhvc3QwHhcNMTcwMzA4MTMzMzU4WhcNMTcwNDA3MTMzMzU4WjAU\n"
    "MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEgMAsGCSqGSIb3DQEBCgOCAQ8AMIIBCgKC\n"
    "AQEApq5FnZRNmtJy+WjNRB8w0ol2+IEcATrUkZpz7HNqq8+EL/GN21m35luz5fcA\n"
    "8ifkm4pKxfwtxY7u912tAfsEavUr/UoyLScXQevhr6SLXR8UO3XE6ne1F88eZNpg\n"
    "KVpocDVKRLjIqKHSSbBNkG76mGKFYyNZXm3qRxaUSHVLnN/opOGfgVK8Vbj0v2EH\n"
    "5L5dGLQLV2ZfQ75I4OGrg8wse5fhgt2oczhhe72sEuMDc08WHGRp9744mMQvVYrt\n"
    "pWUHkP5tdD4i7/zYf3nig21/G56JY20JpW+1J5fEn+PEqiJsv5prNJuPp/zdjD3u\n"
    "ImqfXZzXUeRIQr9l1qXI6JPInwIDAQABo1AwTjAdBgNVHQ4EFgQUcTYhLu7pODIv\n"
    "B6KhR6eyFBB5wacwHwYDVR0jBBgwFoAUcTYhLu7pODIvB6KhR6eyFBB5wacwDAYD\n"
    "VR0TBAUwAwEB/zA+BgkqhkiG9w0BAQowMaANMAsGCWCGSAFlAwQCAaEaMBgGCSqG\n"
    "SIb3DQEBCDALBglghkgBZQMEAgGiBAICAN4DggEBAKMgweHM6WTwlWEQHLG5K+7B\n"
    "hrAUEAsuK8F7sKGKzLEFzYdzZpkJw8LahE4dFayjx/7MD4rZ5IiHQhJcGCdHIVVv\n"
    "ocunlEUTgiKkMxTw4JxqSq0snvNBie04vnn+zUjD7FrctTUutzlH1yKftwbJpGk6\n"
    "CrTW6ctFTAIDwZHd+WX4RPewGY0LTfC+RjcMwWZBmbfVLxuJs0sidSUoNW6GgGE1\n"
    "DIDVeW2yKGeNhjK/3aDzfQWbz1J64aRfccVzXYMPsoABnNJnJgRETh1/Ci0sQ9Vd\n"
    "1OR6iS4hl88/1d7utc00MyFVk1sUIGf54EeCvrNB4bhKtawEJk8Q8AGIRhs93sk=\n"
    "-----END CERTIFICATE-----\n")])


ecdsa_cert = X509CertChain([X509().parse(
    "-----BEGIN CERTIFICATE-----\n"
    "MIIBbTCCARSgAwIBAgIJAPM58cskyK+yMAkGByqGSM49BAEwFDESMBAGA1UEAwwJ\n"
    "bG9jYWxob3N0MB4XDTE3MTAyMzExNDI0MVoXDTE3MTEyMjExNDI0MVowFDESMBAG\n"
    "A1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEyDRjEAJe\n"
    "3F5T62MyZbhjoJnPLGL2nrTthLFymBupZ2IbnWYnqVWDkT/L6i8sQhf2zCLrlSjj\n"
    "1kn7ERqPx/KZyqNQME4wHQYDVR0OBBYEFPfFTUg9o3t6ehLsschSnC8Te8oaMB8G\n"
    "A1UdIwQYMBaAFPfFTUg9o3t6ehLsschSnC8Te8oaMAwGA1UdEwQFMAMBAf8wCQYH\n"
    "KoZIzj0EAQNIADBFAiA6p0YM5ZzfW+klHPRU2r13/IfKgeRfDR3dtBngmPvxUgIh\n"
    "APTeSDeJvYWVBLzyrKTeSerNDKKHU2Rt7sufipv76+7s\n"
    "-----END CERTIFICATE-----\n"
    )])


ecdsa_key = keyfactory.parsePEMKey(
    "-----BEGIN PRIVATE KEY-----\n"
    "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgCOZr0Ovs0eCmh+XM\n"
    "QWDYVpsQ+sJdjiq/itp/kYnWNSahRANCAATINGMQAl7cXlPrYzJluGOgmc8sYvae\n"
    "tO2EsXKYG6lnYhudZiepVYORP8vqLyxCF/bMIuuVKOPWSfsRGo/H8pnK\n"
    "-----END PRIVATE KEY-----\n",
    private=True)


class TestClose(unittest.TestCase):
    def test___init__(self):
        close = Close()

        self.assertIsNotNone(close)

    def test_process(self):
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        close = Close()
        close.process(state)

        state.msg_sock.sock.close.called_once_with()

class TestTCPBufferingEnable(unittest.TestCase):
    def test___init__(self):
        node = TCPBufferingEnable()

        self.assertIsNotNone(node)
        self.assertTrue(node.is_command())
        self.assertFalse(node.is_expect())
        self.assertFalse(node.is_generator())

    @mock.patch('socket.socket')
    def test_generate(self, raw_sock):
        state = ConnectionState()
        conn = Connect('localhost', 4433)
        conn.process(state)

        self.assertFalse(state.msg_sock.sock.buffer_writes)

        node = TCPBufferingEnable()
        node.process(state)

        self.assertTrue(state.msg_sock.sock.buffer_writes)

class TestTCPBufferingDisable(unittest.TestCase):
    def test___init__(self):
        node = TCPBufferingDisable()

        self.assertIsNotNone(node)
        self.assertTrue(node.is_command())
        self.assertFalse(node.is_expect())
        self.assertFalse(node.is_generator())

    @mock.patch('socket.socket')
    def test_generate(self, raw_sock):
        state = ConnectionState()
        conn = Connect('localhost', 4433)
        conn.process(state)

        self.assertFalse(state.msg_sock.sock.buffer_writes)

        node = TCPBufferingEnable()
        node.process(state)

        self.assertTrue(state.msg_sock.sock.buffer_writes)

        node = TCPBufferingDisable()
        node.process(state)

        self.assertFalse(state.msg_sock.sock.buffer_writes)


class TestTCPBufferingFlush(unittest.TestCase):
    def test___init__(self):
        node = TCPBufferingFlush()

        self.assertIsNotNone(node)
        self.assertTrue(node.is_command())
        self.assertFalse(node.is_expect())
        self.assertFalse(node.is_generator())

    @mock.patch('socket.socket')
    def test_generate(self, raw_sock):
        state = ConnectionState()
        conn = Connect('localhost', 4433)
        conn.process(state)

        node = TCPBufferingEnable()
        node.process(state)

        node = RawMessageGenerator(12, bytearray(b'\xff'))
        msg = node.generate(state)
        state.msg_sock.sendMessageBlocking(msg)

        raw_sock.return_value.send.assert_not_called()
        raw_sock.return_value.sendall.assert_not_called()

        flush = TCPBufferingFlush()
        flush.process(state)

        raw_sock.return_value.sendall.assert_called_once_with(
                bytearray(b'\x0c\x03\x00\x00\x01\xff'))


class TestResetWriteConnectionState(unittest.TestCase):
    def test__init__(self):
        node = ResetWriteConnectionState()

        self.assertIsNotNone(node)
        self.assertTrue(node.is_command())
        self.assertFalse(node.is_expect())
        self.assertFalse(node.is_generator())

    def test_process(self):
        state = ConnectionState()
        socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        state.msg_sock = messagesocket.MessageSocket(socket,
                                                     defragger)

        state.msg_sock.version = (3, 3)
        state.msg_sock.encryptThenMAC = True
        state.msg_sock.calcPendingStates(
                constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                bytearray(48),
                bytearray(32),
                bytearray(32),
                None)
        state.msg_sock.changeWriteState()
        state.msg_sock.changeReadState()

        self.assertIsNotNone(state.msg_sock._writeState.macContext)
        self.assertIsNotNone(state.msg_sock._writeState.encContext)
        self.assertTrue(state.msg_sock._writeState.encryptThenMAC)

        node = ResetWriteConnectionState()
        node.process(state)

        self.assertIsNotNone(state.msg_sock._readState.macContext)
        self.assertIsNotNone(state.msg_sock._readState.encContext)
        self.assertTrue(state.msg_sock._readState.encryptThenMAC)

        self.assertIsNone(state.msg_sock._writeState.macContext)
        self.assertIsNone(state.msg_sock._writeState.encContext)
        self.assertIsNone(state.msg_sock._writeState.fixedNonce)
        self.assertEqual(state.msg_sock._writeState.seqnum, 0)
        self.assertFalse(state.msg_sock._writeState.encryptThenMAC)


class TestCollectNonces(unittest.TestCase):
    def test__init__(self):
        nonces = []
        node = CollectNonces(nonces)

        self.assertTrue(node.is_command())

    def test_process(self):
        state = ConnectionState()

        sock = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        state.msg_sock = messagesocket.MessageSocket(sock,
                                                     defragger)


        state.msg_sock.version = (3, 3)
        state.msg_sock.calcPendingStates(constants.CipherSuite.
                                         TLS_RSA_WITH_AES_128_GCM_SHA256,
                                         bytearray(48),
                                         bytearray(32),
                                         bytearray(32),
                                         None)
        state.msg_sock.changeWriteState()

        nonces = []
        node = CollectNonces(nonces)
        node.process(state)

        node = ApplicationDataGenerator(b'some text')
        msg = node.generate(state)
        state.msg_sock.sendMessageBlocking(msg)

        self.assertEqual(nonces,
                         [bytearray(b'\xa9\xfc\x88\x1d'
                                    b'\x00\x00\x00\x00\x00\x00\x00\x00')])


class TestCopyVariables(unittest.TestCase):
    def test___init__(self):
        node = CopyVariables({})

        self.assertIsNotNone(node)
        self.assertTrue(node.is_command())
        self.assertFalse(node.is_expect())
        self.assertFalse(node.is_generator())

    def test_process_client_hello_random(self):
        state = ConnectionState()
        state.client_random = bytearray(b'test')

        arr = []
        node = CopyVariables({'ClientHello.random': arr})

        node.process(state)

        self.assertEqual(arr, [bytearray(b'test')])

    def test_process_server_hello_random(self):
        state = ConnectionState()
        state.server_random = bytearray(b'abba')

        arr = []
        node = CopyVariables({'ServerHello.random': arr})

        node.process(state)

        self.assertEqual(arr, [bytearray(b'abba')])

    def test_process_server_hello_session_id(self):
        state = ConnectionState()
        state.session_id = bytearray(b'some session ID')

        arr = []
        node = CopyVariables({'ServerHello.session_id': arr})

        node.process(state)

        self.assertEqual(arr, [bytearray(b'some session ID')])

    def test_process_premaster_secret(self):
        state = ConnectionState()
        state.key['premaster_secret'] = bytearray(b'premaster secret')

        arr = []
        node = CopyVariables({'premaster_secret': arr})

        node.process(state)

        self.assertEqual(arr, [bytearray(b'premaster secret')])

    def test_process_with_incorrect_name(self):
        state = ConnectionState()

        arr = []
        node = CopyVariables({'DH shared secret': arr})

        with self.assertRaises(ValueError):
            node.process(state)


class TestConnect(unittest.TestCase):
    def test___init__(self):
        connect = Connect(1, 2)

        self.assertIsNotNone(connect)
        self.assertEqual(connect.hostname, 1)
        self.assertEqual(connect.port, 2)
        self.assertEqual(connect.version, (3, 0))
        self.assertEqual(connect.timeout, 5)

    @mock.patch('socket.socket')
    def test_process(self, mock_sock):
        state = ConnectionState()
        connect = Connect("localhost", 4433)

        connect.process(state)

        self.assertEqual(state.msg_sock.version, (3, 0))

        mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM)
        instance = mock_sock.return_value
        instance.connect.assert_called_once_with(("localhost", 4433))
        self.assertIs(state.msg_sock.sock.socket, instance)

    @mock.patch('socket.socket')
    def test_process_with_SSLv2(self, mock_sock):
        state = ConnectionState()
        connect = Connect(1, 2, (0, 2))

        connect.process(state)

        self.assertEqual(state.msg_sock.version, (0, 2))

        mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM)
        instance = mock_sock.return_value
        instance.connect.assert_called_once_with((1, 2))
        self.assertIs(state.msg_sock.sock.socket, instance)

    @mock.patch('socket.socket')
    def test_process_with_timeout(self, mock_sock):
        state = ConnectionState()
        connect = Connect(1, 2, timeout=10)

        connect.process(state)

        mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM)
        instance = mock_sock.return_value
        instance.connect.assert_called_once_with((1, 2))
        instance.settimeout.assert_called_once_with(10)
        self.assertIs(state.msg_sock.sock.socket, instance)


class TestSetRecordVersion(unittest.TestCase):
    def test___init__(self):
        msg_gen = SetRecordVersion((3, 3))

        self.assertIsNotNone(msg_gen)
        self.assertTrue(msg_gen.is_command())
        self.assertFalse(msg_gen.is_expect())
        self.assertFalse(msg_gen.is_generator())

    def test_process(self):
        msg_gen = SetRecordVersion((3, 2))

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        msg_gen.process(state)

        self.assertEqual(state.msg_sock.version, (3, 2))


class TestRawSocketWriteGenerator(unittest.TestCase):
    def test___init__(self):
        msg_gen = RawSocketWriteGenerator(bytearray(b'some data'))

        self.assertIsNotNone(msg_gen)
        self.assertTrue(msg_gen.is_command())
        self.assertFalse(msg_gen.is_expect())
        self.assertFalse(msg_gen.is_generator())

    def test___repr__(self):
        msg_gen = RawSocketWriteGenerator(bytearray(b'some data'))

        self.assertEqual(repr(msg_gen),
                         "RawSocketWriteGenerator(data=bytearray(b'some data'))")

    def test___repr___with_description(self):
        msg_gen = RawSocketWriteGenerator(bytearray(b'some data'),
                                          description="STARTTLS")

        self.assertEqual(repr(msg_gen),
                         "RawSocketWriteGenerator("
                         "data=bytearray(b'some data'), "
                         "description='STARTTLS')")

    def test_process(self):
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        msg_gen = RawSocketWriteGenerator(b'some data')

        msg_gen.process(state)

        self.assertTrue(
            state.msg_sock._recordSocket.
            sock.send.called_once_with(b'some data'))


class TestPlaintextMessageGenerator(unittest.TestCase):
    def test___init__(self):
        msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00'))

        self.assertIsNotNone(msg_gen)
        self.assertTrue(msg_gen.is_command())
        self.assertFalse(msg_gen.is_expect())
        self.assertFalse(msg_gen.is_generator())

    def test___repr__(self):
        msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00'))

        self.assertEqual(repr(msg_gen),
                         "PlaintextMessageGenerator(content_type=12, "
                         "data=bytearray(b'\\x00\\x00'))")

    def test___repr___with_description(self):
        msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00'),
                                            description="some message")

        self.assertEqual(repr(msg_gen),
                         "PlaintextMessageGenerator(content_type=12, "
                         "data=bytearray(b'\\x00\\x00'), "
                         "description='some message')")

    def test_process(self):
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00'))

        msg_gen.process(state)

        self.assertTrue(state.msg_sock._recordSocket.send.called)


class TestRawMessageGenerator(unittest.TestCase):
    def test___init__(self):
        message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02'))

        self.assertIsNotNone(message_gen)
        self.assertEqual(message_gen.content_type, 12)
        self.assertEqual(message_gen.data, bytearray(b'\xff\x02'))

    def test_gen(self):
        message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02'))

        message = message_gen.generate(None)

        self.assertIsNotNone(message)
        self.assertEqual(message.contentType, 12)
        self.assertEqual(message.write(), bytearray(b'\xff\x02'))

    def test___repr__(self):
        message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02'))

        self.assertEqual(repr(message_gen),
                         "RawMessageGenerator(content_type=12, "\
                         "data=bytearray(b'\\xff\\x02'))")

    def test___repr___with_description(self):
        message_gen = RawMessageGenerator(12, bytearray(b'\xff'),
                                          description="a broken message")

        self.assertEqual(repr(message_gen),
                         "RawMessageGenerator(content_type=12, "\
                         "data=bytearray(b'\\xff'), description='a broken "\
                         "message')")

class TestClientHelloGenerator(unittest.TestCase):
    def test___init__(self):
        chg = ClientHelloGenerator()

        self.assertIsNotNone(chg)
        self.assertEqual(chg.ciphers, [])

    def test___repr__(self):
        chg = ClientHelloGenerator()
        chg.compression = []

        self.assertEqual("ClientHelloGenerator()", repr(chg))

    def test___repr___with_values(self):
        chg = ClientHelloGenerator(
            [2, 3], ["ext"], (3, 3), bytearray(b'sess'), bytearray(b'random'),
            [0xc0], True, ["mod"])

        self.assertEqual(
            "ClientHelloGenerator("
            "ssl2=True, "
            "version=(3, 3), "
            "ciphers=[2, 3], "
            "random=bytearray(b'random'), "
            "session_id=bytearray(b'sess'), "
            "compression=[192], "
            "extensions=['ext'], "
            "modifiers=['mod'])",
            repr(chg))

    def test_generate(self):
        state = ConnectionState()
        chg = ClientHelloGenerator()

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)
        mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0),
                                            [], extensions=None)

    def test_generate_extensions_with_missing_generator(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(extensions={0x1234:None})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with self.assertRaises(ValueError):
            with mock.patch.object(messages.ClientHello, 'create',
                    return_value=return_val) as mock_method:
                chg.generate(state)

    def test_generate_extensions_with_auto_generator(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(extensions={0x1234: AutoEmptyExtension()})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)
        ext = extensions.TLSExtension().create(0x1234, bytearray(0))
        mock_method.assert_called_once_with((3, 3), bytearray(32),
                                            bytearray(0), [], extensions=[ext])

    def test_generate_extensions_with_raw_extension(self):
        state = ConnectionState()
        ext = extensions.TLSExtension().create(extType=0x1234, data=None)
        chg = ClientHelloGenerator(extensions={0x1234:ext})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)
        mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0),
                                            [],
                                            extensions=[ext])

    def test_generate_extensions_with_garbage_extension(self):
        state = ConnectionState()
        ext = "some weird non-extension"
        chg = ClientHelloGenerator(extensions={0x1234:ext})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            with self.assertRaises(ValueError):
                ch = chg.generate(state)

    def test_generate_extensions_with_ext_generator(self):
        state = ConnectionState()
        ext_gen = mock.MagicMock()
        chg = ClientHelloGenerator(extensions={0x1234:ext_gen})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)

        ext_gen.assert_called_once_with(state)
        mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0),
                                            [],
                                            extensions=[ext_gen()])

    def test_generate_extensions_with_automatic_status_request(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(extensions=
            {constants.ExtensionType.status_request: None})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)

        mock_method.assert_called_once_with(
            (3, 3), bytearray(32), bytearray(0),
            [],
            extensions=[extensions.StatusRequestExtension().create()])

    def test_generate_extensions_with_automatic_encrypt_then_mac(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(extensions=
            {constants.ExtensionType.encrypt_then_mac: None})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)

        mock_method.assert_called_once_with(
            (3, 3), bytearray(32), bytearray(0),
            [],
            extensions=[extensions.TLSExtension(extType=22).create(b'')])

    def test_generate_extensions_with_renego_info_default_generator(self):
        state = ConnectionState()
        state.key['client_verify_data'] = bytearray(b'\xab\xcd')
        chg = ClientHelloGenerator(extensions={constants.ExtensionType.renegotiation_info:
                                               None})

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)

        ext = RenegotiationInfoExtension().create(bytearray(b'\xab\xcd'))
        mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0),
                                            [],
                                            extensions=[ext])

    def test_generate_with_random(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(random=bytearray(b'\x33'*32))

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                               return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)
        mock_method.assert_called_once_with((3, 3), bytearray(b'\x33'*32),
                                            bytearray(0), [], extensions=None)

    def test_generate_with_compression_methods(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(compression=[0, 2, 3])

        return_val = mock.MagicMock()
        return_val.write = mock.MagicMock(return_value=bytearray(10))
        with mock.patch.object(messages.ClientHello, 'create',
                               return_value=return_val) as mock_method:
            ch = chg.generate(state)

        self.assertEqual(ch, return_val)
        self.assertEqual(ch.compression_methods, [0, 2, 3])
        mock_method.assert_called_once_with((3, 3), bytearray(32),
                                            bytearray(0), [], extensions=None)

    def test_session_id_with_no_settings(self):
        state = ConnectionState()
        chg = ClientHelloGenerator(version=(3, 4))

        msg = chg.generate(state)

        self.assertEqual(msg.session_id, b'')

    def test_seesion_id_with_tls13_extension(self):
        state = ConnectionState()
        exts = {constants.ExtensionType.supported_versions:
                AutoEmptyExtension()}
        chg = ClientHelloGenerator(version=(3, 3), extensions=exts)

        msg = chg.generate(state)

        self.assertEqual(len(msg.session_id), 32)

    def test_session_id_with_explicit_id_and_tls13_extension(self):
        state = ConnectionState()
        exts = {constants.ExtensionType.supported_versions:
                AutoEmptyExtension()}
        chg = ClientHelloGenerator(version=(3, 3), extensions=exts,
                                   session_id=b'')

        msg = chg.generate(state)

        self.assertEqual(msg.session_id, b'')

    def test_modifiers(self):
        state = ConnectionState()
        psk_configs = [(b'test', b'secret')]
        exts = OrderedDict()
        exts[constants.ExtensionType.supported_versions] = \
            AutoEmptyExtension()
        exts[constants.ExtensionType.pre_shared_key] = \
            psk_ext_gen(psk_configs)
        modifiers = [psk_ext_updater(psk_configs)]
        chg = ClientHelloGenerator(version=(3, 3), extensions=exts,
                                   session_id=b'\xaa'*32,
                                   modifiers=modifiers)

        msg = chg.generate(state)

        self.assertEqual(len(msg.extensions), 2)
        self.assertIsInstance(msg.extensions[1],
                              extensions.PreSharedKeyExtension)
        ext = msg.extensions[1]
        self.assertEqual(len(ext.binders), 1)
        self.assertEqual(
            ext.binders[0],
            bytearray(b"\x04!\xd0\xee\x0c\xe8\x13W\xa9\x85\xcc\xce\x07U\x87"
                      b"\xe6\'\xfa\xec\xf4\xf7\x88\x8b\xf3\xc2\xc3^\xf3<\x8b"
                      b"\xba%"))

    def test_with_session_tickets(self):
        state = ConnectionState()
        state.key['resumption master secret'] = bytearray(b'\x10' * 48)
        state.cipher = constants.CipherSuite.TLS_AES_256_GCM_SHA384
        state.session_tickets = [
            messages.NewSessionTicket().create(
                101, 0, bytearray(b'some nonce'),
                bytearray(b'ticket identity'), [])]
        state.session_tickets[0].time = 12.12
        exts = OrderedDict()
        exts[constants.ExtensionType.supported_versions] = \
            AutoEmptyExtension()
        exts[constants.ExtensionType.pre_shared_key] = \
            psk_session_ext_gen()
        modifiers = [psk_ext_updater()]
        chg = ClientHelloGenerator(version=(3, 3), extensions=exts,
                                   session_id=b'\xaa'*32,
                                   modifiers=modifiers)

        with mock.patch('time.time') as mthd:
            mthd.return_value = 1532009672.427739153
            msg = chg.generate(state)

        self.assertEqual(len(msg.extensions), 2)
        self.assertIsInstance(msg.extensions[1],
                              extensions.PreSharedKeyExtension)
        ext = msg.extensions[1]
        self.assertEqual(len(ext.binders), 1)
        self.assertEqual(
            ext.binders[0],
            bytearray(b'6f&\xe42^a\x13\x949\xc3Q\xb0\x1fD\xb6H\xba\xc1=D\xa6'
                      b'\xcb1\x05\xd72\x1eKO\x8d\xf9V\xa13\xdc\x94b\xdb\xc6'
                      b'\x92\\\xe9\xd2\xd7Tv\xb2'))


class TestClientHelloExtensionGenerators(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()

        exts = [extensions.CookieExtension().create(b'some payload'),
                extensions.HRRKeyShareExtension().create(
                    constants.GroupName.secp256r1)]

        hrr = messages.ServerHello()
        hrr.create(version=(3, 3),
                   random=constants.TLS_1_3_HRR,
                   session_id=b'',
                   cipher_suite=0x04,
                   extensions=exts)

        self.state.handshake_messages.append(hrr)

    def test_ch_cookie_handler(self):
        ext = ch_cookie_handler(self.state)

        self.assertIsInstance(ext, extensions.CookieExtension)
        self.assertEqual(ext.cookie, b'some payload')

    def test_ch_cookie_handler_with_no_hrr(self):
        self.state.handshake_messages = []

        with self.assertRaises(ValueError) as e:
            ch_cookie_handler(self.state)

        self.assertIn("No HRR received", str(e.exception))

    def test_ch_key_share_handler(self):
        ext = ch_key_share_handler(self.state)

        self.assertIsInstance(ext, extensions.ClientKeyShareExtension)
        self.assertEqual(len(ext.client_shares), 1)
        self.assertIsInstance(ext.client_shares[0], extensions.KeyShareEntry)
        self.assertEqual(ext.client_shares[0].group,
                         constants.GroupName.secp256r1)

    def test_ch_key_share_handler_with_no_hrr(self):
        self.state.handshake_messages = []

        with self.assertRaises(ValueError) as e:
            ch_key_share_handler(self.state)

        self.assertIn("No HRR received", str(e.exception))


class TestClientKeyExchangeGenerator(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.priv_key = keyfactory.generateRSAKey(1024)

    def test___init__(self):
        cke = ClientKeyExchangeGenerator()

        self.assertEqual(len(cke.premaster_secret), 48)

    def test___init___with_invalid_param(self):
        with self.assertRaises(ValueError):
            cke = ClientKeyExchangeGenerator(p_as_share=True,
                                             p_1_as_share=True)

    def test_generate(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA)

        ret = cke.generate(state)

        self.assertEqual(len(ret.encryptedPreMasterSecret), 128)
        decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret)

        self.assertEqual(decrypt[:2], bytearray([3, 3]))
        self.assertEqual(decrypt[2:], bytearray([0]*46))

    def test_generate_with_custom_premaster_secret(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                premaster_secret=bytearray([1]*10))

        ret = cke.generate(state)

        self.assertEqual(len(ret.encryptedPreMasterSecret), 128)
        decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret)

        self.assertEqual(decrypt[:2], bytearray([3, 3]))
        self.assertEqual(decrypt[2:], bytearray([1]*8))

    def test_generate_with_node_reuse_and_no_encrypted_value_reuse(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                premaster_secret=bytearray([1]*10))

        ret = cke.generate(state)

        self.assertEqual(len(ret.encryptedPreMasterSecret), 128)
        decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret)

        self.assertEqual(decrypt[:2], bytearray([3, 3]))
        self.assertEqual(decrypt[2:], bytearray([1]*8))

        ret2 = cke.generate(state)
        self.assertNotEqual(ret.encryptedPreMasterSecret,
                            ret2.encryptedPreMasterSecret)

    def test_generate_with_node_reuse_and_encrypted_value_reuse(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                premaster_secret=bytearray([1]*10),
                reuse_encrypted_premaster=True)

        ret = cke.generate(state)

        self.assertEqual(len(ret.encryptedPreMasterSecret), 128)
        decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret)

        self.assertEqual(decrypt[:2], bytearray([3, 3]))
        self.assertEqual(decrypt[2:], bytearray([1]*8))

        ret2 = cke.generate(state)
        self.assertEqual(ret.encryptedPreMasterSecret,
                         ret2.encryptedPreMasterSecret)

    def test_generate_with_dhe(self):
        state = ConnectionState()
        state.key_exchange = mock.MagicMock()

        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA)

        ret = cke.generate(state)

        self.assertIs(ret, state.key_exchange.makeClientKeyExchange())

    def test_generate_with_ecdhe(self):
        state = ConnectionState()
        state.key_exchange = mock.MagicMock()

        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA)

        ret = cke.generate(state)

        self.assertIs(ret, state.key_exchange.makeClientKeyExchange())

    def test_generate_with_unknown_cipher(self):
        state = ConnectionState()
        cke = ClientKeyExchangeGenerator()
        with self.assertRaises(AssertionError):
            cke.generate(state)

    def test_generate_DHE_with_bogus_value(self):
        state = ConnectionState()
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                dh_Yc=4982)

        ret = cke.generate(state)
        self.assertEqual(ret.dh_Yc, 4982)

    def test_generate_ECDHE_with_bogus_value(self):
        state = ConnectionState()
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
                ecdh_Yc=bytearray(range(1, 24)))

        ret = cke.generate(state)
        self.assertEqual(ret.ecdh_Yc, bytearray(range(1, 24)))

    def test_generate_with_all_null_RSA(self):
        state = ConnectionState()
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                encrypted_premaster=bytearray(512))

        ret = cke.generate(state)
        self.assertEqual(ret.encryptedPreMasterSecret, bytearray(512))

    def test_generate_with_modulus_as_premaster(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                modulus_as_encrypted_premaster=True)

        ret = cke.generate(state)
        self.assertEqual(ret.encryptedPreMasterSecret,
                         numberToByteArray(self.priv_key.n))

    def test_generate_with_p_as_share(self):
        state = ConnectionState()
        ske = messages.ServerKeyExchange(
                constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                (3, 3))
        ske.createDH(21, 2, 11)
        state.handshake_messages.append(ske)
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                p_as_share=True)

        ret = cke.generate(state)
        self.assertEqual(ret.dh_Yc, 21)

    def test_generate_with_p_1_as_share(self):
        state = ConnectionState()
        ske = messages.ServerKeyExchange(
                constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                (3, 3))
        ske.createDH(21, 2, 11)
        state.handshake_messages.append(ske)
        cke = ClientKeyExchangeGenerator(
                cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                p_1_as_share=True)

        ret = cke.generate(state)
        self.assertEqual(ret.dh_Yc, 20)

    def test_post_send(self):
        state = ConnectionState()
        state.get_server_public_key = lambda : self.priv_key
        cke = ClientKeyExchangeGenerator(
                constants.CipherSuite.TLS_RSA_WITH_NULL_MD5,
                (3, 3))

        ret = cke.generate(state)

        cke.post_send(state)

class TestChangeCipherSpecGenerator(unittest.TestCase):
    def test___init__(self):
        ccs = ChangeCipherSpecGenerator()

        self.assertIsNotNone(ccs)

    def test_generate(self):
        ccs = ChangeCipherSpecGenerator()
        ret = ccs.generate(None)

        self.assertIsInstance(ret, messages.ChangeCipherSpec)

    def test_post_send(self):
        ccsg = ChangeCipherSpecGenerator()
        ccsg.generate(None)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        ccsg.post_send(state)

        self.assertTrue(state.msg_sock.calcPendingStates.called)
        self.assertTrue(state.msg_sock.changeWriteState.called)

    def test_post_send_with_extended_master_secret(self):
        ccsg = ChangeCipherSpecGenerator()
        ccsg.generate(None)
        state = ConnectionState()
        state.extended_master_secret = True
        state.msg_sock = mock.MagicMock()

        with mock.patch('tlsfuzzer.messages.calcExtendedMasterSecret') as mthd:
            mthd.return_value = bytearray(48)
            ccsg.post_send(state)
        mthd.assert_called_once_with(state.version, state.cipher,
                                     state.key['premaster_secret'],
                                     state.handshake_hashes)
        self.assertTrue(state.msg_sock.calcPendingStates.called)
        self.assertTrue(state.msg_sock.changeWriteState.called)

    def test_post_send_with_fake_true(self):
        ccsg = ChangeCipherSpecGenerator(fake=True)
        ccsg.generate(None)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        with mock.patch('tlsfuzzer.messages.calcExtendedMasterSecret') as mthd:
            mthd.return_value = bytearray(48)
            ccsg.post_send(state)

        self.assertFalse(mthd.called)


class TestClientMasterKeyGenerator(unittest.TestCase):
    def test___init__(self):
        cmk = ClientMasterKeyGenerator()

    def test_generate_with_no_cipher(self):
        cmk = ClientMasterKeyGenerator()

        with self.assertRaises(NotImplementedError):
            cmk.generate(None)

    def test_generate(self):
        cmk = ClientMasterKeyGenerator(
                cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.get_server_public_key = mock.MagicMock()

        ret = cmk.generate(state)
        self.assertEqual(ret.cipher,
                         constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5)
        self.assertEqual(ret.clear_key, bytearray(0))
        self.assertEqual(ret.encrypted_key,
                         state.get_server_public_key().encrypt())
        self.assertEqual(ret.key_argument,
                         state.msg_sock.calcSSL2PendingStates())

    def test_generate_with_master_key(self):
        cmk = ClientMasterKeyGenerator(
                cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5,
                master_key=bytearray(range(24)))
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.get_server_public_key = mock.MagicMock()

        ret = cmk.generate(state)

        state.msg_sock.calcSSL2PendingStates.assert_called_once_with(
                constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5,
                bytearray(range(24)),
                bytearray(0),
                bytearray(0),
                None)

    def test_generate_with_export_cipher(self):
        cmk = ClientMasterKeyGenerator(
                cipher=constants.CipherSuite.SSL_CK_RC2_128_CBC_EXPORT40_WITH_MD5)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.get_server_public_key = mock.MagicMock()

        ret = cmk.generate(state)

        self.assertEqual(len(ret.clear_key), 11)

    def test_generate_with_unknown_cipher(self):
        cmk = ClientMasterKeyGenerator(cipher=0xffffff)
        state = ConnectionState()

        with self.assertRaises(AssertionError):
            cmk.generate(state)

    def test_generate_with_des_cipher(self):
        cmk = ClientMasterKeyGenerator(
                cipher=constants.CipherSuite.SSL_CK_DES_64_CBC_WITH_MD5)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.get_server_public_key = mock.MagicMock()

        ret = cmk.generate(state)

        self.assertEqual(ret.encrypted_key,
                         state.get_server_public_key().encrypt())

    def test_generate_with_session_key(self):
        cmk = ClientMasterKeyGenerator(
                cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.get_server_public_key = mock.MagicMock()
        state.key['master_secret'] = bytearray(range(32))

        ret = cmk.generate(state)

        state.msg_sock.calcSSL2PendingStates.assert_called_once_with(
                constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5,
                bytearray(range(32)),
                bytearray(0),
                bytearray(0),
                None)


class TestCertificateGenerator(unittest.TestCase):
    def test___init__(self):
        certg = CertificateGenerator()

        self.assertIsNotNone(certg)

    def test_generate(self):
        certg = CertificateGenerator()
        state = ConnectionState()

        msg = certg.generate(state)

        self.assertIsInstance(msg, messages.Certificate)
        self.assertIsNone(msg.cert_chain)
        self.assertEqual(msg.certificateType,
                         constants.CertificateType.x509)

    def test_generate_with_context(self):
        sig_algs = [constants.SignatureScheme.rsa_pss_rsae_sha512]
        cr = messages.CertificateRequest((3, 4)).create(
            sig_algs=sig_algs, context=b'zesty')
        ctx = []
        certg = CertificateGenerator(context=ctx)
        state = ConnectionState()
        state.version = (3, 3)

        state.handshake_messages.append(cr)
        ctx.append(cr)

        msg = certg.generate(state)

        self.assertIsInstance(msg, messages.Certificate)
        self.assertIsNone(msg.cert_chain)
        self.assertEqual(msg.certificateType,
                         constants.CertificateType.x509)
        self.assertEqual(msg.certificate_request_context, b'zesty')


class TestCertificateVerifyGenerator(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.priv_key = generateRSAKey(1024)

    def test___init__(self):
        cert_ver_g = CertificateVerifyGenerator()

        self.assertIsNotNone(cert_ver_g)

    def test_generate_without_priv_key(self):
        cert_ver_g = CertificateVerifyGenerator()
        state = ConnectionState()

        with self.assertRaises(ValueError):
            cert_ver_g.generate(state)

    def test_generate_TLS_1_1(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 2)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(b'', 'md5') + secureHash(b'', 'sha1')))

    def test_generate_TLS_1_2(self):
        priv_key = generateRSAKey(2048)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 256)
        # when no CR is present, select a secure algorithm by default
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha512)
        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(b'', 'sha512'),
            padding="pss",
            hashAlg="sha512",
            saltLen=64))

    def test_generate_TLS_1_2_with_explicit_msg_sig_alg(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(
            priv_key,
            msg_alg=(constants.HashAlgorithm.sha1,
                     constants.SignatureAlgorithm.rsa))
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.rsa),
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha1,
                          constants.SignatureAlgorithm.rsa))
        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(req.write(), 'sha1'),
            hashAlg="sha1"))

    def test_generate_TLS_1_2_with_cert_request(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.rsa),
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha256,
                          constants.SignatureAlgorithm.rsa))
        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(req.write(), 'sha256'),
            hashAlg="sha256"))

    def test_generate_TLS_1_3(self):
        # RSA-PSS+SHA512 requires a 2048 bit key
        priv_key = keyfactory.generateRSAKey(2048)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 256)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha512)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.rsa_pss_rsae_sha512,
                b'',
                b'',
                b'',
                # this is the PRF hash, not signature hash
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "pss", "sha512", saltLen=64))

    def test_generate_TLS_1_3_with_cert_request(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [(constants.HashAlgorithm.md5, constants.SignatureAlgorithm.rsa),
             constants.SignatureScheme.rsa_pkcs1_sha1,
             constants.SignatureScheme.rsa_pkcs1_sha224,
             constants.SignatureScheme.rsa_pkcs1_sha256,
             constants.SignatureScheme.rsa_pkcs1_sha384,
             constants.SignatureScheme.rsa_pkcs1_sha512,
             constants.SignatureScheme.rsa_pss_rsae_sha256])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha256)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.rsa_pss_rsae_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "pss", "sha256", saltLen=32))

    def test_generate_TLS_1_3_in_PHA_(self):
        priv_key = self.priv_key
        ctx = []
        cert_ver_g = CertificateVerifyGenerator(priv_key, context=ctx)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [(constants.HashAlgorithm.md5, constants.SignatureAlgorithm.rsa),
             constants.SignatureScheme.rsa_pkcs1_sha1,
             constants.SignatureScheme.rsa_pkcs1_sha224,
             constants.SignatureScheme.rsa_pkcs1_sha256,
             constants.SignatureScheme.rsa_pkcs1_sha384,
             constants.SignatureScheme.rsa_pkcs1_sha512,
             constants.SignatureScheme.rsa_pss_rsae_sha256],
            context=b'unique PHA context, totally random')
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())
        state.key['client finished handshake hashes'] = \
            HandshakeHashes()
        ctx.append(req)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha256)

        hh = state.key['client finished handshake hashes'].copy()
        hh.update(ctx[0].write())

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                hh,
                constants.SignatureScheme.rsa_pss_rsae_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "pss", "sha256", saltLen=32))

    def test_generate_with_mismatched_alg(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                sig_alg=(
                                                    constants.HashAlgorithm.md5,
                                                    constants.SignatureAlgorithm.rsa))
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.rsa),
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha256,
                          constants.SignatureAlgorithm.rsa))

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pkcs1", "sha256"))
        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pkcs1", "md5"))

    def test_generate_with_mismatched_CR_TLS1_2(self):
        priv_key = generateRSAKey(2048)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.ecdsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 256)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha512)

        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pss", "sha512", 64))

    def test_generate_with_mismatched_CR_TLS1_3(self):
        priv_key = generateRSAKey(2048)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.ecdsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 256)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha512)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.rsa_pss_rsae_sha512,
                b'',
                b'',
                b'',
                # this is the PRF hash, not signature hash
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "pss", "sha512", saltLen=64))

    def test_generate_with_mismatched_CR_RSA_PSS(self):
        priv_key = Python_RSAKey.generate(2048, "rsa-pss")
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha256,
              constants.SignatureAlgorithm.ecdsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 256)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_pss_sha512)

        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pss", "sha512", 64))

    def test_generate_with_rsa_pss_rsae_alg(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_sha256)

        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pss", "sha256", 32))

    def test_generate_with_no_key(self):
        cert_ver_g = CertificateVerifyGenerator(signature=bytearray(b'xxxx'))
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(msg.signature, bytearray(b'xxxx'))
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha256)

    def test_generate_with_rsa_pss_pss_alg(self):
        priv_key = Python_RSAKey.generate(1024, "rsa-pss")
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]
        state.handshake_hashes.update(req.write())

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_pss_sha256)

        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, req.write(),
            "pss", "sha256", 32))

    def test_generate_with_rsa_with_subs(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_subs={4: 0x00})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        dec_sig = numberToByteArray(priv_key._rawPublicKeyOp(
                                                bytesToNumber(msg.signature)),
                                    128)
        self.assertEqual(dec_sig[4], 0x00)
        # since we're siging the same set of messages always, the hash in
        # signature is also the same
        self.assertEqual(dec_sig[-1], 0x09)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pkcs1_sha1)

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pkcs1", "sha1"))

    def test_generate_with_rsa_with_large_xors(self):
        # the xor can make the encrypted value too large (larger than the
        # modulus) verify that it doesn't break encryption (but the signature
        # will be completely invalid)
        priv_key = generateRSAKey(1020)
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={0: 0x80})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        dec_sig = numberToByteArray(priv_key._rawPublicKeyOp(
                                                bytesToNumber(msg.signature)),
                                    128)
        self.assertEqual(dec_sig[0] & 0x80, 0x00)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pkcs1_sha1)

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pkcs1", "sha1"))

    def test_generate_with_rsa_with_xors(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={4: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [(constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        dec_sig = numberToByteArray(priv_key._rawPublicKeyOp(
                                                bytesToNumber(msg.signature)),
                                    128)
        self.assertEqual(dec_sig[4], 0x00)
        # since we're siging the same set of messages always, the hash in
        # signature is also the same
        self.assertEqual(dec_sig[-1], 0x09)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pkcs1_sha1)

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pkcs1", "sha1"))

    def test_generate_with_rsa_pss_with_subs(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_subs={2: 0xff,
                                                              -1: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        dec_sig = numberToByteArray(priv_key._rawPublicKeyOp(
                                                bytesToNumber(msg.signature)),
                                    128)
        self.assertEqual(dec_sig[2], 0xff)
        self.assertEqual(dec_sig[-1], 0xff)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_sha256)

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pss", "sha256", 32))

    def test_generate_with_rsa_pss_with_xors(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={-1: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        dec_sig = numberToByteArray(priv_key._rawPublicKeyOp(
                                                bytesToNumber(msg.signature)),
                                    128)
        self.assertEqual(dec_sig[-1], 0xbc ^ 0xff)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_sha256)

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pss", "sha256", 32))

    def test_generate_with_tls1_0_version(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 1))
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256,
             constants.SignatureScheme.rsa_pkcs1_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        # because we are creating a TLS 1.0 signature, we want to select
        # a signature algorithm that matches TLS 1.0 most closely
        # (the last 20 bytes of a decrypted signature will match)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha1,
                          constants.SignatureAlgorithm.rsa))

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pkcs1", "sha1"))
        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(b"", "md5") + secureHash(b"", "sha1")))

    def test_generate_with_ssl3_0_version(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 0))
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256,
             constants.SignatureScheme.rsa_pkcs1_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        # just like with TLS 1.0, we want to use one that uses
        # algorithms closest to the ones we will actually use in the signature
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha1,
                          constants.SignatureAlgorithm.rsa))

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, b'',
            "pkcs1", "sha1"))
        verify_bytes = KeyExchange.calcVerifyBytes(
                (3, 0),
                state.handshake_hashes,
                None,
                b'',
                b'',
                b'')
        self.assertTrue(priv_key.verify(
            msg.signature,
            verify_bytes))

    def test_generate_with_empty_signature(self):
        cert_ver_g = CertificateVerifyGenerator(signature=bytearray())
        state = ConnectionState()
        state.version = (3, 3)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(msg.signature, bytearray())

    def test_generate_with_mismatched_mgf1(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 4),
                                                mgf1_hash="sha512")
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha256)

        verify_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.rsa_pss_rsae_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature,
            verify_bytes,
            "pss", "sha512", 32))

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, verify_bytes,
            "pss", "sha256", 32))

    def test_generate_with_mismatched_mgf1_and_salt_len(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 4),
                                                mgf1_hash="sha384",
                                                rsa_pss_salt_len=48)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_rsae_sha256)

        verify_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.rsa_pss_rsae_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature,
            verify_bytes,
            "pss", "sha384", 48))

        self.assertFalse(priv_key.hashAndVerify(
            msg.signature, verify_bytes,
            "pss", "sha256", 32))

    def test_generate_with_key_and_certificate(self):
        priv_key = Python_RSAKey.generate(1024, "rsa-pss")
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        cert = Certificate(constants.CertificateType.x509, (3, 3))\
            .create(rsa_pss_cert)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_rsae_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha384])

        state.handshake_messages = [req, cert]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(len(msg.signature), 128)
        # when no CR is present, select a secure algorithm by default
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.rsa_pss_pss_sha256)

        self.assertTrue(priv_key.verify(
            msg.signature,
            secureHash(b"", "sha256"),
            "pss", "sha256", 32))


class TestCertificateVerifyGeneratorECDSA(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.priv_key = Python_ECDSAKey(None, None, "NIST256p", 12)

    def test_generate_with_ecdsa_and_no_cert_req_in_tls1_3(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha256,
                          constants.SignatureAlgorithm.ecdsa))

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ecdsa_secp256r1_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "", "sha256"))

    def test_generate_with_ecdsa_cert_and_no_key_in_tls1_3(self):
        cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF")
        state = ConnectionState()
        state.version = (3, 4)
        cert = messages.Certificate(constants.CertificateType.x509, (3, 4))\
            .create(ecdsa_cert)
        state.handshake_messages = [cert]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(msg.signature, bytearray(b'ABCDEF'))
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha256,
                          constants.SignatureAlgorithm.ecdsa))

    def test_generate_with_ecdsa_cert_and_no_key_in_tls1_2(self):
        cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF")
        state = ConnectionState()
        state.version = (3, 3)
        cert = messages.Certificate(constants.CertificateType.x509, (3, 3))\
            .create(ecdsa_cert)
        state.handshake_messages = [cert]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(msg.signature, bytearray(b'ABCDEF'))
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

    def test_generate_with_ecdsa_cert_and_no_key_unmatched_CR_in_tls1_2(self):
        cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF")
        state = ConnectionState()
        state.version = (3, 3)
        cert = messages.Certificate(constants.CertificateType.x509, (3, 3))\
            .create(ecdsa_cert)
        cert_req = messages.CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_pss_sha256])
        state.handshake_messages = [cert_req, cert]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertEqual(msg.signature, bytearray(b'ABCDEF'))
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

    def test_generate_with_ecdsa_in_TLS1_0(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 1)
        req = CertificateRequest((3, 1)).create()
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm, None)

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha1"),
            "", "sha1"))

    def test_generate_with_ecdsa_with_TLS1_0_format(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 1))
        state = ConnectionState()
        state.version = (3, 3)
        req = messages.CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp256r1_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.ecdsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha1,
                          constants.SignatureAlgorithm.ecdsa))

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha1"),
            "", "sha1"))

    def test_generate_with_ecdsa_256_alg(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha512")[:32],
            "", "sha512"))

    def test_generate_with_ecdsa_256_alg_and_non_matching_CR_tls1_3(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha256,
                          constants.SignatureAlgorithm.ecdsa))

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ecdsa_secp256r1_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "", "sha256"))

    def test_generate_with_ecdsa_256_alg_and_non_matching_CR_tls1_2(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha512")[:32],
            "", "sha512"))

    def test_generate_with_ecdsa_384_alg(self):
        priv_key = Python_ECDSAKey(None, None, "NIST384p", 11)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha512")[:48],
            "", "sha512"))

    def test_generate_with_ecdsa_521_alg(self):
        priv_key = Python_ECDSAKey(None, None, "NIST521p", 10)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))

        self.assertTrue(priv_key.verify(
            msg.signature, secureHash(b"", "sha512"),
            "", "sha512"))

    def test_generate_with_ecdsa_256_alg_in_tls1_3(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.ecdsa_secp256r1_sha256)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ecdsa_secp256r1_sha256,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "", "sha256"))

    def test_generate_with_ecdsa_384_alg(self):
        priv_key = Python_ECDSAKey(None, None, "NIST384p", 11)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.ecdsa_secp384r1_sha384)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ecdsa_secp384r1_sha384,
                b'',
                b'',
                b'',
                # PRF, not hash of signature
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "", "sha384"))

    def test_generate_with_ecdsa_521_alg(self):
        priv_key = Python_ECDSAKey(None, None, "NIST521p", 10)
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)
        req = CertificateRequest((3, 4)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.ecdsa_secp521r1_sha512)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ecdsa_secp521r1_sha512,
                b'',
                b'',
                b'',
                # PRF, not hash of signature
                "sha256")

        self.assertTrue(priv_key.verify(
            msg.signature, verif_bytes,
            "", "sha512"))

    def test_generate_with_xors(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={0: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))
        self.assertEqual(msg.signature[0] ^ 0xff, 0x30)

        sig = msg.signature
        self.assertFalse(priv_key.verify(
            sig, secureHash(b"", "sha512")[:32],
            "", "sha512"))

        sig[0] ^= 0xff
        self.assertTrue(priv_key.verify(
            sig, secureHash(b"", "sha512")[:32],
            "", "sha512"))

    def test_generate_with_xors_in_tls1_0(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={0: 0xff})
        state = ConnectionState()
        state.version = (3, 1)
        req = CertificateRequest((3, 1)).create()
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertIsNone(msg.signatureAlgorithm)
        self.assertEqual(msg.signature[0] ^ 0xff, 0x30)

        sig = msg.signature
        self.assertFalse(priv_key.verify(
            sig, secureHash(b"", "sha1"),
            "", "sha1"))

        sig[0] ^= 0xff
        self.assertTrue(priv_key.verify(
            sig, secureHash(b"", "sha1"),
            "", "sha1"))

    def test_generate_with_subs(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_subs={0: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ecdsa_secp521r1_sha512,
             constants.SignatureScheme.ecdsa_secp384r1_sha384,
             constants.SignatureScheme.ecdsa_secp256r1_sha256,
             constants.SignatureScheme.rsa_pss_pss_sha256,
             (constants.HashAlgorithm.sha1,
              constants.SignatureAlgorithm.rsa)])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm,
                         (constants.HashAlgorithm.sha512,
                          constants.SignatureAlgorithm.ecdsa))
        self.assertEqual(msg.signature[0], 0xff)

        sig = msg.signature
        self.assertFalse(priv_key.verify(
            sig, secureHash(b"", "sha512")[:32],
            "", "sha512"))

        sig[0] = 0x30
        self.assertTrue(priv_key.verify(
            sig, secureHash(b"", "sha512")[:32],
            "", "sha512"))

    def test_generate_with_subs_in_TLS1_0(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_subs={0: 0xff})
        state = ConnectionState()
        state.version = (3, 1)
        req = CertificateRequest((3, 1)).create()
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in
        # TLS 1.2, so the negotiation will always select SHA512 algorithm
        self.assertEqual(msg.signatureAlgorithm, None)

        sig = msg.signature
        self.assertFalse(priv_key.verify(
            sig, secureHash(b"", "sha1"),
            "", "sha1"))

        sig[0] = 0x30
        self.assertTrue(priv_key.verify(
            sig, secureHash(b"", "sha1"),
            "", "sha1"))

class TestCertificateVerifyGeneratorEdDSA(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        p_key = SigningKey.generate(Ed25519)
        cls.priv_key = Python_EdDSAKey(None, p_key)

    def test_generate_with_eddsa_and_no_cert_req_in_tls1_3(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key)
        state = ConnectionState()
        state.version = (3, 4)

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.ed25519)

        verif_bytes = KeyExchange.calcVerifyBytes(
                (3, 4),
                state.handshake_hashes,
                constants.SignatureScheme.ed25519,
                b'',
                b'',
                b'',
                "sha256")

        self.assertTrue(priv_key.hashAndVerify(
            msg.signature, verif_bytes,
            "", "sha256"))

    def test_generate_with_xors(self):
        priv_key = self.priv_key
        cert_ver_g = CertificateVerifyGenerator(priv_key,
                                                padding_xors={0: 0xff})
        state = ConnectionState()
        state.version = (3, 3)
        req = CertificateRequest((3, 3)).create([], [],
            [constants.SignatureScheme.ed25519])
        state.handshake_messages = [req]

        msg = cert_ver_g.generate(state)

        self.assertIsNotNone(msg)
        self.assertTrue(msg.signature)
        self.assertEqual(msg.signatureAlgorithm,
                         constants.SignatureScheme.ed25519)

        sig = msg.signature
        self.assertFalse(priv_key.hashAndVerify(
            sig, b""))

        sig[0] ^= 0xff
        self.assertTrue(priv_key.hashAndVerify(
            sig, b""))


class TestClearContext(unittest.TestCase):
    def test___init__(self):
        ctx = []
        cc = ClearContext(ctx)

        self.assertIs(cc.context, ctx)
        self.assertTrue(cc.is_command())
        self.assertFalse(cc.is_expect())
        self.assertFalse(cc.is_generator())

    def test_process(self):
        ctx = ["abc", 123]
        cc = ClearContext(ctx)
        self.assertEqual(ctx, ["abc", 123])

        cc.process(None)

        self.assertEqual(ctx, [])


class TestAlertGenerator(unittest.TestCase):
    def test_default_settings(self):
        a = AlertGenerator()

        self.assertIsNotNone(a)

        state = ConnectionState()

        ret = a.generate(state)
        self.assertEqual(ret.level, constants.AlertLevel.warning)
        self.assertEqual(ret.description,
                         constants.AlertDescription.close_notify)

    def test___init___with_parameters(self):
        a = AlertGenerator(constants.AlertLevel.fatal,
                           constants.AlertDescription.decode_error)
        self.assertIsNotNone(a)

        state = ConnectionState()

        ret = a.generate(state)
        self.assertEqual(ret.level, constants.AlertLevel.fatal)
        self.assertEqual(ret.description,
                         constants.AlertDescription.decode_error)


class TestFinishedGenerator(unittest.TestCase):
    def test___init__(self):
        fg = FinishedGenerator()

        self.assertIsNotNone(fg)

    def test_generate(self):
        fg = FinishedGenerator()
        state = ConnectionState()

        ret = fg.generate(state)

        self.assertIsInstance(ret, messages.Finished)

    def test_post_send(self):
        fg = FinishedGenerator()
        state = ConnectionState()

        ret = fg.generate(state)

        self.assertNotIn(ret, state.handshake_messages)

        fg.post_send(state)

        self.assertIn(ret, state.handshake_messages)

    def test_post_send_with_resumption(self):
        fg = FinishedGenerator()
        state = ConnectionState()
        state.resuming = True

        ret = fg.generate(state)
        fg.post_send(state)

        self.assertFalse(state.resuming)

    def test_generate_with_ssl2(self):
        fg = FinishedGenerator((0, 2))
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.session_id = bytearray(b'abba')

        ret = fg.generate(state)

        self.assertEqual(ret.verify_data, bytearray(b'abba'))
        state.msg_sock.changeWriteState.assert_called_once_with()
        state.msg_sock.changeReadState.assert_called_once_with()

    def test_generate_in_tls13(self):
        fg = FinishedGenerator((3, 4))

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256
        state.version = (3, 4)
        state.key['client handshake traffic secret'] = bytearray(32)

        ret = fg.generate(state)

        self.assertEqual(ret.verify_data, bytearray(
            b'\x14\xa5e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef'
            b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3\xc9\x0c'))

        state.key['handshake secret'] = bytearray(32)
        state.key['master secret'] = bytearray(32)

        fg.post_send(state)

        state.msg_sock.changeWriteState.assert_called_once_with()

        self.assertEqual(state.key['resumption master secret'],
            bytearray(b'\x89\xd8\x00l c$\x01\x0f\xd9j\x16\xa3\xbaV\xfesT\x8b'
                      b'\xc6\xeb\x0f~\r\xbd\xb3R\xeb\xd5\x08\xa7\xbd'))

    def test_generate_in_tls13_with_pha(self):
        ctx = []

        fg = FinishedGenerator((3, 4), context=ctx)

        cv = messages.CertificateVerify((3, 4)).create(
            b'abcdef', constants.SignatureScheme.rsa_pss_rsae_sha256)
        ctx.append(cv)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256
        state.version = (3, 4)
        state.key['client handshake traffic secret'] = bytearray(32)
        state.key['client application traffic secret'] = bytearray(b'x'*32)
        state.key['client finished handshake hashes'] = HandshakeHashes()

        ret = fg.generate(state)

        print(repr(ret.verify_data))
        self.assertEqual(ret.verify_data,
            bytearray(b'q\xf1l\x05\x94\xb8"\xb2L7\xce\xd5\xb3\x00\xa6\r\x17*'
                      b'\xcc\xe7\xdc\xa6\xf0c\xd7\x90I\x11}\xbfq:'))
        state.key['handshake secret'] = bytearray(32)
        state.key['master secret'] = bytearray(32)

        fg.post_send(state)

        self.assertTrue(not state.msg_sock.changeWriteState.called)

        self.assertNotIn('resumption master secret', state.key)

    def test_generate_in_tls13_with_truncation(self):
        fg = FinishedGenerator((3, 4), trunc_start=2, trunc_end=-2)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256
        state.version = (3, 4)
        state.key['client handshake traffic secret'] = bytearray(32)

        ret = fg.generate(state)

        self.assertEqual(ret.verify_data, bytearray(
            b'e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef'
            b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3'))

    def test_generate_in_tls13_with_padding(self):
        fg = FinishedGenerator((3, 4), pad_byte=0, pad_left=1, pad_right=1)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256
        state.version = (3, 4)
        state.key['client handshake traffic secret'] = bytearray(32)

        ret = fg.generate(state)

        self.assertEqual(ret.verify_data, bytearray(
            b'\x00\x14\xa5e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef'
            b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3\xc9\x0c\x00'))


class TestKeyUpdateGenerator(unittest.TestCase):

    def test_default_settings(self):
        ku = KeyUpdateGenerator()

        self.assertIsNotNone(ku)

        state = ConnectionState()

        ret = ku.generate(state)
        self.assertEqual(ret.message_type,
                         constants.KeyUpdateMessageType.update_not_requested)

    def test___init___with_parameters(self):
        ku = KeyUpdateGenerator(
            constants.KeyUpdateMessageType.update_requested)

        self.assertIsNotNone(ku)

        state = ConnectionState()

        ret = ku.generate(state)
        self.assertEqual(ret.message_type,
                         constants.KeyUpdateMessageType.update_requested)

    def test_post_send(self):
        ku = KeyUpdateGenerator()

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256
        client_app_secret = mock.Mock()
        server_app_secret = mock.Mock()
        state.key['client application traffic secret'] = client_app_secret
        state.key['server application traffic secret'] = server_app_secret
        state.msg_sock.calcTLS1_3KeyUpdate_reciever.return_value = (
            bytearray(32), bytearray(32))

        ku.post_send(state)
        state.msg_sock.calcTLS1_3KeyUpdate_reciever.assert_called_once_with(
                state.cipher, client_app_secret, server_app_secret)


class TestResetHandshakeHashes(unittest.TestCase):
    def test___init__(self):
        node = ResetHandshakeHashes()

        self.assertIsNotNone(node)

    def test_process(self):
        node = ResetHandshakeHashes()

        state = ConnectionState()
        hashes = state.handshake_hashes

        self.assertIs(hashes, state.handshake_hashes)

        node.process(state)

        self.assertIsNot(hashes, state.handshake_hashes)

class TestResetRenegotiationInfo(unittest.TestCase):
    def test___init__(self):
        node = ResetRenegotiationInfo()

        self.assertIsNotNone(node)

    def test_process(self):
        node = ResetRenegotiationInfo()

        state = ConnectionState()
        state.key['client_verify_data'] = bytearray(b'\xde\xad\xc0\xde')
        state.key['server_verify_data'] = bytearray(b'\xc0\xff\xee')

        node.process(state)

        self.assertEqual(state.key['client_verify_data'], bytearray(0))
        self.assertEqual(state.key['server_verify_data'], bytearray(0))

class TestSetMaxRecordSize(unittest.TestCase):
    def test___init__(self):
        node = SetMaxRecordSize()
        self.assertIsNotNone(node)

    def test_process(self):
        node = SetMaxRecordSize()

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.msg_sock.recordSize = 1024

        node.process(state)

        self.assertEqual(2**14, state.msg_sock.recordSize)

    def test_process_with_size(self):
        node = SetMaxRecordSize(2048)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        self.assertEqual(2048, state.msg_sock.recordSize)


class TestSetPaddingCallback(unittest.TestCase):
    def test___init__(self):
        node = SetPaddingCallback()
        self.assertIsNotNone(node)

    def test_process_fixed_len_padding(self):
        node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(42))

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        self.assertEqual(42, state.msg_sock.padding_cb(13,
                         constants.ContentType.application_data,
                         2**14 - 1))

    def test_process_fill_padding(self):
        node = SetPaddingCallback(SetPaddingCallback.fill_padding_cb)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        self.assertEqual(2**14 - 13 - 1,
                         state.msg_sock.padding_cb(13,
                         constants.ContentType.application_data,
                         2**14 - 1))

    def test_process_custom_callback(self):

        def _my_cb(length, contenttype, max_padding):
            return 1337

        node = SetPaddingCallback(_my_cb)
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        self.assertEqual(1337, state.msg_sock.padding_cb(13,
                         constants.ContentType.application_data,
                         2**14 - 1))

    def test_unset_padding_callback(self):
        node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(16))
        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        self.assertEqual(16, state.msg_sock.padding_cb(13,
                         constants.ContentType.application_data,
                         2**14 - 1))

        unset_node = SetPaddingCallback()
        unset_node.process(state)

        self.assertIsNone(state.msg_sock.padding_cb)

    def test_with_padding_larger_than_possible(self):
        node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(42))

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        node.process(state)

        with self.assertRaises(ValueError):
            state.msg_sock.padding_cb(20,
                                      constants.ContentType.application_data,
                                      32)


class TestRenegotiationInfoExtension(unittest.TestCase):
    def test___init__(self):
        ext = RenegotiationInfoExtension()
        self.assertIsNotNone(ext)

    def test_write(self):
        ext = RenegotiationInfoExtension()

        self.assertEqual(ext.write(), bytearray(
            b'\xff\x01' +       # extension type
            b'\x00\x00'         # overall extension length
            ))

    def test_write_with_data(self):
        ext = RenegotiationInfoExtension()
        ext.create(bytearray(b'\xab\xcd'))

        self.assertEqual(ext.write(), bytearray(
            b'\xff\x01' +       # extension type
            b'\x00\x03' +       # overall extension length
            b'\x02' +           # payload length
            b'\xab\xcd'         # payload
            ))

    def test_parse(self):
        parser = Parser(bytearray(b'\x02\xab\xcd'))

        ext = RenegotiationInfoExtension()
        ext.parse(parser)

        self.assertEqual(bytearray(b'\xab\xcd'), ext.renegotiated_connection)

class TestHandshakePadding(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()
        self.hello_gen = ClientHelloGenerator()

        self.vanilla_hello = self.hello_gen.generate(self.state).write()

    def test_no_option(self):
        self.assertEqual(len(self.vanilla_hello), 43)

        hello_gen = pad_handshake(ClientHelloGenerator())

        unmodified_hello = hello_gen.generate(self.state).write()
        self.assertEqual(len(unmodified_hello), 43)

        self.assertEqual(self.vanilla_hello, unmodified_hello)

    def test_add_padding(self):
        hello_gen = pad_handshake(ClientHelloGenerator(), 1)

        padded_hello = hello_gen.generate(self.state).write()

        self.assertEqual(len(padded_hello), 44)

        # skip the first 4 bytes as they have different length
        self.assertEqual(self.vanilla_hello[4:] + bytearray(1),
                         padded_hello[4:])

    def test_add_specific_padding(self):
        hello_gen = pad_handshake(ClientHelloGenerator(), 2, 0xab)

        padded_hello = hello_gen.generate(self.state).write()

        self.assertEqual(len(padded_hello), 45)

        # skip the first 4 bytes as they have different length
        self.assertEqual(self.vanilla_hello[4:] + bytearray(b'\xab\xab'),
                         padded_hello[4:])

    def test_pad_with_data(self):
        pad = bytearray(b'\xff\x01\x00\x01\x00')
        hello_gen = pad_handshake(ClientHelloGenerator(),
                                  pad=pad)

        padded_hello = hello_gen.generate(self.state).write()

        self.assertEqual(len(padded_hello), len(self.vanilla_hello) + len(pad))

        self.assertEqual(self.vanilla_hello[4:] + pad,
                         padded_hello[4:])
        self.assertNotEqual(self.vanilla_hello[:4],
                            padded_hello[:4])

    def test_truncate(self):
        hello_gen = truncate_handshake(ClientHelloGenerator(), 1)

        padded_hello = hello_gen.generate(self.state).write()

        self.assertEqual(len(padded_hello), 42)

        # skip the first 4 bytes as they have different length
        self.assertEqual(self.vanilla_hello[4:-1],
                         padded_hello[4:])

class TestFuzzMessage(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()
        self.hello_gen = ClientHelloGenerator()

        self.vanilla_hello = self.hello_gen.generate(self.state).write()

    def test_no_options(self):
        self.assertEqual(len(self.vanilla_hello), 43)

        hello_gen = fuzz_message(ClientHelloGenerator())

        unmodified_hello = hello_gen.generate(self.state).write()
        self.assertEqual(len(unmodified_hello), 43)

        self.assertEqual(self.vanilla_hello, unmodified_hello)

    def test_substitutions(self):
        hello_gen = fuzz_message(ClientHelloGenerator(), substitutions={4:0xff})
        modified_hello = hello_gen.generate(self.state).write()

        self.assertNotEqual(self.vanilla_hello, modified_hello)

        self.vanilla_hello[4] = 0xff

        self.assertEqual(self.vanilla_hello, modified_hello)

    def test_xors(self):
        hello_gen = fuzz_message(ClientHelloGenerator(), xors={4:0xff})
        modified_hello = hello_gen.generate(self.state).write()

        self.assertNotEqual(self.vanilla_hello, modified_hello)

        self.vanilla_hello[4] ^= 0xff

        self.assertEqual(self.vanilla_hello, modified_hello)

class TestFuzzMAC(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()

        self.socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        self.state.msg_sock = messagesocket.MessageSocket(self.socket,
                                                          defragger)

        self.state.msg_sock.version = (3, 3)
        self.state.msg_sock.calcPendingStates(constants.CipherSuite.\
                                                    TLS_RSA_WITH_NULL_MD5,
                                              bytearray(48),
                                              bytearray(32),
                                              bytearray(32),
                                              None)
        self.state.msg_sock.changeWriteState()

        self.expected_value = bytearray(
            b"\x16"         # content type
            b"\x03\x03"     # record layer protocol version
            b"\x00\x3b"     # record layer record length
            b"\x01"         # handshake message type
            b"\x00\x00\x27" # handshke protocol message length
            b"\x03\x03"     # client hello protocol version
            # random
            b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
            b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
            b"\x00"         # length of session_id
            b"\x00\x00"     # cipher_suites length
            b"\x01\x00"     # compression_methods (0 - uncompressed)
            # 128 bit MD5 HMAC value
            b"\x1cK \xce\xb3\x1d\x94\x0b\x0f\x9a\'\x9c\x87\x1a-`"
            )

        self.second_write = bytearray(self.expected_value[:-16])
        self.second_write += bytearray(
                # MD5 HMAC with sequence number of "2"
                b"\x84\xcb\\\xf2A\x0c\xd3<u\xf3\xce\x8dk\xa0\xd8/")

    def test_no_options(self):
        hello_gen = fuzz_mac(ClientHelloGenerator())

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(self.socket.sent[0], self.expected_value)

    def test_xor_last_byte(self):
        hello_gen = fuzz_mac(ClientHelloGenerator(), xors={-1:0xff})

        modified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(modified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(modified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.expected_value[-1] ^= 0xff
        self.assertEqual(self.socket.sent[0], self.expected_value)

    def test_xor_first_byte(self):
        hello_gen = fuzz_mac(ClientHelloGenerator(), xors={0:0xff})

        modified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(modified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(modified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        # MD5 is 16 bytes long
        self.expected_value[-16] ^= 0xff
        self.assertEqual(self.socket.sent[0], self.expected_value)

    def test_substitute_last_byte(self):
        hello_gen = fuzz_mac(ClientHelloGenerator(), substitutions={0:0xff})

        modified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(modified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(modified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        # MD5 is 16 bytes long
        self.expected_value[-16] = 0xff
        self.assertEqual(self.socket.sent[0], self.expected_value)

    def test_post_send_no_options(self):
        hello_gen = fuzz_mac(ClientHelloGenerator())

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(self.socket.sent[0], self.expected_value)

        hello_gen.post_send(self.state)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 2)
        self.assertEqual(self.socket.sent[1], self.second_write)

    def test_post_send_xor_last_byte(self):
        hello_gen = fuzz_mac(ClientHelloGenerator(), xors={-1:0xff})

        modified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(modified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(modified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.expected_value[-1] ^= 0xff
        self.assertEqual(self.socket.sent[0], self.expected_value)

        hello_gen.post_send(self.state)

        self.state.msg_sock.sendMessageBlocking(modified_hello)

        self.assertEqual(len(self.socket.sent), 2)
        self.assertEqual(self.socket.sent[1], self.second_write)

class TestFuzzEncryptedMessage(unittest.TestCase):
    def setUp(self):
        patcher = mock.patch.object(os,
                                    'urandom',
                                    lambda x: bytearray(x))
        mock_random = patcher.start()

        self.addCleanup(patcher.stop)
        self.state = ConnectionState()

        self.socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        self.state.msg_sock = messagesocket.MessageSocket(self.socket,
                                                          defragger)

        self.state.msg_sock.version = (3, 1)
        self.state.msg_sock.encryptThenMAC = True
        self.state.msg_sock.calcPendingStates(
                constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                bytearray(48), # master secret
                bytearray(32), # client random
                bytearray(32), # server random
                None)

        self.state.msg_sock.changeWriteState()
        self.expected_value = bytearray(
            b'\x17' +           # application data
            b'\x03\x01' +       # TLS version
            b'\x00\x24' +       # length - 1 block + 20 bytes of MAC
            b'\xc7\xd6\xaf:.MY\x80W\x81\xd2|5A#\xd5' +
            b'X\xcd\xdc\'o\xb3I\xdd-\xfc\tneq~\x0f' +
            b'd\xdb\xbdw'
            )

    def test_no_changes(self):
        node = ApplicationDataGenerator(bytearray(b'test'))
        node = fuzz_encrypted_message(node)
        msg = node.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(msg)
        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(self.socket.sent[0], self.expected_value)

    def test_xor_last_byte(self):
        node = ApplicationDataGenerator(bytearray(b'test'))
        node = fuzz_encrypted_message(node, xors={-1:0xff})
        msg = node.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(msg)
        self.expected_value[-1] ^= 0xff
        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(self.socket.sent[0], self.expected_value)


class TestFuzzPadding(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()
        self.socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        self.state.msg_sock = messagesocket.MessageSocket(self.socket,
                                                          defragger)
        self.state.msg_sock.version = (3, 0)
        self.state.msg_sock.calcPendingStates(constants.CipherSuite.\
                                                TLS_RSA_WITH_AES_128_CBC_SHA,
                                              bytearray(48),
                                              bytearray(32),
                                              bytearray(32),
                                              None)
        self.state.msg_sock.changeWriteState()

    def test_no_options(self):
        hello_gen = fuzz_padding(ClientHelloGenerator())

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         43 +       # length of ClientHello
                         160 // 8 + # length of HMAC
                         1)         # size of length tag of padding (0)

    def test_min_length(self):
        hello_gen = fuzz_padding(ClientHelloGenerator(),
                                 min_length=0)

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         43 +       # length of ClientHello
                         160 // 8 + # length of HMAC
                         1)         # size of length tag of padding (0)

    def test_min_length_with_high_value(self):
        hello_gen = fuzz_padding(ClientHelloGenerator(),
                                 min_length=200)

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         43 +       # length of ClientHello
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         208)       # minimal length of padding

    def test_min_length_with_post_send(self):
        hello_gen = fuzz_padding(ClientHelloGenerator(),
                                 min_length=200)

        unmodified_hello = hello_gen.generate(self.state)
        self.assertEqual(len(unmodified_hello.write()), 43)

        self.state.msg_sock.sendMessageBlocking(unmodified_hello)

        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         43 +       # length of ClientHello
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         208)       # minimal length of padding greater than 200

        hello_gen.post_send(self.state)

        clean_hello_gen = ClientHelloGenerator()
        clean_hello = clean_hello_gen.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(clean_hello)

        self.assertEqual(len(self.socket.sent), 2)
        self.assertEqual(len(self.socket.sent[1]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         43 +       # length of ClientHello
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         0)       # minimal length of padding

    def test_min_length_with_invalid_length(self):
        with self.assertRaises(ValueError):
            fuzz_padding(ClientHelloGenerator(), min_length=257)

    def test_min_length_with_length_too_big_for_data(self):
        data_gen = fuzz_padding(ApplicationDataGenerator(b"text"),
                                min_length=254)

        data_msg = data_gen.generate(self.state)
        self.assertEqual(len(data_msg.write()), 4)

        with self.assertRaises(ValueError):
            self.state.msg_sock.sendMessageBlocking(data_msg)

    def test_xors(self):
        # packet with no modifications
        unchanged = bytearray(
                b'\x17\x03\x00\x000' # record layer header
                b'\xa1\xbb\x9f&Z\x1cb\xb3\xf3U\x11\xbb\xf4\xd6\x91\xf3'
                b'\xa8\xf2"\xb8\xa9@]\x16,\xc9\x17Wh\x17\x1e\xb5'
                b'\x9f\xcdm\x9a\xf0!\xe65\xea\xa8\xeb|(\xd8\xd2\x02')
        data_gen = fuzz_padding(ApplicationDataGenerator(b"text"),
                                min_length=16,
                                xors={-2:0xff})

        data_msg = data_gen.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(data_msg)
        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         4 +        # length of Application Data
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         23)        # minimal length of padding
        last_block = bytearray(
                b'\\Y\x90j\x8a\xe7\x82\xf3=\xceE\xe3\x0f\x85\x82\t')
        self.assertEqual(self.socket.sent[0], unchanged[:-16] + last_block)

class TestFuzzPlaintext(unittest.TestCase):
    def setUp(self):
        self.state = ConnectionState()
        self.socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        defragger.add_static_size(constants.ContentType.alert, 2)
        defragger.add_static_size(constants.ContentType.change_cipher_spec, 1)
        defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3)
        self.state.msg_sock = messagesocket.MessageSocket(self.socket,
                                                          defragger)
        self.state.msg_sock.version = (3, 0)
        self.state.msg_sock.calcPendingStates(constants.CipherSuite.\
                                                TLS_RSA_WITH_AES_128_CBC_SHA,
                                              bytearray(48),
                                              bytearray(32),
                                              bytearray(32),
                                              None)
        self.state.msg_sock.changeWriteState()

    def test_xors(self):
        # packet with no modifications
        unchanged = bytearray(
                b'\x17\x03\x00\x00 ' # record layer header
                b'\xa1\xbb\x9f&Z\x1cb\xb3\xf3U\x11\xbb\xf4\xd6\x91\xf3'
                b'\xbf4\xd0\x86\x99\xb9\xd9Z\xc4_\x8db\xa7\xda\x1a\xea')
        data_gen = fuzz_plaintext(ApplicationDataGenerator(b"text"),
                                  xors={-2:0xff})

        data_msg = data_gen.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(data_msg)
        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         4 +        # length of Application Data
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         7)         # minimal length of padding
        last_block = bytearray(
                b'\x14\xa3\x14\xd2V+\x90\x08t\x81A%\xe5\xd5\xf4\x10')
        self.assertEqual(self.socket.sent[0], unchanged[:-16] + last_block)

    def test_substitutions(self):
        # packet with no modifications
        unchanged = bytearray(
                b'\x17\x03\x00\x00 ' # record layer header
                b'\xa1\xbb\x9f&Z\x1cb\xb3\xf3U\x11\xbb\xf4\xd6\x91\xf3'
                b'\xbf4\xd0\x86\x99\xb9\xd9Z\xc4_\x8db\xa7\xda\x1a\xea')
        data_gen = fuzz_plaintext(ApplicationDataGenerator(b"text"),
                                  substitutions={0:0xff})

        data_msg = data_gen.generate(self.state)
        self.state.msg_sock.sendMessageBlocking(data_msg)
        self.assertEqual(len(self.socket.sent), 1)
        self.assertEqual(len(self.socket.sent[0]),
                         1 +        # record layer type
                         2 +        # protocol version
                         2 +        # record payload length field
                         4 +        # length of Application Data
                         160 // 8 + # length of HMAC
                         1 +        # size of length tag of padding (0)
                         7)         # minimal length of padding
        # since we are doing the substitution on a CBC cipher in first block,
        # all subsequent blocks ciphertext is different too
        expected = bytearray(
                b'\xc0\\ba\x7f}Q\xe0\xa6\xc27P\xd7U\xdf\xf9'
                b'n\x97\xdf_\xe2\xef,X\x9b\rv[\x1c\x83\x1e\xbd')
        self.assertEqual(self.socket.sent[0][:5], unchanged[:5])
        self.assertEqual(self.socket.sent[0][5:], expected)


class TestSplitMessage(unittest.TestCase):
    def test_split_to_two(self):
        state = ConnectionState()
        vanilla_hello = ClientHelloGenerator().generate(state).write()
        fragments = []

        msg = ClientHelloGenerator()
        post_send = msg.post_send
        hello_gen = split_message(msg, fragments, 30)

        self.assertEqual(fragments, [])

        first_part = hello_gen.generate(state).write()

        self.assertEqual(len(first_part), 30)
        self.assertEqual(len(fragments), 2)
        self.assertEqual(len(fragments[0].write()), 13)
        self.assertEqual(fragments[1], post_send)

    def test_split_of_zero_length(self):
        # 0 length messages are intentionally unhandled
        fragments = []
        msg_gen = split_message(RawMessageGenerator(20, bytearray(0)),
                                fragments, 30)

        state = ConnectionState()
        with self.assertRaises(IndexError):
            msg_gen.generate(state)


class TestPopMessageFromList(unittest.TestCase):
    def test_with_message_list(self):
        msg_list = []

        msg_gen = PopMessageFromList(msg_list)

        msg_list.append(messages.Message(20, bytearray(b'\x20\x30')))
        msg_list.append(messages.Message(21, bytearray(b'\x30\x20')))

        msg = msg_gen.generate(None)

        self.assertEqual(msg.contentType, 20)
        self.assertEqual(msg.write(), bytearray(b'\x20\x30'))

        self.assertEqual(len(msg_list), 1)

    def test_with_post_send_call(self):
        msg_list = []

        msg_gen = PopMessageFromList(msg_list)

        msg_list.append(messages.Message(20, bytearray(b'\x20\x20')))
        post_send = mock.MagicMock()
        msg_list.append(post_send)

        msg = msg_gen.generate(None)

        self.assertEqual(msg.contentType, 20)
        self.assertEqual(msg.write(), bytearray(b'\x20\x20'))

        self.assertEqual(len(msg_list), 1)

        state = mock.Mock()
        msg_gen.post_send(state)

        post_send.assert_called_once_with(state)

        self.assertEqual(msg_list, [])


class TestFlushMessageList(unittest.TestCase):
    def test_with_message_list(self):
        msg_list = []

        msg_gen = FlushMessageList(msg_list)

        self.assertEqual(msg_list, [])

        msg_list.append(messages.Message(20, bytearray(b'\x20\x30')))
        msg_list.append(messages.Message(20, bytearray(b'\x60\x70')))

        msg = msg_gen.generate(None)

        self.assertEqual(msg.contentType, 20)
        self.assertEqual(msg.write(), bytearray(b'\x20\x30\x60\x70'))

        self.assertEqual(msg_list, [])

    def test_with_different_message_types(self):
        msg_list = [messages.Message(20, bytearray(b'\x20')),
                    messages.Message(30, bytearray(b'\x10'))]

        msg_gen = FlushMessageList(msg_list)

        with self.assertRaises(AssertionError):
            msg_gen.generate(None)

    def test_with_post_send_call(self):
        msg_list = []

        msg_gen = FlushMessageList(msg_list)

        msg_list.append(messages.Message(20, bytearray(b'\x20\x20')))
        msg_list.append(messages.Message(20, bytearray(b'\x30\x03')))
        post_send = mock.MagicMock()
        msg_list.append(post_send)

        msg = msg_gen.generate(None)

        self.assertEqual(msg.contentType, 20)
        self.assertEqual(msg.write(), bytearray(b'\x20\x20\x30\x03'))

        self.assertEqual(len(msg_list), 1)

        state = mock.Mock()
        msg_gen.post_send(state)

        post_send.assert_called_once_with(state)

        self.assertEqual(msg_list, [])


class TestFuzzPKCS1Padding(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.key = generateRSAKey(1024)
        cls.key.old_addPKCS1Padding = cls.key._addPKCS1Padding

    def setUp(self):
        self.key._addPKCS1Padding = self.key.old_addPKCS1Padding

    def test_with_no_substitutions(self):
        fuzz_pkcs1_padding(self.key)
        for_signing = bytearray(range(0, 16))
        data = self.key._addPKCS1Padding(for_signing, 1)

        expected = bytearray([0, 1] + [0xff] * 109 + [0] + list(range(0, 16)))
        self.assertEqual(len(data), len(expected))
        self.assertEqual(data, expected)

    def test_with_substitutions(self):
        fuzz_pkcs1_padding(self.key, substitutions={1: 2})
        for_signing = bytearray(range(0, 16))
        data = self.key._addPKCS1Padding(for_signing, 1)

        expected = bytearray([0, 2] + [0xff] * 109 + [0] + list(range(0, 16)))
        self.assertEqual(len(data), len(expected))
        self.assertEqual(data, expected)

    def test_with_xors(self):
        fuzz_pkcs1_padding(self.key, xors={-1: 0x0f})
        for_signing = bytearray(range(1, 17))
        data = self.key._addPKCS1Padding(for_signing, 1)

        expected = bytearray([0, 1] + [0xff] * 109 + [0x0f] + list(range(1, 17)))
        self.assertEqual(len(data), len(expected))
        self.assertEqual(data, expected)

    def test_with_padding_byte(self):
        fuzz_pkcs1_padding(self.key, substitutions={1:2}, padding_byte=0xaa)
        for_signing = bytearray(range(5, 21))
        data = self.key._addPKCS1Padding(for_signing, 2)

        expected = bytearray([0, 2] + [0xaa] * 109 + [0x00] + list(range(5, 21)))
        self.assertEqual(len(data), len(expected))
        self.assertEqual(data, expected)


class TestReplacePlaintext(unittest.TestCase):
    def test_replace(self):
        state = ConnectionState()
        socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        state.msg_sock = messagesocket.MessageSocket(socket,
                                                     defragger)
        state.msg_sock.version = (3, 3)
        state.msg_sock.calcPendingStates(constants.CipherSuite.\
                                                TLS_RSA_WITH_AES_128_CBC_SHA,
                                              bytearray(48),
                                              bytearray(32),
                                              bytearray(32),
                                              None)
        state.msg_sock.changeWriteState()

        msg = ApplicationDataGenerator(bytearray(b"text"))
        msg = replace_plaintext(msg, bytearray(b'\x00' * 16))

        data_msg = msg.generate(state)

        state.msg_sock.sendMessageBlocking(data_msg)

        self.assertEqual(len(socket.sent), 1)
        self.assertEqual(len(socket.sent[0]),
                         1 +  # type
                         2 +  # proto version
                         2 +  # payload length
                         16)  # data length

        exp_data = bytearray(b'\x17\x03\x03\x00\x10'
                             b'H&\x1f\xc1\x9c\xde"\x92\xdd\xe4|\xfco)R\xd6')
        # just the fact that the ciphertext is smaller than the MAC size
        # indicates that it was completely replaced
        self.assertEqual(socket.sent[0], exp_data)

    def test_replace_with_replacement_not_multiple_of_cipher_block_size(self):
        state = ConnectionState()
        socket = MockSocket(bytearray())

        defragger = defragmenter.Defragmenter()
        state.msg_sock = messagesocket.MessageSocket(socket,
                                                     defragger)
        state.msg_sock.version = (3, 3)
        state.msg_sock.calcPendingStates(constants.CipherSuite.\
                                                TLS_RSA_WITH_AES_128_CBC_SHA,
                                              bytearray(48),
                                              bytearray(32),
                                              bytearray(32),
                                              None)
        state.msg_sock.changeWriteState()

        msg = ApplicationDataGenerator(b"text")
        msg = replace_plaintext(msg, b'\x00' * 8)

        data_msg = msg.generate(state)

        with self.assertRaises(ValueError):
            state.msg_sock.sendMessageBlocking(data_msg)


class TestHeartbeatGenerator(unittest.TestCase):
    def test___init__(self):
        hbg = HeartbeatGenerator(bytearray(b'payload'))

        self.assertIsNotNone(hbg)
        self.assertEqual(hbg.payload, bytearray(b'payload'))
        self.assertEqual(hbg.message_type,
                constants.HeartbeatMessageType.heartbeat_request)
        self.assertEqual(len(hbg.padding), 16)

    def test_generate(self):
        hbg = HeartbeatGenerator(bytearray(b'heartbeat test'))

        hb = hbg.generate(None)

        self.assertIsNotNone(hb)
        self.assertEqual(hb.payload, bytearray(b'heartbeat test'))
        self.assertEqual(hb.padding, hbg.padding)
        self.assertEqual(hb.message_type,
                constants.HeartbeatMessageType.heartbeat_request)

    def test_generate_with_small_padding(self):
        hbg = HeartbeatGenerator(bytearray(b''))
        hbg.padding = bytearray(b'\x00')

        hb = hbg.generate(None)

        self.assertEqual(hb.padding, bytearray(b'\x00'))

    def test_generate_with_no_padding(self):
        hbg = HeartbeatGenerator(bytearray(b''), padding_length=0)

        hb = hbg.generate(None)

        self.assertEqual(hb.padding, bytearray(b''))
The diff you're trying to view is too large. Only the first 1000 changed files have been loaded.
Showing with 0 additions and 0 deletions (0 / 0 diffs computed)
swh spinner

Computing file changes ...

ENEA — Copyright (C), ENEA. License: GNU AGPLv3+.
Legal notes  ::  JavaScript license information ::  Web API

back to top