diff --git a/test/test_messages.py b/test/test_messages.py index 3cd0244..ea2831e 100755 --- a/test/test_messages.py +++ b/test/test_messages.py @@ -19,8 +19,8 @@ import calendar import os +import pytest import sys -import unittest import time sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/..') # noqa @@ -31,7 +31,7 @@ from tuhi.protocol import * SUCCESS = NordicData([0xb3, 0x1, 0x00]) -class TestUtils(unittest.TestCase): +class TestUtils(object): def test_hex_string(self): values = [ ([0x00, 0x12], '00 12'), @@ -44,12 +44,12 @@ class TestUtils(unittest.TestCase): ] for v in values: - self.assertEqual(as_hex_string(v[0]), v[1]) + assert as_hex_string(v[0]) == v[1] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): as_hex_string(1) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): as_hex_string('0x00') def test_protocol_version(self): @@ -64,14 +64,14 @@ class TestUtils(unittest.TestCase): ] for v in values: - self.assertEqual(ProtocolVersion.from_string(v[0]), v[1]) + assert ProtocolVersion.from_string(v[0]) == v[1] # No real reason we couldn't support those but right now they # aren't, so let's test for it. - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ProtocolVersion.from_string('Slate') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ProtocolVersion.from_string('IntuosPro') def test_little_u16(self): @@ -81,12 +81,12 @@ class TestUtils(unittest.TestCase): ] for v in values: - self.assertEqual(little_u16(v[0]), bytes(v[1])) - self.assertEqual(little_u16(v[1]), v[0]) + assert little_u16(v[0]) == bytes(v[1]) + assert little_u16(v[1]) == v[0] invalid = [0x10000, -1, [0x00, 0x00, 0x00]] for v in invalid: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): little_u16(v) def test_little_u32(self): @@ -98,20 +98,20 @@ class TestUtils(unittest.TestCase): ] for v in values: - self.assertEqual(little_u32(v[0]), bytes(v[1])) - self.assertEqual(little_u32(v[1]), v[0]) + assert little_u32(v[0]) == bytes(v[1]) + assert little_u32(v[1]) == v[0] invalid = [0x100000000, -1, [0x00, 0x00, 0x00, 0x00, 0x00]] for v in invalid: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): little_u32(v) -class TestProtocolAny(unittest.TestCase): +class TestProtocolAny(object): protocol_version = ProtocolVersion.ANY def test_get_protocol(self): - self.assertIsNotNone(Protocol(self.protocol_version, callback=None)) + assert Protocol(self.protocol_version, callback=None) is not None def test_has_all_messages(self): p = Protocol(self.protocol_version, callback=None) @@ -136,47 +136,47 @@ class TestProtocolAny(unittest.TestCase): def test_connect(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe6) - self.assertEqual(request.length, 6) + assert request.opcode == 0xe6 + assert request.length == 6 return SUCCESS if cb is None: cb = _cb p = Protocol(self.protocol_version, callback=cb) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): p.execute(Interactions.CONNECT) # missing argument uuid = 'abcdef123456' msg = p.execute(Interactions.CONNECT, uuid) - self.assertEqual(msg.uuid, uuid) + assert msg.uuid == uuid - with self.assertRaises(ValueError): + with pytest.raises(ValueError): p.execute(Interactions.CONNECT, 'too-long-an-id') - with self.assertRaises(binascii.Error): + with pytest.raises(binascii.Error): uuid = 'uvwxyz123456' p.execute(Interactions.CONNECT, uuid) def test_get_name(self, cb=None, name='test dev name\x0a'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xbb) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xbb + assert request.length == 1 + assert request[0] == 0x00 return NordicData([0xbc, len(name)] + list(bytes(name, encoding='ascii'))) cb = cb or _cb p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_NAME) - self.assertEqual(msg.name, name) + assert msg.name == name def test_set_name(self, cb=None, name='test dev name'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xbb) - self.assertEqual(request.length, len(name) + 1) - self.assertEqual(request[-1], 0xa) # spark needs a trailing linebreak - self.assertEqual(bytes(request[:-1]).decode('utf-8'), name) + assert request.opcode == 0xbb + assert request.length == len(name) + 1 + assert request[-1] == 0xa # spark needs a trailing linebreak + assert bytes(request[:-1]).decode('utf-8') == name return SUCCESS cb = cb or _cb @@ -186,8 +186,8 @@ class TestProtocolAny(unittest.TestCase): def test_get_time(self, cb=None, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb6) - self.assertEqual(request.length, 1) + assert request.opcode == 0xb6 + assert request.length == 1 t = time.strftime('%y%m%d%H%M%S', time.gmtime(ts)) t = [int(i) for i in binascii.unhexlify(t)] return NordicData([0xbd, len(t)] + t) @@ -196,15 +196,15 @@ class TestProtocolAny(unittest.TestCase): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_TIME) - self.assertEqual(msg.timestamp, int(ts)) + assert msg.timestamp == int(ts) def test_set_time(self, cb=None, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb6) - self.assertEqual(request.length, 6) + assert request.opcode == 0xb6 + assert request.length == 6 str_timestamp = ''.join([f'{b:02x}' for b in request]) t = calendar.timegm(time.strptime(str_timestamp, '%y%m%d%H%M%S')) - self.assertEqual(int(t), int(ts)) + assert int(t) == int(ts) return SUCCESS cb = cb or _cb @@ -214,8 +214,8 @@ class TestProtocolAny(unittest.TestCase): def test_get_fw(self, cb=None, fw='abcdef-123456'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb7) - self.assertEqual(request.length, 1) + assert request.opcode == 0xb7 + assert request.length == 1 data = [int(c, 16) for c in fw.split('-')[request[0]]] return NordicData([0xb8, len(data) + 1, 0x00] + data) @@ -223,43 +223,43 @@ class TestProtocolAny(unittest.TestCase): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_FIRMWARE) - self.assertEqual(msg.firmware, fw) + assert msg.firmware == fw def test_get_battery(self, cb=None, battery=(1, 78)): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb9) - self.assertEqual(request.length, 1) + assert request.opcode == 0xb9 + assert request.length == 1 return NordicData([0xba, 2, battery[1], battery[0]]) cb = cb or _cb p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_BATTERY) - self.assertEqual(msg.battery_is_charging, battery[0]) - self.assertEqual(msg.battery_percent, battery[1]) + assert msg.battery_is_charging == battery[0] + assert msg.battery_percent == battery[1] def test_get_width(self, cb=None): # this is hardcoded for the spark p = Protocol(self.protocol_version, callback=None) msg = p.execute(Interactions.GET_WIDTH) - self.assertEqual(msg.width, 21000) + assert msg.width == 21000 def test_get_height(self, cb=None): # this is hardcoded for the spark p = Protocol(self.protocol_version, callback=None) msg = p.execute(Interactions.GET_HEIGHT) - self.assertEqual(msg.height, 14800) + assert msg.height == 14800 def test_get_point_size(self, cb=None): # this is hardcoded for the spark p = Protocol(self.protocol_version, callback=None) msg = p.execute(Interactions.GET_POINT_SIZE) - self.assertEqual(msg.point_size, 10) + assert msg.point_size == 10 def test_unknown_e3(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe3) - self.assertEqual(request.length, 1) + assert request.opcode == 0xe3 + assert request.length == 1 return SUCCESS cb = cb or _cb @@ -269,9 +269,9 @@ class TestProtocolAny(unittest.TestCase): def test_filetransfer_reporting_type(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xec) - self.assertEqual(request.length, 6) - self.assertEqual(request, [0x06, 0x00, 0x00, 0x00, 0x00, 0x00]) + assert request.opcode == 0xec + assert request.length == 6 + assert request, [0x06, 0x00, 0x00, 0x00, 0x00 == 0x00] return SUCCESS cb = cb or _cb @@ -284,9 +284,9 @@ class TestProtocolAny(unittest.TestCase): mode = Mode.LIVE def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb1) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], mode) + assert request.opcode == 0xb1 + assert request.length == 1 + assert request[0] == mode return SUCCESS cb = cb or _cb @@ -299,9 +299,9 @@ class TestProtocolAny(unittest.TestCase): # this is a weird double call, see the protocol # We reply 0xc7 first, and then 0xcd if request is not None: - self.assertEqual(request.opcode, 0xc5) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xc5 + assert request.length == 1 + assert request[0] == 0x00 data = list(count.to_bytes(4, byteorder='big')) return NordicData([0xc7, len(data)] + data) else: @@ -313,14 +313,14 @@ class TestProtocolAny(unittest.TestCase): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_STROKES) - self.assertEqual(msg.count, count) - self.assertEqual(msg.timestamp, int(ts)) + assert msg.count == count + assert msg.timestamp == int(ts) def test_available_files_count(self, cb=None, ndata=1234): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xc1) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xc1 + assert request.length == 1 + assert request[0] == 0x00 data = list(ndata.to_bytes(2, byteorder='big')) return NordicData([0xc2, len(data)] + data) @@ -328,13 +328,13 @@ class TestProtocolAny(unittest.TestCase): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.AVAILABLE_FILES_COUNT) - self.assertEqual(msg.count, ndata) + assert msg.count == ndata def test_download_oldest_file(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xc3) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xc3 + assert request.length == 1 + assert request[0] == 0x00 return NordicData([0xc8, 1, 0xbe]) cb = cb or _cb @@ -344,9 +344,9 @@ class TestProtocolAny(unittest.TestCase): def test_delete_oldest_file(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xca) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xca + assert request.length == 1 + assert request[0] == 0x00 # no reply cb = cb or _cb @@ -356,9 +356,9 @@ class TestProtocolAny(unittest.TestCase): def test_register_complete(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe5) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xe5 + assert request.length == 1 + assert request[0] == 0x00 return SUCCESS cb = cb or _cb @@ -368,16 +368,16 @@ class TestProtocolAny(unittest.TestCase): def test_register_press_button(self, cb=None, uuid=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe3) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x01) + assert request.opcode == 0xe3 + assert request.length == 1 + assert request[0] == 0x01 # no reply cb = cb or _cb p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.REGISTER_PRESS_BUTTON, uuid=uuid) - self.assertEqual(msg.uuid, uuid) + assert msg.uuid == uuid def test_error_invalid_state(self): def _cb(request, requires_reply=True, userdata=None, timeout=5): @@ -386,17 +386,17 @@ class TestProtocolAny(unittest.TestCase): p = Protocol(self.protocol_version, callback=_cb) # a "random" collection of requests that we want to check for - with self.assertRaises(DeviceError) as cm: + with pytest.raises(DeviceError) as cm: p.execute(Interactions.CONNECT, uuid='abcdef123456') - self.assertEqual(cm.exception.errorcode, DeviceError.ErrorCode.GENERAL_ERROR) + assert cm.value.errorcode == DeviceError.ErrorCode.GENERAL_ERROR - with self.assertRaises(DeviceError) as cm: + with pytest.raises(DeviceError) as cm: p.execute(Interactions.GET_STROKES) - self.assertEqual(cm.exception.errorcode, DeviceError.ErrorCode.GENERAL_ERROR) + assert cm.value.errorcode == DeviceError.ErrorCode.GENERAL_ERROR - with self.assertRaises(DeviceError) as cm: + with pytest.raises(DeviceError) as cm: p.execute(Interactions.SET_MODE, Mode.PAPER) - self.assertEqual(cm.exception.errorcode, DeviceError.ErrorCode.GENERAL_ERROR) + assert cm.value.errorcode == DeviceError.ErrorCode.GENERAL_ERROR class TestProtocolSpark(TestProtocolAny): @@ -404,14 +404,14 @@ class TestProtocolSpark(TestProtocolAny): def test_register_wait_for_button(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertIsNone(request) + assert request is None return NordicData([0xe4, 0x00]) cb = cb or _cb p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.REGISTER_WAIT_FOR_BUTTON) - self.assertEqual(msg.protocol_version, self.protocol_version) + assert msg.protocol_version == self.protocol_version class TestProtocolSlate(TestProtocolSpark): @@ -419,9 +419,9 @@ class TestProtocolSlate(TestProtocolSpark): def test_get_width(self, cb=None, width=1234): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xea) - self.assertEqual(request.length, 2) - self.assertEqual(request[0], 3) + assert request.opcode == 0xea + assert request.length == 2 + assert request[0] == 3 data = [0x03, 0x00] + list(width.to_bytes(4, byteorder='little')) return NordicData([0xeb, len(data)] + data) @@ -430,13 +430,13 @@ class TestProtocolSlate(TestProtocolSpark): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_WIDTH) - self.assertEqual(msg.width, width) + assert msg.width == width def test_get_height(self, cb=None, height=4321): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xea) - self.assertEqual(request.length, 2) - self.assertEqual(request[0], 4) + assert request.opcode == 0xea + assert request.length == 2 + assert request[0] == 4 data = [0x04, 0x00] + list(height.to_bytes(4, byteorder='little')) return NordicData([0xeb, len(data)] + data) @@ -445,13 +445,13 @@ class TestProtocolSlate(TestProtocolSpark): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_HEIGHT) - self.assertEqual(msg.height, height) + assert msg.height == height def test_get_strokes(self, cb=None, count=1024, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xcc) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xcc + assert request.length == 1 + assert request[0] == 0x00 c = list(count.to_bytes(4, byteorder='little')) t = time.strftime('%y%m%d%H%M%S', time.gmtime(ts)) t = [int(i) for i in binascii.unhexlify(t)] @@ -462,9 +462,9 @@ class TestProtocolSlate(TestProtocolSpark): def test_available_files_count(self, cb=None, ndata=1234): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xc1) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xc1 + assert request.length == 1 + assert request[0] == 0x00 data = list(ndata.to_bytes(2, byteorder='little')) return NordicData([0xc2, len(data)] + data) @@ -472,24 +472,24 @@ class TestProtocolSlate(TestProtocolSpark): def test_delete_oldest_file(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xca) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xca + assert request.length == 1 + assert request[0] == 0x00 return SUCCESS super().test_delete_oldest_file(cb or _cb) def test_register_press_button(self, cb=None, uuid='abcdef123456'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe7) - self.assertEqual(request.length, 6) + assert request.opcode == 0xe7 + assert request.length == 6 # no reply super().test_register_press_button(cb or _cb, uuid) def test_register_wait_for_button(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertIsNone(request) + assert request is None return NordicData([0xe4, 0x00]) super().test_register_wait_for_button(cb or _cb) @@ -500,34 +500,34 @@ class TestProtocolIntuosPro(TestProtocolSlate): def test_connect(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xe6) - self.assertEqual(request.length, 6) + assert request.opcode == 0xe6 + assert request.length == 6 return NordicData([0x50, 0x06] + request) # replies with the uuid super().test_connect(cb or _cb) def test_get_name(self, cb=None, name='test dev name'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xdb) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xdb + assert request.length == 1 + assert request[0] == 0x00 return NordicData([0xbc, len(name)] + list(bytes(name, encoding='ascii'))) super().test_get_name(cb or _cb, name=name) def test_set_name(self, cb=None, name='test dev name'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xdb) - self.assertEqual(request.length, len(name)) - self.assertEqual(bytes(request).decode('utf-8'), name) + assert request.opcode == 0xdb + assert request.length == len(name) + assert bytes(request).decode('utf-8') == name return SUCCESS super().test_set_name(cb or _cb, name=name) def test_get_time(self, cb=None, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xd6) - self.assertEqual(request.length, 1) + assert request.opcode == 0xd6 + assert request.length == 1 t = list(int(ts).to_bytes(length=4, byteorder='little')) + [0x00, 0x00] return NordicData([0xbd, len(t)] + t) @@ -535,18 +535,18 @@ class TestProtocolIntuosPro(TestProtocolSlate): def test_set_time(self, cb=None, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb6) - self.assertEqual(request.length, 6) + assert request.opcode == 0xb6 + assert request.length == 6 t = int.from_bytes(request[0:4], byteorder='little') - self.assertEqual(int(t), int(ts)) + assert int(t) == int(ts) return SUCCESS super().test_set_time(cb or _cb, ts=ts) def test_get_fw(self, cb=None, fw='anything-string'): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xb7) - self.assertEqual(request.length, 1) + assert request.opcode == 0xb7 + assert request.length == 1 data = bytes(fw.split('-')[request[0]].encode('utf8')) return NordicData([0xb8, len(data) + 1, 0x00] + list(data)) @@ -554,9 +554,9 @@ class TestProtocolIntuosPro(TestProtocolSlate): def test_get_strokes(self, cb=None, count=1024, ts=time.time()): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xcc) - self.assertEqual(request.length, 1) - self.assertEqual(request[0], 0x00) + assert request.opcode == 0xcc + assert request.length == 1 + assert request[0] == 0x00 c = list(count.to_bytes(4, byteorder='little')) t = list(int(ts).to_bytes(4, byteorder='little')) data = c + t @@ -566,16 +566,16 @@ class TestProtocolIntuosPro(TestProtocolSlate): def test_register_wait_for_button(self, cb=None): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertIsNone(request) + assert request is None return NordicData([0x53, 0x00]) super().test_register_wait_for_button(cb or _cb) def test_get_point_size(self, cb=None, pointsize=12): def _cb(request, requires_reply=True, userdata=None, timeout=5): - self.assertEqual(request.opcode, 0xea) - self.assertEqual(request.length, 2) - self.assertEqual(request[0], 0x14) + assert request.opcode == 0xea + assert request.length == 2 + assert request[0] == 0x14 ps = little_u32(pointsize) return NordicData([0xeb, 6, 0x14, 0x00] + list(ps)) @@ -583,8 +583,4 @@ class TestProtocolIntuosPro(TestProtocolSlate): p = Protocol(self.protocol_version, callback=cb) msg = p.execute(Interactions.GET_POINT_SIZE) - self.assertEqual(msg.point_size, pointsize - 1) - - -if __name__ == "__main__": - unittest.main(sys.argv[1:]) + assert msg.point_size == pointsize - 1