# Author: Hubert Kario, (c) 2015 # Released under Gnu GPL v2.0, see LICENSE file for details try: import unittest2 as unittest except ImportError: import unittest try: import mock from mock import call except ImportError: import unittest.mock as mock from unittest.mock import call import socket import os from tlsfuzzer.messages import ClientHelloGenerator, ClientKeyExchangeGenerator,\ ChangeCipherSpecGenerator, FinishedGenerator, \ RenegotiationInfoExtension, ResetHandshakeHashes, SetMaxRecordSize, \ pad_handshake, truncate_handshake, Close, fuzz_message, \ RawMessageGenerator, split_message, PopMessageFromList, \ FlushMessageList, fuzz_mac, fuzz_padding, ApplicationDataGenerator, \ CertificateGenerator, CertificateVerifyGenerator, CertificateRequest, \ ResetRenegotiationInfo, fuzz_plaintext, Connect, \ ClientMasterKeyGenerator, TCPBufferingEnable, TCPBufferingDisable, \ TCPBufferingFlush, fuzz_encrypted_message, fuzz_pkcs1_padding, \ CollectNonces, AlertGenerator, PlaintextMessageGenerator, \ SetPaddingCallback, replace_plaintext, ch_cookie_handler, \ ch_key_share_handler, SetRecordVersion, CopyVariables, \ ResetWriteConnectionState, HeartbeatGenerator, Certificate, \ KeyUpdateGenerator, ClearContext, RawSocketWriteGenerator from tlsfuzzer.helpers import psk_ext_gen, psk_ext_updater, \ psk_session_ext_gen, AutoEmptyExtension from tlsfuzzer.runner import ConnectionState import tlslite.messages as messages import tlslite.messagesocket as messagesocket import tlslite.extensions as extensions from tlslite.handshakehashes import HandshakeHashes from tlslite.keyexchange import KeyExchange import tlslite.utils.keyfactory as keyfactory from tlslite.utils.cryptomath import bytesToNumber, numberToByteArray, \ secureHash from tlsfuzzer.utils.ordered_dict import OrderedDict import tlslite.constants as constants import tlslite.defragmenter as defragmenter from tlslite.utils.codec import Parser from tests.mocksock import MockSocket from tlslite.utils.keyfactory import generateRSAKey from tlslite.utils.cryptomath import numberToByteArray from tlslite.utils.python_rsakey import Python_RSAKey from tlslite.utils.python_ecdsakey import Python_ECDSAKey from tlslite.utils.python_eddsakey import Python_EdDSAKey from tlslite.x509 import X509 from tlslite.x509certchain import X509CertChain from ecdsa import SigningKey, Ed25519 rsa_pss_cert = X509CertChain([X509().parse( "-----BEGIN CERTIFICATE-----\n" "MIIDWzCCAhKgAwIBAgIJAM94DjB2Qf+GMD4GCSqGSIb3DQEBCjAxoA0wCwYJYIZI\n" "AWUDBAIBoRowGAYJKoZIhvcNAQEIMAsGCWCGSAFlAwQCAaIEAgIA3jAUMRIwEAYD\n" "VQQDDAlsb2NhbGhvc3QwHhcNMTcwMzA4MTMzMzU4WhcNMTcwNDA3MTMzMzU4WjAU\n" "MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEgMAsGCSqGSIb3DQEBCgOCAQ8AMIIBCgKC\n" "AQEApq5FnZRNmtJy+WjNRB8w0ol2+IEcATrUkZpz7HNqq8+EL/GN21m35luz5fcA\n" "8ifkm4pKxfwtxY7u912tAfsEavUr/UoyLScXQevhr6SLXR8UO3XE6ne1F88eZNpg\n" "KVpocDVKRLjIqKHSSbBNkG76mGKFYyNZXm3qRxaUSHVLnN/opOGfgVK8Vbj0v2EH\n" "5L5dGLQLV2ZfQ75I4OGrg8wse5fhgt2oczhhe72sEuMDc08WHGRp9744mMQvVYrt\n" "pWUHkP5tdD4i7/zYf3nig21/G56JY20JpW+1J5fEn+PEqiJsv5prNJuPp/zdjD3u\n" "ImqfXZzXUeRIQr9l1qXI6JPInwIDAQABo1AwTjAdBgNVHQ4EFgQUcTYhLu7pODIv\n" "B6KhR6eyFBB5wacwHwYDVR0jBBgwFoAUcTYhLu7pODIvB6KhR6eyFBB5wacwDAYD\n" "VR0TBAUwAwEB/zA+BgkqhkiG9w0BAQowMaANMAsGCWCGSAFlAwQCAaEaMBgGCSqG\n" "SIb3DQEBCDALBglghkgBZQMEAgGiBAICAN4DggEBAKMgweHM6WTwlWEQHLG5K+7B\n" "hrAUEAsuK8F7sKGKzLEFzYdzZpkJw8LahE4dFayjx/7MD4rZ5IiHQhJcGCdHIVVv\n" "ocunlEUTgiKkMxTw4JxqSq0snvNBie04vnn+zUjD7FrctTUutzlH1yKftwbJpGk6\n" "CrTW6ctFTAIDwZHd+WX4RPewGY0LTfC+RjcMwWZBmbfVLxuJs0sidSUoNW6GgGE1\n" "DIDVeW2yKGeNhjK/3aDzfQWbz1J64aRfccVzXYMPsoABnNJnJgRETh1/Ci0sQ9Vd\n" "1OR6iS4hl88/1d7utc00MyFVk1sUIGf54EeCvrNB4bhKtawEJk8Q8AGIRhs93sk=\n" "-----END CERTIFICATE-----\n")]) ecdsa_cert = X509CertChain([X509().parse( "-----BEGIN CERTIFICATE-----\n" "MIIBbTCCARSgAwIBAgIJAPM58cskyK+yMAkGByqGSM49BAEwFDESMBAGA1UEAwwJ\n" "bG9jYWxob3N0MB4XDTE3MTAyMzExNDI0MVoXDTE3MTEyMjExNDI0MVowFDESMBAG\n" "A1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEyDRjEAJe\n" "3F5T62MyZbhjoJnPLGL2nrTthLFymBupZ2IbnWYnqVWDkT/L6i8sQhf2zCLrlSjj\n" "1kn7ERqPx/KZyqNQME4wHQYDVR0OBBYEFPfFTUg9o3t6ehLsschSnC8Te8oaMB8G\n" "A1UdIwQYMBaAFPfFTUg9o3t6ehLsschSnC8Te8oaMAwGA1UdEwQFMAMBAf8wCQYH\n" "KoZIzj0EAQNIADBFAiA6p0YM5ZzfW+klHPRU2r13/IfKgeRfDR3dtBngmPvxUgIh\n" "APTeSDeJvYWVBLzyrKTeSerNDKKHU2Rt7sufipv76+7s\n" "-----END CERTIFICATE-----\n" )]) ecdsa_key = keyfactory.parsePEMKey( "-----BEGIN PRIVATE KEY-----\n" "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgCOZr0Ovs0eCmh+XM\n" "QWDYVpsQ+sJdjiq/itp/kYnWNSahRANCAATINGMQAl7cXlPrYzJluGOgmc8sYvae\n" "tO2EsXKYG6lnYhudZiepVYORP8vqLyxCF/bMIuuVKOPWSfsRGo/H8pnK\n" "-----END PRIVATE KEY-----\n", private=True) class TestClose(unittest.TestCase): def test___init__(self): close = Close() self.assertIsNotNone(close) def test_process(self): state = ConnectionState() state.msg_sock = mock.MagicMock() close = Close() close.process(state) state.msg_sock.sock.close.called_once_with() class TestTCPBufferingEnable(unittest.TestCase): def test___init__(self): node = TCPBufferingEnable() self.assertIsNotNone(node) self.assertTrue(node.is_command()) self.assertFalse(node.is_expect()) self.assertFalse(node.is_generator()) @mock.patch('socket.socket') def test_generate(self, raw_sock): state = ConnectionState() conn = Connect('localhost', 4433) conn.process(state) self.assertFalse(state.msg_sock.sock.buffer_writes) node = TCPBufferingEnable() node.process(state) self.assertTrue(state.msg_sock.sock.buffer_writes) class TestTCPBufferingDisable(unittest.TestCase): def test___init__(self): node = TCPBufferingDisable() self.assertIsNotNone(node) self.assertTrue(node.is_command()) self.assertFalse(node.is_expect()) self.assertFalse(node.is_generator()) @mock.patch('socket.socket') def test_generate(self, raw_sock): state = ConnectionState() conn = Connect('localhost', 4433) conn.process(state) self.assertFalse(state.msg_sock.sock.buffer_writes) node = TCPBufferingEnable() node.process(state) self.assertTrue(state.msg_sock.sock.buffer_writes) node = TCPBufferingDisable() node.process(state) self.assertFalse(state.msg_sock.sock.buffer_writes) class TestTCPBufferingFlush(unittest.TestCase): def test___init__(self): node = TCPBufferingFlush() self.assertIsNotNone(node) self.assertTrue(node.is_command()) self.assertFalse(node.is_expect()) self.assertFalse(node.is_generator()) @mock.patch('socket.socket') def test_generate(self, raw_sock): state = ConnectionState() conn = Connect('localhost', 4433) conn.process(state) node = TCPBufferingEnable() node.process(state) node = RawMessageGenerator(12, bytearray(b'\xff')) msg = node.generate(state) state.msg_sock.sendMessageBlocking(msg) raw_sock.return_value.send.assert_not_called() raw_sock.return_value.sendall.assert_not_called() flush = TCPBufferingFlush() flush.process(state) raw_sock.return_value.sendall.assert_called_once_with( bytearray(b'\x0c\x03\x00\x00\x01\xff')) class TestResetWriteConnectionState(unittest.TestCase): def test__init__(self): node = ResetWriteConnectionState() self.assertIsNotNone(node) self.assertTrue(node.is_command()) self.assertFalse(node.is_expect()) self.assertFalse(node.is_generator()) def test_process(self): state = ConnectionState() socket = MockSocket(bytearray()) defragger = defragmenter.Defragmenter() defragger.add_static_size(constants.ContentType.alert, 2) defragger.add_static_size(constants.ContentType.change_cipher_spec, 1) defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3) state.msg_sock = messagesocket.MessageSocket(socket, defragger) state.msg_sock.version = (3, 3) state.msg_sock.encryptThenMAC = True state.msg_sock.calcPendingStates( constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, bytearray(48), bytearray(32), bytearray(32), None) state.msg_sock.changeWriteState() state.msg_sock.changeReadState() self.assertIsNotNone(state.msg_sock._writeState.macContext) self.assertIsNotNone(state.msg_sock._writeState.encContext) self.assertTrue(state.msg_sock._writeState.encryptThenMAC) node = ResetWriteConnectionState() node.process(state) self.assertIsNotNone(state.msg_sock._readState.macContext) self.assertIsNotNone(state.msg_sock._readState.encContext) self.assertTrue(state.msg_sock._readState.encryptThenMAC) self.assertIsNone(state.msg_sock._writeState.macContext) self.assertIsNone(state.msg_sock._writeState.encContext) self.assertIsNone(state.msg_sock._writeState.fixedNonce) self.assertEqual(state.msg_sock._writeState.seqnum, 0) self.assertFalse(state.msg_sock._writeState.encryptThenMAC) class TestCollectNonces(unittest.TestCase): def test__init__(self): nonces = [] node = CollectNonces(nonces) self.assertTrue(node.is_command()) def test_process(self): state = ConnectionState() sock = MockSocket(bytearray()) defragger = defragmenter.Defragmenter() defragger.add_static_size(constants.ContentType.alert, 2) defragger.add_static_size(constants.ContentType.change_cipher_spec, 1) defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3) state.msg_sock = messagesocket.MessageSocket(sock, defragger) state.msg_sock.version = (3, 3) state.msg_sock.calcPendingStates(constants.CipherSuite. TLS_RSA_WITH_AES_128_GCM_SHA256, bytearray(48), bytearray(32), bytearray(32), None) state.msg_sock.changeWriteState() nonces = [] node = CollectNonces(nonces) node.process(state) node = ApplicationDataGenerator(b'some text') msg = node.generate(state) state.msg_sock.sendMessageBlocking(msg) self.assertEqual(nonces, [bytearray(b'\xa9\xfc\x88\x1d' b'\x00\x00\x00\x00\x00\x00\x00\x00')]) class TestCopyVariables(unittest.TestCase): def test___init__(self): node = CopyVariables({}) self.assertIsNotNone(node) self.assertTrue(node.is_command()) self.assertFalse(node.is_expect()) self.assertFalse(node.is_generator()) def test_process_client_hello_random(self): state = ConnectionState() state.client_random = bytearray(b'test') arr = [] node = CopyVariables({'ClientHello.random': arr}) node.process(state) self.assertEqual(arr, [bytearray(b'test')]) def test_process_server_hello_random(self): state = ConnectionState() state.server_random = bytearray(b'abba') arr = [] node = CopyVariables({'ServerHello.random': arr}) node.process(state) self.assertEqual(arr, [bytearray(b'abba')]) def test_process_server_hello_session_id(self): state = ConnectionState() state.session_id = bytearray(b'some session ID') arr = [] node = CopyVariables({'ServerHello.session_id': arr}) node.process(state) self.assertEqual(arr, [bytearray(b'some session ID')]) def test_process_premaster_secret(self): state = ConnectionState() state.key['premaster_secret'] = bytearray(b'premaster secret') arr = [] node = CopyVariables({'premaster_secret': arr}) node.process(state) self.assertEqual(arr, [bytearray(b'premaster secret')]) def test_process_with_incorrect_name(self): state = ConnectionState() arr = [] node = CopyVariables({'DH shared secret': arr}) with self.assertRaises(ValueError): node.process(state) class TestConnect(unittest.TestCase): def test___init__(self): connect = Connect(1, 2) self.assertIsNotNone(connect) self.assertEqual(connect.hostname, 1) self.assertEqual(connect.port, 2) self.assertEqual(connect.version, (3, 0)) self.assertEqual(connect.timeout, 5) @mock.patch('socket.socket') def test_process(self, mock_sock): state = ConnectionState() connect = Connect("localhost", 4433) connect.process(state) self.assertEqual(state.msg_sock.version, (3, 0)) mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) instance = mock_sock.return_value instance.connect.assert_called_once_with(("localhost", 4433)) self.assertIs(state.msg_sock.sock.socket, instance) @mock.patch('socket.socket') def test_process_with_SSLv2(self, mock_sock): state = ConnectionState() connect = Connect(1, 2, (0, 2)) connect.process(state) self.assertEqual(state.msg_sock.version, (0, 2)) mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) instance = mock_sock.return_value instance.connect.assert_called_once_with((1, 2)) self.assertIs(state.msg_sock.sock.socket, instance) @mock.patch('socket.socket') def test_process_with_timeout(self, mock_sock): state = ConnectionState() connect = Connect(1, 2, timeout=10) connect.process(state) mock_sock.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) instance = mock_sock.return_value instance.connect.assert_called_once_with((1, 2)) instance.settimeout.assert_called_once_with(10) self.assertIs(state.msg_sock.sock.socket, instance) class TestSetRecordVersion(unittest.TestCase): def test___init__(self): msg_gen = SetRecordVersion((3, 3)) self.assertIsNotNone(msg_gen) self.assertTrue(msg_gen.is_command()) self.assertFalse(msg_gen.is_expect()) self.assertFalse(msg_gen.is_generator()) def test_process(self): msg_gen = SetRecordVersion((3, 2)) state = ConnectionState() state.msg_sock = mock.MagicMock() msg_gen.process(state) self.assertEqual(state.msg_sock.version, (3, 2)) class TestRawSocketWriteGenerator(unittest.TestCase): def test___init__(self): msg_gen = RawSocketWriteGenerator(bytearray(b'some data')) self.assertIsNotNone(msg_gen) self.assertTrue(msg_gen.is_command()) self.assertFalse(msg_gen.is_expect()) self.assertFalse(msg_gen.is_generator()) def test___repr__(self): msg_gen = RawSocketWriteGenerator(bytearray(b'some data')) self.assertEqual(repr(msg_gen), "RawSocketWriteGenerator(data=bytearray(b'some data'))") def test___repr___with_description(self): msg_gen = RawSocketWriteGenerator(bytearray(b'some data'), description="STARTTLS") self.assertEqual(repr(msg_gen), "RawSocketWriteGenerator(" "data=bytearray(b'some data'), " "description='STARTTLS')") def test_process(self): state = ConnectionState() state.msg_sock = mock.MagicMock() msg_gen = RawSocketWriteGenerator(b'some data') msg_gen.process(state) self.assertTrue( state.msg_sock._recordSocket. sock.send.called_once_with(b'some data')) class TestPlaintextMessageGenerator(unittest.TestCase): def test___init__(self): msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00')) self.assertIsNotNone(msg_gen) self.assertTrue(msg_gen.is_command()) self.assertFalse(msg_gen.is_expect()) self.assertFalse(msg_gen.is_generator()) def test___repr__(self): msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00')) self.assertEqual(repr(msg_gen), "PlaintextMessageGenerator(content_type=12, " "data=bytearray(b'\\x00\\x00'))") def test___repr___with_description(self): msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00'), description="some message") self.assertEqual(repr(msg_gen), "PlaintextMessageGenerator(content_type=12, " "data=bytearray(b'\\x00\\x00'), " "description='some message')") def test_process(self): state = ConnectionState() state.msg_sock = mock.MagicMock() msg_gen = PlaintextMessageGenerator(12, bytearray(b'\x00\x00')) msg_gen.process(state) self.assertTrue(state.msg_sock._recordSocket.send.called) class TestRawMessageGenerator(unittest.TestCase): def test___init__(self): message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02')) self.assertIsNotNone(message_gen) self.assertEqual(message_gen.content_type, 12) self.assertEqual(message_gen.data, bytearray(b'\xff\x02')) def test_gen(self): message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02')) message = message_gen.generate(None) self.assertIsNotNone(message) self.assertEqual(message.contentType, 12) self.assertEqual(message.write(), bytearray(b'\xff\x02')) def test___repr__(self): message_gen = RawMessageGenerator(12, bytearray(b'\xff\x02')) self.assertEqual(repr(message_gen), "RawMessageGenerator(content_type=12, "\ "data=bytearray(b'\\xff\\x02'))") def test___repr___with_description(self): message_gen = RawMessageGenerator(12, bytearray(b'\xff'), description="a broken message") self.assertEqual(repr(message_gen), "RawMessageGenerator(content_type=12, "\ "data=bytearray(b'\\xff'), description='a broken "\ "message')") class TestClientHelloGenerator(unittest.TestCase): def test___init__(self): chg = ClientHelloGenerator() self.assertIsNotNone(chg) self.assertEqual(chg.ciphers, []) def test___repr__(self): chg = ClientHelloGenerator() chg.compression = [] self.assertEqual("ClientHelloGenerator()", repr(chg)) def test___repr___with_values(self): chg = ClientHelloGenerator( [2, 3], ["ext"], (3, 3), bytearray(b'sess'), bytearray(b'random'), [0xc0], True, ["mod"]) self.assertEqual( "ClientHelloGenerator(" "ssl2=True, " "version=(3, 3), " "ciphers=[2, 3], " "random=bytearray(b'random'), " "session_id=bytearray(b'sess'), " "compression=[192], " "extensions=['ext'], " "modifiers=['mod'])", repr(chg)) def test_generate(self): state = ConnectionState() chg = ClientHelloGenerator() return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=None) def test_generate_extensions_with_missing_generator(self): state = ConnectionState() chg = ClientHelloGenerator(extensions={0x1234:None}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with self.assertRaises(ValueError): with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: chg.generate(state) def test_generate_extensions_with_auto_generator(self): state = ConnectionState() chg = ClientHelloGenerator(extensions={0x1234: AutoEmptyExtension()}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) ext = extensions.TLSExtension().create(0x1234, bytearray(0)) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=[ext]) def test_generate_extensions_with_raw_extension(self): state = ConnectionState() ext = extensions.TLSExtension().create(extType=0x1234, data=None) chg = ClientHelloGenerator(extensions={0x1234:ext}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=[ext]) def test_generate_extensions_with_garbage_extension(self): state = ConnectionState() ext = "some weird non-extension" chg = ClientHelloGenerator(extensions={0x1234:ext}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: with self.assertRaises(ValueError): ch = chg.generate(state) def test_generate_extensions_with_ext_generator(self): state = ConnectionState() ext_gen = mock.MagicMock() chg = ClientHelloGenerator(extensions={0x1234:ext_gen}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) ext_gen.assert_called_once_with(state) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=[ext_gen()]) def test_generate_extensions_with_automatic_status_request(self): state = ConnectionState() chg = ClientHelloGenerator(extensions= {constants.ExtensionType.status_request: None}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) mock_method.assert_called_once_with( (3, 3), bytearray(32), bytearray(0), [], extensions=[extensions.StatusRequestExtension().create()]) def test_generate_extensions_with_automatic_encrypt_then_mac(self): state = ConnectionState() chg = ClientHelloGenerator(extensions= {constants.ExtensionType.encrypt_then_mac: None}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) mock_method.assert_called_once_with( (3, 3), bytearray(32), bytearray(0), [], extensions=[extensions.TLSExtension(extType=22).create(b'')]) def test_generate_extensions_with_renego_info_default_generator(self): state = ConnectionState() state.key['client_verify_data'] = bytearray(b'\xab\xcd') chg = ClientHelloGenerator(extensions={constants.ExtensionType.renegotiation_info: None}) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) ext = RenegotiationInfoExtension().create(bytearray(b'\xab\xcd')) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=[ext]) def test_generate_with_random(self): state = ConnectionState() chg = ClientHelloGenerator(random=bytearray(b'\x33'*32)) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) mock_method.assert_called_once_with((3, 3), bytearray(b'\x33'*32), bytearray(0), [], extensions=None) def test_generate_with_compression_methods(self): state = ConnectionState() chg = ClientHelloGenerator(compression=[0, 2, 3]) return_val = mock.MagicMock() return_val.write = mock.MagicMock(return_value=bytearray(10)) with mock.patch.object(messages.ClientHello, 'create', return_value=return_val) as mock_method: ch = chg.generate(state) self.assertEqual(ch, return_val) self.assertEqual(ch.compression_methods, [0, 2, 3]) mock_method.assert_called_once_with((3, 3), bytearray(32), bytearray(0), [], extensions=None) def test_session_id_with_no_settings(self): state = ConnectionState() chg = ClientHelloGenerator(version=(3, 4)) msg = chg.generate(state) self.assertEqual(msg.session_id, b'') def test_seesion_id_with_tls13_extension(self): state = ConnectionState() exts = {constants.ExtensionType.supported_versions: AutoEmptyExtension()} chg = ClientHelloGenerator(version=(3, 3), extensions=exts) msg = chg.generate(state) self.assertEqual(len(msg.session_id), 32) def test_session_id_with_explicit_id_and_tls13_extension(self): state = ConnectionState() exts = {constants.ExtensionType.supported_versions: AutoEmptyExtension()} chg = ClientHelloGenerator(version=(3, 3), extensions=exts, session_id=b'') msg = chg.generate(state) self.assertEqual(msg.session_id, b'') def test_modifiers(self): state = ConnectionState() psk_configs = [(b'test', b'secret')] exts = OrderedDict() exts[constants.ExtensionType.supported_versions] = \ AutoEmptyExtension() exts[constants.ExtensionType.pre_shared_key] = \ psk_ext_gen(psk_configs) modifiers = [psk_ext_updater(psk_configs)] chg = ClientHelloGenerator(version=(3, 3), extensions=exts, session_id=b'\xaa'*32, modifiers=modifiers) msg = chg.generate(state) self.assertEqual(len(msg.extensions), 2) self.assertIsInstance(msg.extensions[1], extensions.PreSharedKeyExtension) ext = msg.extensions[1] self.assertEqual(len(ext.binders), 1) self.assertEqual( ext.binders[0], bytearray(b"\x04!\xd0\xee\x0c\xe8\x13W\xa9\x85\xcc\xce\x07U\x87" b"\xe6\'\xfa\xec\xf4\xf7\x88\x8b\xf3\xc2\xc3^\xf3<\x8b" b"\xba%")) def test_with_session_tickets(self): state = ConnectionState() state.key['resumption master secret'] = bytearray(b'\x10' * 48) state.cipher = constants.CipherSuite.TLS_AES_256_GCM_SHA384 state.session_tickets = [ messages.NewSessionTicket().create( 101, 0, bytearray(b'some nonce'), bytearray(b'ticket identity'), [])] state.session_tickets[0].time = 12.12 exts = OrderedDict() exts[constants.ExtensionType.supported_versions] = \ AutoEmptyExtension() exts[constants.ExtensionType.pre_shared_key] = \ psk_session_ext_gen() modifiers = [psk_ext_updater()] chg = ClientHelloGenerator(version=(3, 3), extensions=exts, session_id=b'\xaa'*32, modifiers=modifiers) with mock.patch('time.time') as mthd: mthd.return_value = 1532009672.427739153 msg = chg.generate(state) self.assertEqual(len(msg.extensions), 2) self.assertIsInstance(msg.extensions[1], extensions.PreSharedKeyExtension) ext = msg.extensions[1] self.assertEqual(len(ext.binders), 1) self.assertEqual( ext.binders[0], bytearray(b'6f&\xe42^a\x13\x949\xc3Q\xb0\x1fD\xb6H\xba\xc1=D\xa6' b'\xcb1\x05\xd72\x1eKO\x8d\xf9V\xa13\xdc\x94b\xdb\xc6' b'\x92\\\xe9\xd2\xd7Tv\xb2')) class TestClientHelloExtensionGenerators(unittest.TestCase): def setUp(self): self.state = ConnectionState() exts = [extensions.CookieExtension().create(b'some payload'), extensions.HRRKeyShareExtension().create( constants.GroupName.secp256r1)] hrr = messages.ServerHello() hrr.create(version=(3, 3), random=constants.TLS_1_3_HRR, session_id=b'', cipher_suite=0x04, extensions=exts) self.state.handshake_messages.append(hrr) def test_ch_cookie_handler(self): ext = ch_cookie_handler(self.state) self.assertIsInstance(ext, extensions.CookieExtension) self.assertEqual(ext.cookie, b'some payload') def test_ch_cookie_handler_with_no_hrr(self): self.state.handshake_messages = [] with self.assertRaises(ValueError) as e: ch_cookie_handler(self.state) self.assertIn("No HRR received", str(e.exception)) def test_ch_key_share_handler(self): ext = ch_key_share_handler(self.state) self.assertIsInstance(ext, extensions.ClientKeyShareExtension) self.assertEqual(len(ext.client_shares), 1) self.assertIsInstance(ext.client_shares[0], extensions.KeyShareEntry) self.assertEqual(ext.client_shares[0].group, constants.GroupName.secp256r1) def test_ch_key_share_handler_with_no_hrr(self): self.state.handshake_messages = [] with self.assertRaises(ValueError) as e: ch_key_share_handler(self.state) self.assertIn("No HRR received", str(e.exception)) class TestClientKeyExchangeGenerator(unittest.TestCase): @classmethod def setUpClass(cls): cls.priv_key = keyfactory.generateRSAKey(1024) def test___init__(self): cke = ClientKeyExchangeGenerator() self.assertEqual(len(cke.premaster_secret), 48) def test___init___with_invalid_param(self): with self.assertRaises(ValueError): cke = ClientKeyExchangeGenerator(p_as_share=True, p_1_as_share=True) def test_generate(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA) ret = cke.generate(state) self.assertEqual(len(ret.encryptedPreMasterSecret), 128) decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret) self.assertEqual(decrypt[:2], bytearray([3, 3])) self.assertEqual(decrypt[2:], bytearray([0]*46)) def test_generate_with_custom_premaster_secret(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, premaster_secret=bytearray([1]*10)) ret = cke.generate(state) self.assertEqual(len(ret.encryptedPreMasterSecret), 128) decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret) self.assertEqual(decrypt[:2], bytearray([3, 3])) self.assertEqual(decrypt[2:], bytearray([1]*8)) def test_generate_with_node_reuse_and_no_encrypted_value_reuse(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, premaster_secret=bytearray([1]*10)) ret = cke.generate(state) self.assertEqual(len(ret.encryptedPreMasterSecret), 128) decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret) self.assertEqual(decrypt[:2], bytearray([3, 3])) self.assertEqual(decrypt[2:], bytearray([1]*8)) ret2 = cke.generate(state) self.assertNotEqual(ret.encryptedPreMasterSecret, ret2.encryptedPreMasterSecret) def test_generate_with_node_reuse_and_encrypted_value_reuse(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, premaster_secret=bytearray([1]*10), reuse_encrypted_premaster=True) ret = cke.generate(state) self.assertEqual(len(ret.encryptedPreMasterSecret), 128) decrypt = self.priv_key.decrypt(ret.encryptedPreMasterSecret) self.assertEqual(decrypt[:2], bytearray([3, 3])) self.assertEqual(decrypt[2:], bytearray([1]*8)) ret2 = cke.generate(state) self.assertEqual(ret.encryptedPreMasterSecret, ret2.encryptedPreMasterSecret) def test_generate_with_dhe(self): state = ConnectionState() state.key_exchange = mock.MagicMock() cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA) ret = cke.generate(state) self.assertIs(ret, state.key_exchange.makeClientKeyExchange()) def test_generate_with_ecdhe(self): state = ConnectionState() state.key_exchange = mock.MagicMock() cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA) ret = cke.generate(state) self.assertIs(ret, state.key_exchange.makeClientKeyExchange()) def test_generate_with_unknown_cipher(self): state = ConnectionState() cke = ClientKeyExchangeGenerator() with self.assertRaises(AssertionError): cke.generate(state) def test_generate_DHE_with_bogus_value(self): state = ConnectionState() cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, dh_Yc=4982) ret = cke.generate(state) self.assertEqual(ret.dh_Yc, 4982) def test_generate_ECDHE_with_bogus_value(self): state = ConnectionState() cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, ecdh_Yc=bytearray(range(1, 24))) ret = cke.generate(state) self.assertEqual(ret.ecdh_Yc, bytearray(range(1, 24))) def test_generate_with_all_null_RSA(self): state = ConnectionState() cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, encrypted_premaster=bytearray(512)) ret = cke.generate(state) self.assertEqual(ret.encryptedPreMasterSecret, bytearray(512)) def test_generate_with_modulus_as_premaster(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, modulus_as_encrypted_premaster=True) ret = cke.generate(state) self.assertEqual(ret.encryptedPreMasterSecret, numberToByteArray(self.priv_key.n)) def test_generate_with_p_as_share(self): state = ConnectionState() ske = messages.ServerKeyExchange( constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, (3, 3)) ske.createDH(21, 2, 11) state.handshake_messages.append(ske) cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, p_as_share=True) ret = cke.generate(state) self.assertEqual(ret.dh_Yc, 21) def test_generate_with_p_1_as_share(self): state = ConnectionState() ske = messages.ServerKeyExchange( constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, (3, 3)) ske.createDH(21, 2, 11) state.handshake_messages.append(ske) cke = ClientKeyExchangeGenerator( cipher=constants.CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, p_1_as_share=True) ret = cke.generate(state) self.assertEqual(ret.dh_Yc, 20) def test_post_send(self): state = ConnectionState() state.get_server_public_key = lambda : self.priv_key cke = ClientKeyExchangeGenerator( constants.CipherSuite.TLS_RSA_WITH_NULL_MD5, (3, 3)) ret = cke.generate(state) cke.post_send(state) class TestChangeCipherSpecGenerator(unittest.TestCase): def test___init__(self): ccs = ChangeCipherSpecGenerator() self.assertIsNotNone(ccs) def test_generate(self): ccs = ChangeCipherSpecGenerator() ret = ccs.generate(None) self.assertIsInstance(ret, messages.ChangeCipherSpec) def test_post_send(self): ccsg = ChangeCipherSpecGenerator() ccsg.generate(None) state = ConnectionState() state.msg_sock = mock.MagicMock() ccsg.post_send(state) self.assertTrue(state.msg_sock.calcPendingStates.called) self.assertTrue(state.msg_sock.changeWriteState.called) def test_post_send_with_extended_master_secret(self): ccsg = ChangeCipherSpecGenerator() ccsg.generate(None) state = ConnectionState() state.extended_master_secret = True state.msg_sock = mock.MagicMock() with mock.patch('tlsfuzzer.messages.calcExtendedMasterSecret') as mthd: mthd.return_value = bytearray(48) ccsg.post_send(state) mthd.assert_called_once_with(state.version, state.cipher, state.key['premaster_secret'], state.handshake_hashes) self.assertTrue(state.msg_sock.calcPendingStates.called) self.assertTrue(state.msg_sock.changeWriteState.called) def test_post_send_with_fake_true(self): ccsg = ChangeCipherSpecGenerator(fake=True) ccsg.generate(None) state = ConnectionState() state.msg_sock = mock.MagicMock() with mock.patch('tlsfuzzer.messages.calcExtendedMasterSecret') as mthd: mthd.return_value = bytearray(48) ccsg.post_send(state) self.assertFalse(mthd.called) class TestClientMasterKeyGenerator(unittest.TestCase): def test___init__(self): cmk = ClientMasterKeyGenerator() def test_generate_with_no_cipher(self): cmk = ClientMasterKeyGenerator() with self.assertRaises(NotImplementedError): cmk.generate(None) def test_generate(self): cmk = ClientMasterKeyGenerator( cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5) state = ConnectionState() state.msg_sock = mock.MagicMock() state.get_server_public_key = mock.MagicMock() ret = cmk.generate(state) self.assertEqual(ret.cipher, constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5) self.assertEqual(ret.clear_key, bytearray(0)) self.assertEqual(ret.encrypted_key, state.get_server_public_key().encrypt()) self.assertEqual(ret.key_argument, state.msg_sock.calcSSL2PendingStates()) def test_generate_with_master_key(self): cmk = ClientMasterKeyGenerator( cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5, master_key=bytearray(range(24))) state = ConnectionState() state.msg_sock = mock.MagicMock() state.get_server_public_key = mock.MagicMock() ret = cmk.generate(state) state.msg_sock.calcSSL2PendingStates.assert_called_once_with( constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5, bytearray(range(24)), bytearray(0), bytearray(0), None) def test_generate_with_export_cipher(self): cmk = ClientMasterKeyGenerator( cipher=constants.CipherSuite.SSL_CK_RC2_128_CBC_EXPORT40_WITH_MD5) state = ConnectionState() state.msg_sock = mock.MagicMock() state.get_server_public_key = mock.MagicMock() ret = cmk.generate(state) self.assertEqual(len(ret.clear_key), 11) def test_generate_with_unknown_cipher(self): cmk = ClientMasterKeyGenerator(cipher=0xffffff) state = ConnectionState() with self.assertRaises(AssertionError): cmk.generate(state) def test_generate_with_des_cipher(self): cmk = ClientMasterKeyGenerator( cipher=constants.CipherSuite.SSL_CK_DES_64_CBC_WITH_MD5) state = ConnectionState() state.msg_sock = mock.MagicMock() state.get_server_public_key = mock.MagicMock() ret = cmk.generate(state) self.assertEqual(ret.encrypted_key, state.get_server_public_key().encrypt()) def test_generate_with_session_key(self): cmk = ClientMasterKeyGenerator( cipher=constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5) state = ConnectionState() state.msg_sock = mock.MagicMock() state.get_server_public_key = mock.MagicMock() state.key['master_secret'] = bytearray(range(32)) ret = cmk.generate(state) state.msg_sock.calcSSL2PendingStates.assert_called_once_with( constants.CipherSuite.SSL_CK_DES_192_EDE3_CBC_WITH_MD5, bytearray(range(32)), bytearray(0), bytearray(0), None) class TestCertificateGenerator(unittest.TestCase): def test___init__(self): certg = CertificateGenerator() self.assertIsNotNone(certg) def test_generate(self): certg = CertificateGenerator() state = ConnectionState() msg = certg.generate(state) self.assertIsInstance(msg, messages.Certificate) self.assertIsNone(msg.cert_chain) self.assertEqual(msg.certificateType, constants.CertificateType.x509) def test_generate_with_context(self): sig_algs = [constants.SignatureScheme.rsa_pss_rsae_sha512] cr = messages.CertificateRequest((3, 4)).create( sig_algs=sig_algs, context=b'zesty') ctx = [] certg = CertificateGenerator(context=ctx) state = ConnectionState() state.version = (3, 3) state.handshake_messages.append(cr) ctx.append(cr) msg = certg.generate(state) self.assertIsInstance(msg, messages.Certificate) self.assertIsNone(msg.cert_chain) self.assertEqual(msg.certificateType, constants.CertificateType.x509) self.assertEqual(msg.certificate_request_context, b'zesty') class TestCertificateVerifyGenerator(unittest.TestCase): @classmethod def setUpClass(cls): cls.priv_key = generateRSAKey(1024) def test___init__(self): cert_ver_g = CertificateVerifyGenerator() self.assertIsNotNone(cert_ver_g) def test_generate_without_priv_key(self): cert_ver_g = CertificateVerifyGenerator() state = ConnectionState() with self.assertRaises(ValueError): cert_ver_g.generate(state) def test_generate_TLS_1_1(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 2) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertTrue(priv_key.verify( msg.signature, secureHash(b'', 'md5') + secureHash(b'', 'sha1'))) def test_generate_TLS_1_2(self): priv_key = generateRSAKey(2048) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 256) # when no CR is present, select a secure algorithm by default self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha512) self.assertTrue(priv_key.verify( msg.signature, secureHash(b'', 'sha512'), padding="pss", hashAlg="sha512", saltLen=64)) def test_generate_TLS_1_2_with_explicit_msg_sig_alg(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator( priv_key, msg_alg=(constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.rsa), (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(req.write(), 'sha1'), hashAlg="sha1")) def test_generate_TLS_1_2_with_cert_request(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.rsa), (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.rsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(req.write(), 'sha256'), hashAlg="sha256")) def test_generate_TLS_1_3(self): # RSA-PSS+SHA512 requires a 2048 bit key priv_key = keyfactory.generateRSAKey(2048) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 256) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha512) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.rsa_pss_rsae_sha512, b'', b'', b'', # this is the PRF hash, not signature hash "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "pss", "sha512", saltLen=64)) def test_generate_TLS_1_3_with_cert_request(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [(constants.HashAlgorithm.md5, constants.SignatureAlgorithm.rsa), constants.SignatureScheme.rsa_pkcs1_sha1, constants.SignatureScheme.rsa_pkcs1_sha224, constants.SignatureScheme.rsa_pkcs1_sha256, constants.SignatureScheme.rsa_pkcs1_sha384, constants.SignatureScheme.rsa_pkcs1_sha512, constants.SignatureScheme.rsa_pss_rsae_sha256]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha256) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.rsa_pss_rsae_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "pss", "sha256", saltLen=32)) def test_generate_TLS_1_3_in_PHA_(self): priv_key = self.priv_key ctx = [] cert_ver_g = CertificateVerifyGenerator(priv_key, context=ctx) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [(constants.HashAlgorithm.md5, constants.SignatureAlgorithm.rsa), constants.SignatureScheme.rsa_pkcs1_sha1, constants.SignatureScheme.rsa_pkcs1_sha224, constants.SignatureScheme.rsa_pkcs1_sha256, constants.SignatureScheme.rsa_pkcs1_sha384, constants.SignatureScheme.rsa_pkcs1_sha512, constants.SignatureScheme.rsa_pss_rsae_sha256], context=b'unique PHA context, totally random') state.handshake_messages = [req] state.handshake_hashes.update(req.write()) state.key['client finished handshake hashes'] = \ HandshakeHashes() ctx.append(req) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha256) hh = state.key['client finished handshake hashes'].copy() hh.update(ctx[0].write()) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), hh, constants.SignatureScheme.rsa_pss_rsae_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "pss", "sha256", saltLen=32)) def test_generate_with_mismatched_alg(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_alg=( constants.HashAlgorithm.md5, constants.SignatureAlgorithm.rsa)) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.rsa), (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.rsa)) self.assertFalse(priv_key.hashAndVerify( msg.signature, req.write(), "pkcs1", "sha256")) self.assertTrue(priv_key.hashAndVerify( msg.signature, req.write(), "pkcs1", "md5")) def test_generate_with_mismatched_CR_TLS1_2(self): priv_key = generateRSAKey(2048) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 256) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha512) self.assertTrue(priv_key.hashAndVerify( msg.signature, req.write(), "pss", "sha512", 64)) def test_generate_with_mismatched_CR_TLS1_3(self): priv_key = generateRSAKey(2048) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 256) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha512) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.rsa_pss_rsae_sha512, b'', b'', b'', # this is the PRF hash, not signature hash "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "pss", "sha512", saltLen=64)) def test_generate_with_mismatched_CR_RSA_PSS(self): priv_key = Python_RSAKey.generate(2048, "rsa-pss") cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 256) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_pss_sha512) self.assertTrue(priv_key.hashAndVerify( msg.signature, req.write(), "pss", "sha512", 64)) def test_generate_with_rsa_pss_rsae_alg(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_sha256) self.assertTrue(priv_key.hashAndVerify( msg.signature, req.write(), "pss", "sha256", 32)) def test_generate_with_no_key(self): cert_ver_g = CertificateVerifyGenerator(signature=bytearray(b'xxxx')) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(msg.signature, bytearray(b'xxxx')) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha256) def test_generate_with_rsa_pss_pss_alg(self): priv_key = Python_RSAKey.generate(1024, "rsa-pss") cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] state.handshake_hashes.update(req.write()) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_pss_sha256) self.assertTrue(priv_key.hashAndVerify( msg.signature, req.write(), "pss", "sha256", 32)) def test_generate_with_rsa_with_subs(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_subs={4: 0x00}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) dec_sig = numberToByteArray(priv_key._rawPublicKeyOp( bytesToNumber(msg.signature)), 128) self.assertEqual(dec_sig[4], 0x00) # since we're siging the same set of messages always, the hash in # signature is also the same self.assertEqual(dec_sig[-1], 0x09) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pkcs1_sha1) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pkcs1", "sha1")) def test_generate_with_rsa_with_large_xors(self): # the xor can make the encrypted value too large (larger than the # modulus) verify that it doesn't break encryption (but the signature # will be completely invalid) priv_key = generateRSAKey(1020) cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={0: 0x80}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) dec_sig = numberToByteArray(priv_key._rawPublicKeyOp( bytesToNumber(msg.signature)), 128) self.assertEqual(dec_sig[0] & 0x80, 0x00) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pkcs1_sha1) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pkcs1", "sha1")) def test_generate_with_rsa_with_xors(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={4: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [(constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) dec_sig = numberToByteArray(priv_key._rawPublicKeyOp( bytesToNumber(msg.signature)), 128) self.assertEqual(dec_sig[4], 0x00) # since we're siging the same set of messages always, the hash in # signature is also the same self.assertEqual(dec_sig[-1], 0x09) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pkcs1_sha1) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pkcs1", "sha1")) def test_generate_with_rsa_pss_with_subs(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_subs={2: 0xff, -1: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) dec_sig = numberToByteArray(priv_key._rawPublicKeyOp( bytesToNumber(msg.signature)), 128) self.assertEqual(dec_sig[2], 0xff) self.assertEqual(dec_sig[-1], 0xff) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_sha256) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pss", "sha256", 32)) def test_generate_with_rsa_pss_with_xors(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={-1: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) dec_sig = numberToByteArray(priv_key._rawPublicKeyOp( bytesToNumber(msg.signature)), 128) self.assertEqual(dec_sig[-1], 0xbc ^ 0xff) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_sha256) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pss", "sha256", 32)) def test_generate_with_tls1_0_version(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 1)) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256, constants.SignatureScheme.rsa_pkcs1_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) # because we are creating a TLS 1.0 signature, we want to select # a signature algorithm that matches TLS 1.0 most closely # (the last 20 bytes of a decrypted signature will match) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pkcs1", "sha1")) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "md5") + secureHash(b"", "sha1"))) def test_generate_with_ssl3_0_version(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 0)) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256, constants.SignatureScheme.rsa_pkcs1_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) # just like with TLS 1.0, we want to use one that uses # algorithms closest to the ones we will actually use in the signature self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)) self.assertFalse(priv_key.hashAndVerify( msg.signature, b'', "pkcs1", "sha1")) verify_bytes = KeyExchange.calcVerifyBytes( (3, 0), state.handshake_hashes, None, b'', b'', b'') self.assertTrue(priv_key.verify( msg.signature, verify_bytes)) def test_generate_with_empty_signature(self): cert_ver_g = CertificateVerifyGenerator(signature=bytearray()) state = ConnectionState() state.version = (3, 3) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(msg.signature, bytearray()) def test_generate_with_mismatched_mgf1(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 4), mgf1_hash="sha512") state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha256) verify_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.rsa_pss_rsae_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verify_bytes, "pss", "sha512", 32)) self.assertFalse(priv_key.hashAndVerify( msg.signature, verify_bytes, "pss", "sha256", 32)) def test_generate_with_mismatched_mgf1_and_salt_len(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 4), mgf1_hash="sha384", rsa_pss_salt_len=48) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_rsae_sha256) verify_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.rsa_pss_rsae_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verify_bytes, "pss", "sha384", 48)) self.assertFalse(priv_key.hashAndVerify( msg.signature, verify_bytes, "pss", "sha256", 32)) def test_generate_with_key_and_certificate(self): priv_key = Python_RSAKey.generate(1024, "rsa-pss") cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) cert = Certificate(constants.CertificateType.x509, (3, 3))\ .create(rsa_pss_cert) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_rsae_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, constants.SignatureScheme.rsa_pss_pss_sha384]) state.handshake_messages = [req, cert] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(len(msg.signature), 128) # when no CR is present, select a secure algorithm by default self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.rsa_pss_pss_sha256) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha256"), "pss", "sha256", 32)) class TestCertificateVerifyGeneratorECDSA(unittest.TestCase): @classmethod def setUpClass(cls): cls.priv_key = Python_ECDSAKey(None, None, "NIST256p", 12) def test_generate_with_ecdsa_and_no_cert_req_in_tls1_3(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ecdsa_secp256r1_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "", "sha256")) def test_generate_with_ecdsa_cert_and_no_key_in_tls1_3(self): cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF") state = ConnectionState() state.version = (3, 4) cert = messages.Certificate(constants.CertificateType.x509, (3, 4))\ .create(ecdsa_cert) state.handshake_messages = [cert] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(msg.signature, bytearray(b'ABCDEF')) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)) def test_generate_with_ecdsa_cert_and_no_key_in_tls1_2(self): cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF") state = ConnectionState() state.version = (3, 3) cert = messages.Certificate(constants.CertificateType.x509, (3, 3))\ .create(ecdsa_cert) state.handshake_messages = [cert] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(msg.signature, bytearray(b'ABCDEF')) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) def test_generate_with_ecdsa_cert_and_no_key_unmatched_CR_in_tls1_2(self): cert_ver_g = CertificateVerifyGenerator(signature=b"ABCDEF") state = ConnectionState() state.version = (3, 3) cert = messages.Certificate(constants.CertificateType.x509, (3, 3))\ .create(ecdsa_cert) cert_req = messages.CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_pss_sha256]) state.handshake_messages = [cert_req, cert] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertEqual(msg.signature, bytearray(b'ABCDEF')) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) def test_generate_with_ecdsa_in_TLS1_0(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 1) req = CertificateRequest((3, 1)).create() state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, None) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha1"), "", "sha1")) def test_generate_with_ecdsa_with_TLS1_0_format(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, sig_version=(3, 1)) state = ConnectionState() state.version = (3, 3) req = messages.CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp256r1_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.ecdsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.ecdsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha1"), "", "sha1")) def test_generate_with_ecdsa_256_alg(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha512")[:32], "", "sha512")) def test_generate_with_ecdsa_256_alg_and_non_matching_CR_tls1_3(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha256, constants.SignatureAlgorithm.ecdsa)) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ecdsa_secp256r1_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "", "sha256")) def test_generate_with_ecdsa_256_alg_and_non_matching_CR_tls1_2(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha512")[:32], "", "sha512")) def test_generate_with_ecdsa_384_alg(self): priv_key = Python_ECDSAKey(None, None, "NIST384p", 11) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha512")[:48], "", "sha512")) def test_generate_with_ecdsa_521_alg(self): priv_key = Python_ECDSAKey(None, None, "NIST521p", 10) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertTrue(priv_key.verify( msg.signature, secureHash(b"", "sha512"), "", "sha512")) def test_generate_with_ecdsa_256_alg_in_tls1_3(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.ecdsa_secp256r1_sha256) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ecdsa_secp256r1_sha256, b'', b'', b'', "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "", "sha256")) def test_generate_with_ecdsa_384_alg(self): priv_key = Python_ECDSAKey(None, None, "NIST384p", 11) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.ecdsa_secp384r1_sha384) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ecdsa_secp384r1_sha384, b'', b'', b'', # PRF, not hash of signature "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "", "sha384")) def test_generate_with_ecdsa_521_alg(self): priv_key = Python_ECDSAKey(None, None, "NIST521p", 10) cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) req = CertificateRequest((3, 4)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.ecdsa_secp521r1_sha512) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ecdsa_secp521r1_sha512, b'', b'', b'', # PRF, not hash of signature "sha256") self.assertTrue(priv_key.verify( msg.signature, verif_bytes, "", "sha512")) def test_generate_with_xors(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={0: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertEqual(msg.signature[0] ^ 0xff, 0x30) sig = msg.signature self.assertFalse(priv_key.verify( sig, secureHash(b"", "sha512")[:32], "", "sha512")) sig[0] ^= 0xff self.assertTrue(priv_key.verify( sig, secureHash(b"", "sha512")[:32], "", "sha512")) def test_generate_with_xors_in_tls1_0(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={0: 0xff}) state = ConnectionState() state.version = (3, 1) req = CertificateRequest((3, 1)).create() state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertIsNone(msg.signatureAlgorithm) self.assertEqual(msg.signature[0] ^ 0xff, 0x30) sig = msg.signature self.assertFalse(priv_key.verify( sig, secureHash(b"", "sha1"), "", "sha1")) sig[0] ^= 0xff self.assertTrue(priv_key.verify( sig, secureHash(b"", "sha1"), "", "sha1")) def test_generate_with_subs(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_subs={0: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ecdsa_secp521r1_sha512, constants.SignatureScheme.ecdsa_secp384r1_sha384, constants.SignatureScheme.ecdsa_secp256r1_sha256, constants.SignatureScheme.rsa_pss_pss_sha256, (constants.HashAlgorithm.sha1, constants.SignatureAlgorithm.rsa)]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, (constants.HashAlgorithm.sha512, constants.SignatureAlgorithm.ecdsa)) self.assertEqual(msg.signature[0], 0xff) sig = msg.signature self.assertFalse(priv_key.verify( sig, secureHash(b"", "sha512")[:32], "", "sha512")) sig[0] = 0x30 self.assertTrue(priv_key.verify( sig, secureHash(b"", "sha512")[:32], "", "sha512")) def test_generate_with_subs_in_TLS1_0(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_subs={0: 0xff}) state = ConnectionState() state.version = (3, 1) req = CertificateRequest((3, 1)).create() state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) # the SignatureScheme.ecdsa_secp*_sha* ignore the curve type in # TLS 1.2, so the negotiation will always select SHA512 algorithm self.assertEqual(msg.signatureAlgorithm, None) sig = msg.signature self.assertFalse(priv_key.verify( sig, secureHash(b"", "sha1"), "", "sha1")) sig[0] = 0x30 self.assertTrue(priv_key.verify( sig, secureHash(b"", "sha1"), "", "sha1")) class TestCertificateVerifyGeneratorEdDSA(unittest.TestCase): @classmethod def setUpClass(cls): p_key = SigningKey.generate(Ed25519) cls.priv_key = Python_EdDSAKey(None, p_key) def test_generate_with_eddsa_and_no_cert_req_in_tls1_3(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key) state = ConnectionState() state.version = (3, 4) msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.ed25519) verif_bytes = KeyExchange.calcVerifyBytes( (3, 4), state.handshake_hashes, constants.SignatureScheme.ed25519, b'', b'', b'', "sha256") self.assertTrue(priv_key.hashAndVerify( msg.signature, verif_bytes, "", "sha256")) def test_generate_with_xors(self): priv_key = self.priv_key cert_ver_g = CertificateVerifyGenerator(priv_key, padding_xors={0: 0xff}) state = ConnectionState() state.version = (3, 3) req = CertificateRequest((3, 3)).create([], [], [constants.SignatureScheme.ed25519]) state.handshake_messages = [req] msg = cert_ver_g.generate(state) self.assertIsNotNone(msg) self.assertTrue(msg.signature) self.assertEqual(msg.signatureAlgorithm, constants.SignatureScheme.ed25519) sig = msg.signature self.assertFalse(priv_key.hashAndVerify( sig, b"")) sig[0] ^= 0xff self.assertTrue(priv_key.hashAndVerify( sig, b"")) class TestClearContext(unittest.TestCase): def test___init__(self): ctx = [] cc = ClearContext(ctx) self.assertIs(cc.context, ctx) self.assertTrue(cc.is_command()) self.assertFalse(cc.is_expect()) self.assertFalse(cc.is_generator()) def test_process(self): ctx = ["abc", 123] cc = ClearContext(ctx) self.assertEqual(ctx, ["abc", 123]) cc.process(None) self.assertEqual(ctx, []) class TestAlertGenerator(unittest.TestCase): def test_default_settings(self): a = AlertGenerator() self.assertIsNotNone(a) state = ConnectionState() ret = a.generate(state) self.assertEqual(ret.level, constants.AlertLevel.warning) self.assertEqual(ret.description, constants.AlertDescription.close_notify) def test___init___with_parameters(self): a = AlertGenerator(constants.AlertLevel.fatal, constants.AlertDescription.decode_error) self.assertIsNotNone(a) state = ConnectionState() ret = a.generate(state) self.assertEqual(ret.level, constants.AlertLevel.fatal) self.assertEqual(ret.description, constants.AlertDescription.decode_error) class TestFinishedGenerator(unittest.TestCase): def test___init__(self): fg = FinishedGenerator() self.assertIsNotNone(fg) def test_generate(self): fg = FinishedGenerator() state = ConnectionState() ret = fg.generate(state) self.assertIsInstance(ret, messages.Finished) def test_post_send(self): fg = FinishedGenerator() state = ConnectionState() ret = fg.generate(state) self.assertNotIn(ret, state.handshake_messages) fg.post_send(state) self.assertIn(ret, state.handshake_messages) def test_post_send_with_resumption(self): fg = FinishedGenerator() state = ConnectionState() state.resuming = True ret = fg.generate(state) fg.post_send(state) self.assertFalse(state.resuming) def test_generate_with_ssl2(self): fg = FinishedGenerator((0, 2)) state = ConnectionState() state.msg_sock = mock.MagicMock() state.session_id = bytearray(b'abba') ret = fg.generate(state) self.assertEqual(ret.verify_data, bytearray(b'abba')) state.msg_sock.changeWriteState.assert_called_once_with() state.msg_sock.changeReadState.assert_called_once_with() def test_generate_in_tls13(self): fg = FinishedGenerator((3, 4)) state = ConnectionState() state.msg_sock = mock.MagicMock() state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256 state.version = (3, 4) state.key['client handshake traffic secret'] = bytearray(32) ret = fg.generate(state) self.assertEqual(ret.verify_data, bytearray( b'\x14\xa5e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef' b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3\xc9\x0c')) state.key['handshake secret'] = bytearray(32) state.key['master secret'] = bytearray(32) fg.post_send(state) state.msg_sock.changeWriteState.assert_called_once_with() self.assertEqual(state.key['resumption master secret'], bytearray(b'\x89\xd8\x00l c$\x01\x0f\xd9j\x16\xa3\xbaV\xfesT\x8b' b'\xc6\xeb\x0f~\r\xbd\xb3R\xeb\xd5\x08\xa7\xbd')) def test_generate_in_tls13_with_pha(self): ctx = [] fg = FinishedGenerator((3, 4), context=ctx) cv = messages.CertificateVerify((3, 4)).create( b'abcdef', constants.SignatureScheme.rsa_pss_rsae_sha256) ctx.append(cv) state = ConnectionState() state.msg_sock = mock.MagicMock() state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256 state.version = (3, 4) state.key['client handshake traffic secret'] = bytearray(32) state.key['client application traffic secret'] = bytearray(b'x'*32) state.key['client finished handshake hashes'] = HandshakeHashes() ret = fg.generate(state) print(repr(ret.verify_data)) self.assertEqual(ret.verify_data, bytearray(b'q\xf1l\x05\x94\xb8"\xb2L7\xce\xd5\xb3\x00\xa6\r\x17*' b'\xcc\xe7\xdc\xa6\xf0c\xd7\x90I\x11}\xbfq:')) state.key['handshake secret'] = bytearray(32) state.key['master secret'] = bytearray(32) fg.post_send(state) self.assertTrue(not state.msg_sock.changeWriteState.called) self.assertNotIn('resumption master secret', state.key) def test_generate_in_tls13_with_truncation(self): fg = FinishedGenerator((3, 4), trunc_start=2, trunc_end=-2) state = ConnectionState() state.msg_sock = mock.MagicMock() state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256 state.version = (3, 4) state.key['client handshake traffic secret'] = bytearray(32) ret = fg.generate(state) self.assertEqual(ret.verify_data, bytearray( b'e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef' b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3')) def test_generate_in_tls13_with_padding(self): fg = FinishedGenerator((3, 4), pad_byte=0, pad_left=1, pad_right=1) state = ConnectionState() state.msg_sock = mock.MagicMock() state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256 state.version = (3, 4) state.key['client handshake traffic secret'] = bytearray(32) ret = fg.generate(state) self.assertEqual(ret.verify_data, bytearray( b'\x00\x14\xa5e\xa67\xfe\xa3(\xd3\xac\x95\xecX\xb7\xc0\xd4u\xef' b'\xb3V\x8f\xc7[\xcdD\xc8\xa4\x86\xcf\xd3\xc9\x0c\x00')) class TestKeyUpdateGenerator(unittest.TestCase): def test_default_settings(self): ku = KeyUpdateGenerator() self.assertIsNotNone(ku) state = ConnectionState() ret = ku.generate(state) self.assertEqual(ret.message_type, constants.KeyUpdateMessageType.update_not_requested) def test___init___with_parameters(self): ku = KeyUpdateGenerator( constants.KeyUpdateMessageType.update_requested) self.assertIsNotNone(ku) state = ConnectionState() ret = ku.generate(state) self.assertEqual(ret.message_type, constants.KeyUpdateMessageType.update_requested) def test_post_send(self): ku = KeyUpdateGenerator() state = ConnectionState() state.msg_sock = mock.MagicMock() state.cipher = constants.CipherSuite.TLS_AES_128_GCM_SHA256 client_app_secret = mock.Mock() server_app_secret = mock.Mock() state.key['client application traffic secret'] = client_app_secret state.key['server application traffic secret'] = server_app_secret state.msg_sock.calcTLS1_3KeyUpdate_reciever.return_value = ( bytearray(32), bytearray(32)) ku.post_send(state) state.msg_sock.calcTLS1_3KeyUpdate_reciever.assert_called_once_with( state.cipher, client_app_secret, server_app_secret) class TestResetHandshakeHashes(unittest.TestCase): def test___init__(self): node = ResetHandshakeHashes() self.assertIsNotNone(node) def test_process(self): node = ResetHandshakeHashes() state = ConnectionState() hashes = state.handshake_hashes self.assertIs(hashes, state.handshake_hashes) node.process(state) self.assertIsNot(hashes, state.handshake_hashes) class TestResetRenegotiationInfo(unittest.TestCase): def test___init__(self): node = ResetRenegotiationInfo() self.assertIsNotNone(node) def test_process(self): node = ResetRenegotiationInfo() state = ConnectionState() state.key['client_verify_data'] = bytearray(b'\xde\xad\xc0\xde') state.key['server_verify_data'] = bytearray(b'\xc0\xff\xee') node.process(state) self.assertEqual(state.key['client_verify_data'], bytearray(0)) self.assertEqual(state.key['server_verify_data'], bytearray(0)) class TestSetMaxRecordSize(unittest.TestCase): def test___init__(self): node = SetMaxRecordSize() self.assertIsNotNone(node) def test_process(self): node = SetMaxRecordSize() state = ConnectionState() state.msg_sock = mock.MagicMock() state.msg_sock.recordSize = 1024 node.process(state) self.assertEqual(2**14, state.msg_sock.recordSize) def test_process_with_size(self): node = SetMaxRecordSize(2048) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) self.assertEqual(2048, state.msg_sock.recordSize) class TestSetPaddingCallback(unittest.TestCase): def test___init__(self): node = SetPaddingCallback() self.assertIsNotNone(node) def test_process_fixed_len_padding(self): node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(42)) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) self.assertEqual(42, state.msg_sock.padding_cb(13, constants.ContentType.application_data, 2**14 - 1)) def test_process_fill_padding(self): node = SetPaddingCallback(SetPaddingCallback.fill_padding_cb) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) self.assertEqual(2**14 - 13 - 1, state.msg_sock.padding_cb(13, constants.ContentType.application_data, 2**14 - 1)) def test_process_custom_callback(self): def _my_cb(length, contenttype, max_padding): return 1337 node = SetPaddingCallback(_my_cb) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) self.assertEqual(1337, state.msg_sock.padding_cb(13, constants.ContentType.application_data, 2**14 - 1)) def test_unset_padding_callback(self): node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(16)) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) self.assertEqual(16, state.msg_sock.padding_cb(13, constants.ContentType.application_data, 2**14 - 1)) unset_node = SetPaddingCallback() unset_node.process(state) self.assertIsNone(state.msg_sock.padding_cb) def test_with_padding_larger_than_possible(self): node = SetPaddingCallback(SetPaddingCallback.fixed_length_cb(42)) state = ConnectionState() state.msg_sock = mock.MagicMock() node.process(state) with self.assertRaises(ValueError): state.msg_sock.padding_cb(20, constants.ContentType.application_data, 32) class TestRenegotiationInfoExtension(unittest.TestCase): def test___init__(self): ext = RenegotiationInfoExtension() self.assertIsNotNone(ext) def test_write(self): ext = RenegotiationInfoExtension() self.assertEqual(ext.write(), bytearray( b'\xff\x01' + # extension type b'\x00\x00' # overall extension length )) def test_write_with_data(self): ext = RenegotiationInfoExtension() ext.create(bytearray(b'\xab\xcd')) self.assertEqual(ext.write(), bytearray( b'\xff\x01' + # extension type b'\x00\x03' + # overall extension length b'\x02' + # payload length b'\xab\xcd' # payload )) def test_parse(self): parser = Parser(bytearray(b'\x02\xab\xcd')) ext = RenegotiationInfoExtension() ext.parse(parser) self.assertEqual(bytearray(b'\xab\xcd'), ext.renegotiated_connection) class TestHandshakePadding(unittest.TestCase): def setUp(self): self.state = ConnectionState() self.hello_gen = ClientHelloGenerator() self.vanilla_hello = self.hello_gen.generate(self.state).write() def test_no_option(self): self.assertEqual(len(self.vanilla_hello), 43) hello_gen = pad_handshake(ClientHelloGenerator()) unmodified_hello = hello_gen.generate(self.state).write() self.assertEqual(len(unmodified_hello), 43) self.assertEqual(self.vanilla_hello, unmodified_hello) def test_add_padding(self): hello_gen = pad_handshake(ClientHelloGenerator(), 1) padded_hello = hello_gen.generate(self.state).write() self.assertEqual(len(padded_hello), 44) # skip the first 4 bytes as they have different length self.assertEqual(self.vanilla_hello[4:] + bytearray(1), padded_hello[4:]) def test_add_specific_padding(self): hello_gen = pad_handshake(ClientHelloGenerator(), 2, 0xab) padded_hello = hello_gen.generate(self.state).write() self.assertEqual(len(padded_hello), 45) # skip the first 4 bytes as they have different length self.assertEqual(self.vanilla_hello[4:] + bytearray(b'\xab\xab'), padded_hello[4:]) def test_pad_with_data(self): pad = bytearray(b'\xff\x01\x00\x01\x00') hello_gen = pad_handshake(ClientHelloGenerator(), pad=pad) padded_hello = hello_gen.generate(self.state).write() self.assertEqual(len(padded_hello), len(self.vanilla_hello) + len(pad)) self.assertEqual(self.vanilla_hello[4:] + pad, padded_hello[4:]) self.assertNotEqual(self.vanilla_hello[:4], padded_hello[:4]) def test_truncate(self): hello_gen = truncate_handshake(ClientHelloGenerator(), 1) padded_hello = hello_gen.generate(self.state).write() self.assertEqual(len(padded_hello), 42) # skip the first 4 bytes as they have different length self.assertEqual(self.vanilla_hello[4:-1], padded_hello[4:]) class TestFuzzMessage(unittest.TestCase): def setUp(self): self.state = ConnectionState() self.hello_gen = ClientHelloGenerator() self.vanilla_hello = self.hello_gen.generate(self.state).write() def test_no_options(self): self.assertEqual(len(self.vanilla_hello), 43) hello_gen = fuzz_message(ClientHelloGenerator()) unmodified_hello = hello_gen.generate(self.state).write() self.assertEqual(len(unmodified_hello), 43) self.assertEqual(self.vanilla_hello, unmodified_hello) def test_substitutions(self): hello_gen = fuzz_message(ClientHelloGenerator(), substitutions={4:0xff}) modified_hello = hello_gen.generate(self.state).write() self.assertNotEqual(self.vanilla_hello, modified_hello) self.vanilla_hello[4] = 0xff self.assertEqual(self.vanilla_hello, modified_hello) def test_xors(self): hello_gen = fuzz_message(ClientHelloGenerator(), xors={4:0xff}) modified_hello = hello_gen.generate(self.state).write() self.assertNotEqual(self.vanilla_hello, modified_hello) self.vanilla_hello[4] ^= 0xff self.assertEqual(self.vanilla_hello, modified_hello) class TestFuzzMAC(unittest.TestCase): def setUp(self): self.state = ConnectionState() self.socket = MockSocket(bytearray()) defragger = defragmenter.Defragmenter() defragger.add_static_size(constants.ContentType.alert, 2) defragger.add_static_size(constants.ContentType.change_cipher_spec, 1) defragger.add_dynamic_size(constants.ContentType.handshake, 1, 3) self.state.msg_sock = messagesocket.MessageSocket(self.socket, defragger) self.state.msg_sock.version = (3, 3) self.state.msg_sock.calcPendingStates(constants.CipherSuite.\ TLS_RSA_WITH_NULL_MD5, bytearray(48), bytearray(32), bytearray(32), None) self.state.msg_sock.changeWriteState() self.expected_value = bytearray( b"\x16" # content type b"\x03\x03" # record layer protocol version b"\x00\x3b" # record layer record length b"\x01" # handshake message type b"\x00\x00\x27" # handshke protocol message length b"\x03\x03" # client hello protocol version # random b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" b"\x00" # length of session_id b"\x00\x00" # cipher_suites length b"\x01\x00" # compression_methods (0 - uncompressed) # 128 bit MD5 HMAC value b"\x1cK \xce\xb3\x1d\x94\x0b\x0f\x9a\'\x9c\x87\x1a-`" ) self.second_write = bytearray(self.expected_value[:-16]) self.second_write += bytearray( # MD5 HMAC with sequence number of "2" b"\x84\xcb\\\xf2A\x0c\xd3