test: switch the stroke tests to use pytest

We're already running the meson test through pytest anyway and pytest is more
powerful than unittest. So let's switch, it's just a search/replace away.

Plus, this way the approach to dynamically create the tests based on the test
logs in the user's home directory is a lot saner.

Signed-off-by: Peter Hutterer <peter.hutterer@who-t.net>
pull/248/head
Peter Hutterer 2020-01-08 09:58:27 +10:00 committed by Benjamin Tissoires
parent fcdacb7187
commit 7a881de9bb
1 changed files with 140 additions and 161 deletions

View File

@ -18,8 +18,8 @@
#
import os
import pytest
import sys
import unittest
import xdg.BaseDirectory
from pathlib import Path
import yaml
@ -34,32 +34,47 @@ logger = logging.getLogger('tuhi') # piggyback the debug messages
logger.setLevel(logging.DEBUG)
class TestLogFiles(unittest.TestCase):
'''
Special test class that loads a yaml file created by tuhi compiles
the StrokeData from it. This class autogenerates its own tests, see
the main() handling.
'''
def load_pen_data(self, filename):
def pytest_generate_tests(metafunc):
# for any test function that takes a "logfile" argument return the list
# of all current logfiles in XDG_DATA_HOME/tuhi
# This means the test gets better the more logfiles are present on the
# user's machine.
if 'logfile' in metafunc.fixturenames:
basedir = Path(xdg.BaseDirectory.xdg_data_home) / 'tuhi'
def loads_and_has_data(filename):
with open(filename) as fd:
try:
yml = yaml.load(fd, Loader=yaml.Loader)
return yml is not None
except Exception as e:
logger.error(f'Exception triggered by file {filename}')
raise e
logfiles = [f for f in basedir.glob('**/raw/log-*.yaml') if loads_and_has_data(f)]
metafunc.parametrize('logfile', logfiles)
def test_log_files(logfile):
def load_pen_data(filename):
with open(filename) as fd:
yml = yaml.load(fd, Loader=yaml.Loader)
# all recv lists that have source PEN
pendata = [d['recv'] for d in yml['data'] if 'recv' in d and 'source' in d and d['source'] == 'PEN']
return list(flatten(pendata))
def _test_file(self, fname):
data = self.load_pen_data(fname)
if not data: # Recordings without Pen data can be skipped
raise unittest.SkipTest()
StrokeFile(data)
data = load_pen_data(logfile)
if not data: # Recordings without Pen data can be skipped
pytest.skip('Recording without pen data')
StrokeFile(data)
class TestStrokeParsers(unittest.TestCase):
class TestStrokeParsers(object):
def test_identify_file_header(self):
data = [0x67, 0x82, 0x69, 0x65]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER)
assert StrokeDataType.identify(data) == StrokeDataType.FILE_HEADER
data = [0x62, 0x38, 0x62, 0x74]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER)
assert StrokeDataType.identify(data) == StrokeDataType.FILE_HEADER
others = [
# with header
@ -77,54 +92,54 @@ class TestStrokeParsers(unittest.TestCase):
[0x62, 0x38, 0x62, 0x73],
]
for data in others:
self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER, msg=data)
assert StrokeDataType.identify(data) != StrokeDataType.FILE_HEADER, data
def test_identify_stroke_header(self):
data = [0xff, 0xfa] # two bytes are enough to identify
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER)
assert StrokeDataType.identify(data) == StrokeDataType.STROKE_HEADER
data = [0x3, 0xfa] # lowest bits set, not a correct packet but identify doesn't care
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER)
assert StrokeDataType.identify(data) == StrokeDataType.STROKE_HEADER
data = [0xfc, 0xfa] # lowest bits unset, must be something else
self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER)
assert StrokeDataType.identify(data) != StrokeDataType.STROKE_HEADER
def test_identify_stroke_point(self):
data = [0xff, 0xff, 0xff] # three bytes are enough to identify
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT)
assert StrokeDataType.identify(data) == StrokeDataType.POINT
data = [0xff, 0xff, 0xff, 1, 2, 3, 4, 5, 6]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT)
assert StrokeDataType.identify(data) == StrokeDataType.POINT
# wrong header, but observed in the wild
data = [0xbf, 0xff, 0xff, 1, 2, 3, 4, 5, 6]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT)
assert StrokeDataType.identify(data) == StrokeDataType.POINT
data = [0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] # stroke end
self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.POINT)
assert StrokeDataType.identify(data) != StrokeDataType.POINT
data = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] # EOF
self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.POINT)
assert StrokeDataType.identify(data) != StrokeDataType.POINT
def test_identify_stroke_lost_point(self):
data = [0xff, 0xdd, 0xdd]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.LOST_POINT)
assert StrokeDataType.identify(data) == StrokeDataType.LOST_POINT
def test_identify_eof(self):
data = [0xff] * 9
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.EOF)
assert StrokeDataType.identify(data) == StrokeDataType.EOF
def test_identify_stroke_end(self):
data = [0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_END)
assert StrokeDataType.identify(data) == StrokeDataType.STROKE_END
def test_identify_delta(self):
for i in range(256):
data = [i, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
if i & 0x3 == 0:
self.assertEqual(StrokeDataType.identify(data), StrokeDataType.DELTA, f'packet: {data}')
assert StrokeDataType.identify(data), StrokeDataType.DELTA == f'packet: {data}'
else:
self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.DELTA, f'packet: {data}')
assert StrokeDataType.identify(data), StrokeDataType.DELTA != f'packet: {data}'
def test_parse_stroke_header(self):
F_NEW_LAYER = 0x40
@ -134,81 +149,81 @@ class TestStrokeParsers(unittest.TestCase):
data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01]
packet = StrokeHeader(data)
self.assertEqual(packet.size, 9)
self.assertEqual(packet.is_new_layer, 1)
self.assertEqual(packet.pen_id, 0)
self.assertEqual(packet.pen_type, pen_type)
self.assertEqual(packet.timestamp, 1565750047)
assert packet.size == 9
assert packet.is_new_layer == 1
assert packet.pen_id == 0
assert packet.pen_type == pen_type
assert packet.timestamp == 1565750047
# new layer off
flags = pen_type
data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01]
packet = StrokeHeader(data)
self.assertEqual(packet.size, 9)
self.assertEqual(packet.is_new_layer, 0)
self.assertEqual(packet.pen_id, 0)
self.assertEqual(packet.pen_type, pen_type)
self.assertEqual(packet.timestamp, 1565750047)
assert packet.size == 9
assert packet.is_new_layer == 0
assert packet.pen_id == 0
assert packet.pen_type == pen_type
assert packet.timestamp == 1565750047
# pen type change
pen_type = 1
flags = F_NEW_LAYER | pen_type
data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01]
packet = StrokeHeader(data)
self.assertEqual(packet.size, 9)
self.assertEqual(packet.is_new_layer, 1)
self.assertEqual(packet.pen_id, 0)
self.assertEqual(packet.pen_type, pen_type)
self.assertEqual(packet.timestamp, 1565750047)
assert packet.size == 9
assert packet.is_new_layer == 1
assert packet.pen_id == 0
assert packet.pen_type == pen_type
assert packet.timestamp == 1565750047
# with pen id
flags = F_NEW_LAYER | F_PEN_ID | pen_type
pen_id = [0xff, 0x0a, 0x87, 0x75, 0x80, 0x28, 0x42, 0x00, 0x10]
data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01] + pen_id
packet = StrokeHeader(data)
self.assertEqual(packet.size, 18)
self.assertEqual(packet.is_new_layer, 1)
self.assertEqual(packet.pen_id, 0x100042288075870a)
self.assertEqual(packet.pen_type, pen_type)
self.assertEqual(packet.timestamp, 1565750047)
assert packet.size == 18
assert packet.is_new_layer == 1
assert packet.pen_id == 0x100042288075870a
assert packet.pen_type == pen_type
assert packet.timestamp == 1565750047
def test_parse_stroke_point(self):
# 0xff means 2 bytes each for abs coords
data = [0xff, 0xff, 0xff, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
packet = StrokePoint(data)
self.assertEqual(packet.size, 9)
self.assertEqual(packet.x, 0x0201)
self.assertEqual(packet.y, 0x0403)
self.assertEqual(packet.p, 0x0605)
self.assertIsNone(packet.dx)
self.assertIsNone(packet.dy)
self.assertIsNone(packet.dp)
assert packet.size == 9
assert packet.x == 0x0201
assert packet.y == 0x0403
assert packet.p == 0x0605
assert packet.dx is None
assert packet.dy is None
assert packet.dp is None
# 0xbf means: 1 byte for pressure delta, i.e. the 0x6 is skipped
data = [0xbf, 0xff, 0xff, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
packet = StrokePoint(data)
self.assertEqual(packet.size, 8)
self.assertEqual(packet.x, 0x0201)
self.assertEqual(packet.y, 0x0403)
self.assertIsNone(packet.p)
self.assertIsNone(packet.dx)
self.assertIsNone(packet.dy)
self.assertEqual(packet.dp, 0x5)
assert packet.size == 8
assert packet.x == 0x0201
assert packet.y == 0x0403
assert packet.p is None
assert packet.dx is None
assert packet.dy is None
assert packet.dp == 0x5
def test_parse_lost_point(self):
data = [0xff, 0xdd, 0xdd, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
packet = StrokeLostPoint(data)
self.assertEqual(packet.size, 9)
self.assertEqual(packet.nlost, 0x0201)
assert packet.size == 9
assert packet.nlost == 0x0201
def test_parse_eof(self):
data = [0xff] * 9
packet = StrokeEOF(data)
self.assertEqual(packet.size, 9)
assert packet.size == 9
data = [0xfc] + [0xff] * 6
packet = StrokeEOF(data)
self.assertEqual(packet.size, 7)
assert packet.size == 7
def test_parse_delta(self):
x_delta = 0b00001000 # noqa
@ -221,93 +236,93 @@ class TestStrokeParsers(unittest.TestCase):
flags = x_delta
data = [flags, 1]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertEqual(packet.dx, 1)
self.assertIsNone(packet.dy)
self.assertIsNone(packet.dp)
self.assertIsNone(packet.x)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
assert packet.size == len(data)
assert packet.dx == 1
assert packet.dy is None
assert packet.dp is None
assert packet.x is None
assert packet.y is None
assert packet.p is None
flags = y_delta
data = [flags, 2]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertIsNone(packet.dx)
self.assertEqual(packet.dy, 2)
self.assertIsNone(packet.dp)
self.assertIsNone(packet.x)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
assert packet.size == len(data)
assert packet.dx is None
assert packet.dy == 2
assert packet.dp is None
assert packet.x is None
assert packet.y is None
assert packet.p is None
flags = p_delta
data = [flags, 3]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertIsNone(packet.dx)
self.assertIsNone(packet.dy)
self.assertEqual(packet.dp, 3)
self.assertIsNone(packet.x)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
assert packet.size == len(data)
assert packet.dx is None
assert packet.dy is None
assert packet.dp == 3
assert packet.x is None
assert packet.y is None
assert packet.p is None
flags = x_delta | p_delta
data = [flags, 3, 5]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertEqual(packet.dx, 3)
self.assertIsNone(packet.dy)
self.assertEqual(packet.dp, 5)
self.assertIsNone(packet.x)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
assert packet.size == len(data)
assert packet.dx == 3
assert packet.dy is None
assert packet.dp == 5
assert packet.x is None
assert packet.y is None
assert packet.p is None
flags = x_delta | y_delta | p_delta
data = [flags, 3, 5, 7]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertEqual(packet.dx, 3)
self.assertEqual(packet.dy, 5)
self.assertEqual(packet.dp, 7)
self.assertIsNone(packet.x)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
assert packet.size == len(data)
assert packet.dx == 3
assert packet.dy == 5
assert packet.dp == 7
assert packet.x is None
assert packet.y is None
assert packet.p is None
flags = x_abs | y_abs | p_abs
data = [flags, 1, 2, 3, 4, 5, 6]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertEqual(packet.x, 0x0201)
self.assertEqual(packet.y, 0x0403)
self.assertEqual(packet.p, 0x0605)
self.assertIsNone(packet.dx)
self.assertIsNone(packet.dy)
self.assertIsNone(packet.dp)
assert packet.size == len(data)
assert packet.x == 0x0201
assert packet.y == 0x0403
assert packet.p == 0x0605
assert packet.dx is None
assert packet.dy is None
assert packet.dp is None
flags = y_abs
data = [flags, 2, 3]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertIsNone(packet.x)
self.assertEqual(packet.y, 0x0302)
self.assertIsNone(packet.p)
self.assertIsNone(packet.dx)
self.assertIsNone(packet.dy)
self.assertIsNone(packet.dp)
assert packet.size == len(data)
assert packet.x is None
assert packet.y == 0x0302
assert packet.p is None
assert packet.dx is None
assert packet.dy is None
assert packet.dp is None
flags = x_abs | y_delta | p_delta
data = [flags, 2, 3, 4, 5]
packet = StrokeDelta(data)
self.assertEqual(packet.size, len(data))
self.assertEqual(packet.x, 0x0302)
self.assertIsNone(packet.y)
self.assertIsNone(packet.p)
self.assertIsNone(packet.dx)
self.assertEqual(packet.dy, 4)
self.assertEqual(packet.dp, 5)
assert packet.size == len(data)
assert packet.x == 0x0302
assert packet.y is None
assert packet.p is None
assert packet.dx is None
assert packet.dy == 4
assert packet.dp == 5
class TestStrokes(unittest.TestCase):
class TestStrokes(object):
def test_single_stroke(self):
data = '''
67 82 69 65 22 73 53 5d 00 00 02 00 00 00 00 00 ff fa c3 1f
@ -428,39 +443,3 @@ class TestStrokes(unittest.TestCase):
p = Protocol(ProtocolVersion.INTUOS_PRO, None, None)
p.parse_pen_data(b)
# How does this work?
# The test generater spits out a simple test function that just calls the
# real test function
#
# Then we search for all yaml files with logs we have and generate a unique
# test name (based on the timestamp) for that file. Result: we have
# a unittests function for each log file found in the directory.
def generator(logfile):
def test(self):
self._test_file(logfile)
return test
def search_for_tests():
basedir = Path(xdg.BaseDirectory.xdg_data_home) / 'tuhi'
for logfile in basedir.glob('**/raw/log-*.yaml'):
with open(logfile) as fd:
try:
yml = yaml.load(fd, Loader=yaml.Loader)
if not yml:
continue
timestamp = yml['time']
test_name = f'test_log_{timestamp}'
test = generator(logfile)
setattr(TestLogFiles, test_name, test)
except Exception as e:
logger.error(f'Exception triggered by file {logfile}')
raise e
search_for_tests()
if __name__ == '__main__':
unittest.main()