From e420b04fda0bfa015eb3926d01438e6e4833b374 Mon Sep 17 00:00:00 2001 From: Peter Hutterer Date: Wed, 14 Aug 2019 19:08:30 +1000 Subject: [PATCH] protocol: add stroke parsing This is a slightly different model as the messages, primarily because it's not quite as model-specific. So there's only one parse function and it can handle both file types that we currently support (intuos pro and the spark/slate bits). All wrapped into their own classes to make future extensions a bit easier. It's not a 1:1 implementation of the tuhi/wacom.py bits either because we now know about a few extra bits like the flags in stroke headers. Most importantly though, this can be easily tested now, with one test case searching for raw logs in $XDG_DATA_HOME/tuhi/raw and parsing all of those. So the more files are sitting in there (i.e. the more the tablet is used), the better the test suite becomes. Signed-off-by: Peter Hutterer --- .circleci/config.yml | 2 +- test/test_strokes.py | 460 +++++++++++++++++++++++++++++++++ tuhi/protocol.py | 595 +++++++++++++++++++++++++++++++++++++++++++ tuhi/util.py | 9 + tuhi/wacom.py | 11 +- 5 files changed, 1066 insertions(+), 11 deletions(-) create mode 100755 test/test_strokes.py diff --git a/.circleci/config.yml b/.circleci/config.yml index eea17ea..93dbd35 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ jobs: steps: - run: command: | - dnf install -y meson gettext python3-devel pygobject3-devel python3-flake8 desktop-file-utils libappstream-glib python3-pytest + dnf install -y meson gettext python3-devel pygobject3-devel python3-flake8 desktop-file-utils libappstream-glib python3-pytest python3-pyxdg python3-pyyaml - checkout - run: command: | diff --git a/test/test_strokes.py b/test/test_strokes.py new file mode 100755 index 0000000..7c74c21 --- /dev/null +++ b/test/test_strokes.py @@ -0,0 +1,460 @@ +#!/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2019 Red Hat, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import os +import sys +import unittest +import xdg.BaseDirectory +from pathlib import Path +import yaml +import logging + +sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/..') # noqa + +from tuhi.protocol import * +from tuhi.util import flatten + +logger = logging.getLogger('tuhi') # piggyback the debug messages +logger.setLevel(logging.DEBUG) + + +class TestLogFiles(unittest.TestCase): + ''' + Special test class that loads a yaml file created by tuhi compiles + the StrokeData from it. This class autogenerates its own tests, see + the main() handling. + ''' + def load_pen_data(self, filename): + with open(filename) as fd: + yml = yaml.load(fd, Loader=yaml.Loader) + # all recv lists that have source PEN + pendata = [d['recv'] for d in yml['data'] if 'recv' in d and 'source' in d and d['source'] == 'PEN'] + return list(flatten(pendata)) + + def _test_file(self, fname): + data = self.load_pen_data(fname) + if not data: # Recordings without Pen data can be skipped + raise unittest.SkipTest() + StrokeFile(data) + + +class TestStrokeParsers(unittest.TestCase): + def test_identify_file_header(self): + data = [0x67, 0x82, 0x69, 0x65] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER) + data = [0x62, 0x38, 0x62, 0x74] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER) + + others = [ + # with header + [0xff, 0x62, 0x38, 0x62, 0x74], + [0xff, 0x67, 0x82, 0x69, 0x65], + # wrong size + [0x67, 0x82, 0x69], + [0x67, 0x82], + [0x67], + [0x62, 0x38, 0x62], + [0x62, 0x38], + [0x62], + # wrong numbers + [0x67, 0x82, 0x69, 0x64], + [0x62, 0x38, 0x62, 0x73], + ] + for data in others: + self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.FILE_HEADER, msg=data) + + def test_identify_stroke_header(self): + data = [0xff, 0xfa] # two bytes are enough to identify + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER) + + data = [0x3, 0xfa] # lowest bits set, not a correct packet but identify doesn't care + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER) + + data = [0xfc, 0xfa] # lowest bits unset, must be something else + self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_HEADER) + + def test_identify_stroke_point(self): + data = [0xff, 0xff, 0xff] # three bytes are enough to identify + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT) + + data = [0xff, 0xff, 0xff, 1, 2, 3, 4, 5, 6] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT) + + # wrong header, but observed in the wild + data = [0xbf, 0xff, 0xff, 1, 2, 3, 4, 5, 6] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.POINT) + + data = [0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] # stroke end + self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.POINT) + + data = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] # EOF + self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.POINT) + + def test_identify_stroke_lost_point(self): + data = [0xff, 0xdd, 0xdd] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.LOST_POINT) + + def test_identify_eof(self): + data = [0xff] * 9 + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.EOF) + + def test_identify_stroke_end(self): + data = [0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.STROKE_END) + + def test_identify_delta(self): + for i in range(256): + data = [i, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6] + if i & 0x3 == 0: + self.assertEqual(StrokeDataType.identify(data), StrokeDataType.DELTA, f'packet: {data}') + else: + self.assertNotEqual(StrokeDataType.identify(data), StrokeDataType.DELTA, f'packet: {data}') + + def test_parse_stroke_header(self): + F_NEW_LAYER = 0x40 + F_PEN_ID = 0x80 + pen_type = 3 + flags = F_NEW_LAYER | pen_type + + data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01] + packet = StrokeHeader(data) + self.assertEqual(packet.size, 9) + self.assertEqual(packet.is_new_layer, 1) + self.assertEqual(packet.pen_id, 0) + self.assertEqual(packet.pen_type, pen_type) + self.assertEqual(packet.timestamp, 1565750047) + + # new layer off + flags = pen_type + data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01] + packet = StrokeHeader(data) + self.assertEqual(packet.size, 9) + self.assertEqual(packet.is_new_layer, 0) + self.assertEqual(packet.pen_id, 0) + self.assertEqual(packet.pen_type, pen_type) + self.assertEqual(packet.timestamp, 1565750047) + + # pen type change + pen_type = 1 + flags = F_NEW_LAYER | pen_type + data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01] + packet = StrokeHeader(data) + self.assertEqual(packet.size, 9) + self.assertEqual(packet.is_new_layer, 1) + self.assertEqual(packet.pen_id, 0) + self.assertEqual(packet.pen_type, pen_type) + self.assertEqual(packet.timestamp, 1565750047) + + # with pen id + flags = F_NEW_LAYER | F_PEN_ID | pen_type + pen_id = [0xff, 0x0a, 0x87, 0x75, 0x80, 0x28, 0x42, 0x00, 0x10] + data = [0xff, 0xfa, flags, 0x1f, 0x73, 0x53, 0x5d, 0x2e, 0x01] + pen_id + packet = StrokeHeader(data) + self.assertEqual(packet.size, 18) + self.assertEqual(packet.is_new_layer, 1) + self.assertEqual(packet.pen_id, 0x100042288075870a) + self.assertEqual(packet.pen_type, pen_type) + self.assertEqual(packet.timestamp, 1565750047) + + def test_parse_stroke_point(self): + # 0xff means 2 bytes each for abs coords + data = [0xff, 0xff, 0xff, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6] + packet = StrokePoint(data) + self.assertEqual(packet.size, 9) + self.assertEqual(packet.x, 0x0201) + self.assertEqual(packet.y, 0x0403) + self.assertEqual(packet.p, 0x0605) + self.assertIsNone(packet.dx) + self.assertIsNone(packet.dy) + self.assertIsNone(packet.dp) + + # 0xbf means: 1 byte for pressure delta, i.e. the 0x6 is skipped + data = [0xbf, 0xff, 0xff, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6] + packet = StrokePoint(data) + self.assertEqual(packet.size, 8) + self.assertEqual(packet.x, 0x0201) + self.assertEqual(packet.y, 0x0403) + self.assertIsNone(packet.p) + self.assertIsNone(packet.dx) + self.assertIsNone(packet.dy) + self.assertEqual(packet.dp, 0x5) + + def test_parse_lost_point(self): + data = [0xff, 0xdd, 0xdd, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6] + packet = StrokeLostPoint(data) + self.assertEqual(packet.size, 9) + self.assertEqual(packet.nlost, 0x0201) + + def test_parse_eof(self): + data = [0xff] * 9 + packet = StrokeEOF(data) + self.assertEqual(packet.size, 9) + + data = [0xfc] + [0xff] * 6 + packet = StrokeEOF(data) + self.assertEqual(packet.size, 7) + + def test_parse_delta(self): + x_delta = 0b00001000 # noqa + x_abs = 0b00001100 # noqa + y_delta = 0b00100000 # noqa + y_abs = 0b00110000 # noqa + p_delta = 0b10000000 # noqa + p_abs = 0b11000000 # noqa + + flags = x_delta + data = [flags, 1] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertEqual(packet.dx, 1) + self.assertIsNone(packet.dy) + self.assertIsNone(packet.dp) + self.assertIsNone(packet.x) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + + flags = y_delta + data = [flags, 2] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertIsNone(packet.dx) + self.assertEqual(packet.dy, 2) + self.assertIsNone(packet.dp) + self.assertIsNone(packet.x) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + + flags = p_delta + data = [flags, 3] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertIsNone(packet.dx) + self.assertIsNone(packet.dy) + self.assertEqual(packet.dp, 3) + self.assertIsNone(packet.x) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + + flags = x_delta | p_delta + data = [flags, 3, 5] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertEqual(packet.dx, 3) + self.assertIsNone(packet.dy) + self.assertEqual(packet.dp, 5) + self.assertIsNone(packet.x) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + + flags = x_delta | y_delta | p_delta + data = [flags, 3, 5, 7] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertEqual(packet.dx, 3) + self.assertEqual(packet.dy, 5) + self.assertEqual(packet.dp, 7) + self.assertIsNone(packet.x) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + + flags = x_abs | y_abs | p_abs + data = [flags, 1, 2, 3, 4, 5, 6] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertEqual(packet.x, 0x0201) + self.assertEqual(packet.y, 0x0403) + self.assertEqual(packet.p, 0x0605) + self.assertIsNone(packet.dx) + self.assertIsNone(packet.dy) + self.assertIsNone(packet.dp) + + flags = y_abs + data = [flags, 2, 3] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertIsNone(packet.x) + self.assertEqual(packet.y, 0x0302) + self.assertIsNone(packet.p) + self.assertIsNone(packet.dx) + self.assertIsNone(packet.dy) + self.assertIsNone(packet.dp) + + flags = x_abs | y_delta | p_delta + data = [flags, 2, 3, 4, 5] + packet = StrokeDelta(data) + self.assertEqual(packet.size, len(data)) + self.assertEqual(packet.x, 0x0302) + self.assertIsNone(packet.y) + self.assertIsNone(packet.p) + self.assertIsNone(packet.dx) + self.assertEqual(packet.dy, 4) + self.assertEqual(packet.dp, 5) + + +class TestStrokes(unittest.TestCase): + def test_single_stroke(self): + data = ''' + 67 82 69 65 22 73 53 5d 00 00 02 00 00 00 00 00 ff fa c3 1f + 73 53 5d 2e 01 ff 0a 87 75 80 28 42 00 10 ff ff ff 41 3c 13 + 30 72 03 c8 01 cb 04 e8 ff 01 0f 06 e0 ff 5f 07 e0 ff 91 08 + c0 78 09 e8 02 01 2f 0a e8 fe 01 ce 0a e0 ff 55 0b e8 01 01 + dc 0b c8 fe 6f 0c a8 03 ff 75 a8 fe 02 f4 a8 02 ff 06 88 ff + ee 88 01 f1 80 fd 88 ff f4 88 01 f4 a8 fe ff f7 a8 01 01 fa + 88 ff f7 80 f4 a8 01 01 f3 88 ff 01 a8 ff ff fd a0 03 f4 80 + d3 a0 02 df a0 02 be a8 fe 02 b2 a0 02 c7 e8 ff 02 8e 0a e8 + fe 01 15 08 e8 01 03 91 04 e8 f3 26 94 01 ff fa 03 21 73 53 + 5d 46 02 ff ff ff d3 6f 5a 38 c0 03 e8 fb 2b 5c 04 a8 fa 2c + 78 a8 02 ff 4e a8 01 ff 0c a0 fd be 88 02 d6 a8 ff fe f7 a8 + ff ff d3 a0 01 fd a8 01 ff 15 88 fd 15 a0 ff d9 80 fd 88 fe + f4 a8 01 ff f1 a8 ff fe fd 80 09 a8 fe fc 03 88 01 03 a8 ff + ff 1e a8 ff ff c6 88 01 1c a8 ff ff f7 a8 02 01 ee 28 01 fe + a8 01 ff fd 88 03 f4 a8 ff ff 0b a8 01 fe ff 80 f6 80 0c a0 + ff f1 a8 ff 01 0f 88 01 02 a8 01 fe 10 a8 fe 01 05 88 01 e9 + 88 fe e2 a0 02 1e a0 02 0c a8 ff ff f4 a8 01 02 fd a8 ff fe + ee 80 ee a8 01 ff 03 28 ff fe 80 09 88 01 06 a8 ff fe eb a8 + 02 01 29 88 ff 01 a0 01 02 a8 01 ff fe a8 ff 01 f6 80 0a 88 + 02 03 08 ff 88 01 33 a8 01 ff 09 a0 01 24 a8 01 01 fd a0 ff + df a0 02 09 a0 02 0c a8 ff 02 f4 a8 ff 03 03 a8 01 02 f1 88 + fe 0f a0 01 eb 80 06 a8 ff 01 f1 28 fe 02 88 01 09 a8 ff 01 + 03 a8 ff 01 f4 a8 02 ff 0f a0 01 09 88 02 eb a8 ff ff fa a8 + 01 ff d6 a8 ff ff ee a8 ff fd c4 a0 fe dc a8 01 fd 12 a8 01 + fd c4 a8 02 fd dc a8 01 fd a6 a0 fc 94 e8 ff fc 8a 06 fc ff + ff ff ff ff ff + ''' + b = [d.strip() for d in data.split(' ') if d not in ['', '\n']] + b = [int(x, 16) for x in b] + + p = Protocol(ProtocolVersion.INTUOS_PRO, None, None) + p.parse_pen_data(b) + + def test_double_stroke(self): + data = ''' + 67 82 69 65 28 c7 53 5d 00 00 02 00 00 00 00 00 ff fa c3 26 + c7 53 5d a8 01 ff 0a 87 75 80 28 42 00 10 ff ff ff f6 29 da + 1d a4 04 e0 02 0c 06 e0 04 b7 06 e0 04 47 07 e8 01 08 2e 08 + e0 09 06 09 e0 09 ab 09 e8 01 0a 4d 0a a8 ff 0b 72 a0 0b 06 + a0 05 e5 a8 ff 05 f7 a8 02 01 15 a8 fe 01 d6 a0 ff e5 a8 01 + fe 15 a8 ff fe f7 a8 02 fd e5 a8 01 fe eb a8 01 fe ff a0 fd + ff a8 ff fd 01 a8 02 fd e6 a8 fd fd 18 a0 fd d9 a0 ff dc a8 + fd fc 04 a8 01 fe 49 28 ff fc a0 fc ff a0 fa ed a8 01 f9 12 + a8 fe f9 03 a8 01 f8 fd a8 01 fa d0 a0 fd b5 a8 01 fd d9 a0 + fd bb e8 ff f9 12 08 e8 ff cc 3b 03 ff fa 03 26 c7 53 5d 6d + 00 ff ff ff 42 2f fb 1c 65 0a e0 fd f1 0b e8 ff fd b7 0c a8 + 02 03 75 a0 04 09 a0 05 18 a8 03 0a 15 a8 01 0d eb a8 03 0e + d0 a8 ff 0f d0 a8 ff 0d 12 a0 0b e2 a8 fd 08 fa 28 ff 07 a8 + ff 05 f1 20 3a a8 01 02 0c a8 f6 03 f4 28 01 c5 a8 05 fa 0f + a8 02 c8 f7 a0 fc fd a0 fe fa a8 02 24 ee a8 ff f6 fe a8 02 + f6 f3 a8 ff f6 d9 a0 f7 ee a0 f6 d0 a8 ff f9 e2 a0 fb f1 a8 + ff fa dc e8 01 fd 2a 0d e8 ff f9 57 0a e8 fa e7 9c 04 fc ff + ff ff ff ff ff + ''' + b = [d.strip() for d in data.split(' ') if d not in ['', '\n']] + b = [int(x, 16) for x in b] + + p = Protocol(ProtocolVersion.INTUOS_PRO, None, None) + p.parse_pen_data(b) + + def test_quint_stroke(self): + data = ''' + 67 82 69 65 cc ce 53 5d 00 00 05 00 00 00 00 00 ff fa c3 c7 + ce 53 5d 8d 00 ff 0a 87 75 80 28 42 00 10 ff ff ff 95 29 a9 + 1e 23 06 e0 01 a0 07 e8 ff ff db 08 e8 01 01 e9 09 e0 ff bb + 0a e0 01 81 0b c0 1a 0c 80 7b a0 02 f1 a8 ff 03 09 a8 ff 03 + 18 a0 04 e2 a8 fd 02 03 a8 02 03 df a8 fe 03 f7 a0 03 0c a8 + ff 06 e2 a0 03 e8 a8 01 03 f4 a8 01 03 ee a8 01 01 fe a0 03 + d8 a8 01 03 c7 a0 02 fd a8 01 03 06 a8 ff 01 f7 a8 ff 02 1b + 08 ff a8 ff 03 1e a0 02 09 a0 03 15 a8 fe 05 fa a8 ff 02 18 + a0 04 10 a8 ff 01 f0 80 fa a0 ff 06 a0 fe 0f a0 fc f1 a8 01 + fc 01 28 02 fa a8 01 f8 17 a8 01 f9 e8 28 03 f9 28 01 f9 28 + ff f9 a0 f9 02 a0 fa ff a0 fb ff a8 01 fb e9 a8 01 fa c3 a8 + 01 f9 91 e0 f8 cc 0a e8 05 c4 76 06 ff fa 03 c8 ce 53 5d 10 + 02 ff ff ff bc 36 40 1e 00 08 e8 f2 26 14 09 e8 02 0d d7 09 + e8 01 0b 7c 0a a0 0b 69 a0 0a 51 a8 ff 06 f1 a8 ff 07 e8 a8 + fe 03 e5 a8 f7 35 eb 88 ff fa a8 10 99 f7 20 01 28 f8 2f a8 + ff fe d6 a8 01 fa eb 28 01 fb 28 02 fa 28 03 fc a8 02 fa eb + 28 03 fc a8 02 f8 df a8 03 fb f4 a8 01 f9 d6 a8 02 f9 b8 a8 + 02 f7 03 a0 f8 94 e8 02 f7 a3 0a e0 f9 bd 06 e8 0a eb 05 02 + ff fa 03 c8 ce 53 5d 5b 00 ff ff ff 0a 40 19 1e 5b 05 e8 fe + 0d b7 06 e8 fd 0e b3 07 e8 ff 07 88 08 e0 07 2d 09 a0 06 54 + a0 06 75 a8 ff 04 18 a0 05 e8 a8 ff 03 d0 a8 fe 01 fd a8 01 + 02 dc a8 fe 02 ee a0 ff fa 80 f4 a8 ff fe fd a8 01 ff df a8 + 01 fe df a0 fd fa a8 03 ff fd a0 fd 0c a8 01 fd f4 a8 01 fd + f1 a0 fc f4 a8 01 fe af a0 fc b8 a8 02 fb c1 e0 fb 64 0b e8 + 01 fa b3 07 e8 03 0b 18 02 ff fa 03 c9 ce 53 5d b2 02 ff ff + ff dc 46 82 1d 44 06 e8 fe 1b da 06 e8 ff 08 67 07 e8 ff 05 + ee 07 e8 01 08 81 08 a8 fe 06 7e a8 03 06 42 a8 fe 07 12 a8 + 01 04 cd a8 ff 04 d0 a8 fe 03 c4 a8 ff 01 d0 88 ff 21 a8 01 + fe fd a8 01 fd 0c a8 01 fb fd a8 02 fa d9 a8 01 fc fd a8 01 + fa e5 a0 fb fa a8 01 fb dc a8 01 f9 f1 e8 ff fa be 0a e0 fa + 41 09 e8 ff fa 3d 07 e8 ff f9 00 05 ff fa 03 c9 ce 53 5d a9 + 00 ff ff ff 0b 1d ae 23 e4 03 e8 21 fc b6 04 e8 20 fb 3a 05 + c8 06 09 06 e8 01 ff 02 07 e8 02 ff d1 07 88 02 63 a0 fe 06 + 88 03 e5 a8 04 fe d9 a8 04 ff 1e a8 06 fe f4 a8 04 ff e5 a8 + 39 f2 06 88 09 36 a8 0b fc dc a8 c9 0a 03 88 02 f4 a8 03 03 + 27 a8 05 f8 d3 a8 0a 05 12 a8 03 f9 c7 88 07 18 a8 06 07 f2 + a8 fe f8 f8 88 0b 1e a8 04 01 e0 a8 10 ff fd a8 fd fd 03 a8 + 05 03 0c a8 09 fb f6 88 06 fc a8 01 02 02 a8 ff fa df a8 07 + 03 09 a8 0a ff 09 88 02 df a8 f9 04 1b a8 04 ff f4 a8 01 08 + 21 a8 fe f8 fd a8 fd 07 03 a8 05 ff fd a8 02 ff e8 a8 05 ff + 18 a8 fd 01 03 28 01 fd a8 02 03 15 a8 f9 04 f1 a8 ff fb fa + a0 02 e2 a8 09 ff 15 a8 f9 01 09 88 fc 1b 88 01 e8 a8 fb 01 + 30 28 09 04 a8 f8 f9 06 a8 07 01 f1 a8 fb 02 0c 28 fe fe a8 + ff 01 24 a0 fe f4 a8 f5 03 f4 a8 01 0a f7 a8 fb f5 d3 a0 ff + 09 a8 05 03 03 a8 fa 02 1b a8 fc fe ee a8 fd 02 eb a0 fe f6 + a8 f8 09 22 a8 ff f9 dc 28 fd fe a8 fe 04 f1 a8 03 ff 06 a8 + f3 03 fd a8 fa 03 f1 a8 f6 ff 03 a8 f9 03 ee a8 fb 02 e5 a8 + f4 05 dc a8 fa 06 03 a8 f1 fc d6 a8 f4 08 df a8 cd 03 c1 e8 + fb 01 d1 08 e8 fd 01 58 06 e8 67 0b 8b 02 fc ff ff ff ff ff + ff + ''' + b = [d.strip() for d in data.split(' ') if d not in ['', '\n']] + b = [int(x, 16) for x in b] + + p = Protocol(ProtocolVersion.INTUOS_PRO, None, None) + p.parse_pen_data(b) + + +# How does this work? +# The test generater spits out a simple test function that just calls the +# real test function +# +# Then we search for all yaml files with logs we have and generate a unique +# test name (based on the timestamp) for that file. Result: we have +# a unittests function for each log file found in the directory. +def generator(logfile): + def test(self): + self._test_file(logfile) + return test + + +def search_for_tests(): + basedir = Path(xdg.BaseDirectory.xdg_data_home) / 'tuhi' + for idx, logfile in enumerate(basedir.glob('**/raw/log-*.yaml')): + with open(logfile) as fd: + yml = yaml.load(fd, Loader=yaml.Loader) + timestamp = yml['time'] + test_name = f'test_log_{timestamp}' + test = generator(logfile) + setattr(TestLogFiles, test_name, test) + + +search_for_tests() + +if __name__ == '__main__': + unittest.main() diff --git a/tuhi/protocol.py b/tuhi/protocol.py index c7c2018..3ae60fb 100644 --- a/tuhi/protocol.py +++ b/tuhi/protocol.py @@ -53,6 +53,11 @@ import binascii import enum import time +import logging +from collections import namedtuple +from .util import list2hex + +logger = logging.getLogger('tuhi.protocol') def little_u16(x): @@ -83,6 +88,20 @@ def little_u32(x): return int.from_bytes(x, byteorder='little') +def little_u64(x): + ''' + Convert to or from a 64-bit integer to a little-endian 4-byte array. If + passed an integer, the return value is a 8-byte array. If passed a + 4-byte array, the return value is a 64-bit integer. + ''' + if isinstance(x, int): + assert(x <= 0xffffffffffffffff and x >= 0x0000000000000000) + return x.to_bytes(8, byteorder='little') + else: + assert(len(x) == 8) + return int.from_bytes(x, byteorder='little') + + class Interactions(enum.Enum): '''All possible interactions with a device. Not all of these interactions may be available on any specific device.''' @@ -263,6 +282,18 @@ class Protocol(object): ''' return self.get(key, *args, **kwargs).execute() + def parse_pen_data(self, data): + ''' + Parse the given pen data. Returns a list of :class:`StrokeFile` objects. + ''' + files = [] + while data: + logger.debug(f'... remaining data ({len(data)}): {list2hex(data)}') + sf = StrokeFile(data) + files.append(sf) + data = data[sf.bytesize:] + return files + class NordicData(list): ''' @@ -1209,3 +1240,567 @@ class MsgRegisterWaitForButtonSlateOrIntuosPro(Msg): self.protocol_version = ProtocolVersion.INTUOS_PRO else: raise UnexpectedReply(reply) + + +class StrokeParsingError(ProtocolError): + def __init__(self, message, data=[]): + self.message = message + self.data = data + + def __repr__(self): + if self.data: + datastr = f' data: {list2hex(self.data)}' + else: + datastr = '' + return f'{self.message}{datastr}' + + def __str__(self): + return self.__repr__() + + +class StrokeDataType(enum.Enum): + UNKNOWN = enum.auto() + FILE_HEADER = enum.auto() + STROKE_HEADER = enum.auto() + STROKE_END = enum.auto() + POINT = enum.auto() + DELTA = enum.auto() + EOF = enum.auto() + LOST_POINT = enum.auto() + + @classmethod + def identify(cls, data): + ''' + Returns the identified packet type for the next packet. + ''' + header = data[0] + nbytes = bin(header).count('1') + payload = data[1:1 + nbytes] + + # Note: the order of the checks below matters + + # Known file format headers. This is just a version number, I think. + if data[0:4] == [0x67, 0x82, 0x69, 0x65] or \ + data[0:4] == [0x62, 0x38, 0x62, 0x74]: + return StrokeDataType.FILE_HEADER + + # End of stroke, but can sometimes mean end of file too + if data[0:7] == [0xfc, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]: + return StrokeDataType.STROKE_END + + if payload == [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]: + return StrokeDataType.EOF + + # all special headers have the lowest two bits set + if header & 0x3 == 0: + return StrokeDataType.DELTA + + if not payload: + return StrokeDataType.UNKNOWN + + if payload[0] == 0xfa or payload[0:3] == [0xff, 0xee, 0xee]: + return StrokeDataType.STROKE_HEADER + + if payload[0:2] == [0xff, 0xff]: + return StrokeDataType.POINT + + if payload[0:2] == [0xdd, 0xdd]: + return StrokeDataType.LOST_POINT + + return StrokeDataType.UNKNOWN + + +class StrokeFile(object): + ''' + Represents a single file as coming from the device. Note that pen data + received from the device may include more than one file, this object is + merely the first represented in this file. + + .. attribute:: bytesize + + The length in bytes of the data consumed. + + .. attribute:: timestamp + + Creation time of the drawing (when the button was pressed) or None where + this is not supported by the device. + + .. attribute:: strokes + + A list of strokes, each a list of Point(x, y, p) namedtuples. + Coordinates for the points are in absolute device units. + + ''' + def __init__(self, data): + self.data = data + self.file_header = StrokeFileHeader(data[:16]) + + logger.debug(self.file_header) + + self.bytesize = self.file_header.size + + offset = self.file_header.size + self.timestamp = self.file_header.timestamp + self.bytesize += self._parse_data(data[offset:]) + + def _parse_data(self, data): + # the data formats we return + Stroke = namedtuple('Stroke', ['points']) + Point = namedtuple('Point', ['x', 'y', 'p']) + + last_point = None # abs coords for most recent point + last_delta = Point(0, 0, 0) # delta accumulates + + strokes = [] # all strokes + points = [] # Points of current strokes + + consumed = 0 + + # Note about the below: this was largely reverse-engineered because + # the specs we have access to are either ambiguous or outright wrong. + # + # First byte is a bitmask that seems to indicate how many bytes. + # + # Where the header byte has the lowest two bits set, it can be + # one of several packages: + # - a StrokeHeader [0xfa] to indicate a new stroke + # - end of stroke - all payload bytes are 0xff + # - lost point [0xdd, 0xdd] - firmware couldn't record a point + # - a StrokePoint [0xff, 0xff] a fully specified point. Always + # the first after a StrokeHeader but may also appear elsewhere. + # + # Where the header byte has the lowest two bits on zero, it is + # a StrokeDelta, a variable sized payload following the header, + # values depend on the bits set in the header. + # + # In theory all the header packages should have a header of 0xff, + # but they don't. End of stroke may have 0xfc, a StrokePoint + # sometimes has 0xbf. It is unknown why. + # + # The StrokePoint is strange since if can sometimes contain deltas + # (bitmask 0xbf). So it's just a delta with an extra two bytes for + # headers, so what is the point of it? Presumably a firmware bug or + # something. + while data: + packet_type = StrokeDataType.identify(data) + logger.debug(f'Next data packet {packet_type.name}: {list2hex(data[:16])} …') + + packet = None + if packet_type == StrokeDataType.UNKNOWN: + packet = StrokePacketUnknown(data) + elif packet_type == StrokeDataType.FILE_HEADER: + # This code shouldn't be triggered, we handle the file + # header outside this function. + packet = StrokeFileHeader(data) + logger.error(f'Unexpected file header at byte {consumed}: {packet}') + break + elif packet_type == StrokeDataType.STROKE_END: + packet = StrokeEndOfStroke(data) + if points: + strokes.append(Stroke(points)) + points = [] + elif packet_type == StrokeDataType.EOF: + # EOF means pack + packet = StrokeEOF(data) + if points: + strokes.append(Stroke(points)) + points = [] + data = data[packet.size:] + consumed += packet.size + break + elif packet_type == StrokeDataType.STROKE_HEADER: + # New stroke means resetting delta and storing the last + # stroke + packet = StrokeHeader(data) + last_delta = Point(0, 0, 0) + if points: + strokes.append(Stroke(points)) + points = [] + elif packet_type == StrokeDataType.LOST_POINT: + # We don't yet handle lost points + packet = StrokeLostPoint(data) + elif (packet_type == StrokeDataType.POINT or + packet_type == StrokeDataType.DELTA): + # POINT and DELTA *should* be two different packages but + # sometimes a POINT includes a delta for a coordinate. So + # all a POINT is is a delta with an added [0xff 0xff] after + # the header byte. The StrokePoint packet hides this so we + # can process both the same way. + if packet_type == StrokeDataType.POINT: + packet = StrokePoint(data) + if last_point is None: + last_point = Point(packet.x, packet.y, packet.p) + else: + packet = StrokeDelta(data) + + # Compression algorithm in the device basically keeps a + # cumulative delta so that + # P0 = absolute x, y, z + # P1 = P0 + d1 + # P2 = P0 + 2*d1 + d2 + # P3 = P0 + 3*d1 + 2*d2 + d3 + # And we use that here by just keeping the last delta + # around, adding to it where necessary and then adding it to + # the last point we have. + # + # Whenever we get an absolute coordinate, the delta resets + # to 0. Since this is per axis, our fictional P4 may be: + # P4(x) = P0 + 4*d1 + 3*d2 + 2*d3 + d4 + # P4(y) = P0 + 4*d1 + 2*d3 ... d2 and d4 are missing (zero) + # P4(p) = P4(p) .... absolute + dx, dy, dp = last_delta + x, y, p = last_point + if packet.dx is not None: + dx += packet.dx + elif packet.x is not None: + x = packet.x + dx = 0 + + if packet.dy is not None: + dy += packet.dy + elif packet.y is not None: + y = packet.y + dy = 0 + + if packet.dp is not None: + dp += packet.dp + elif packet.p is not None: + p = packet.p + dp = 0 + + # dx,dy,dp ... are cumulative deltas for this packet + # x,y,p ... most recent known abs coordinates + # add those two together and we have the real coordinates + # and the baseline for the next point + last_delta = Point(dx, dy, dp) + current_point = Point(x, y, p) + last_point = Point(current_point.x + last_delta.x, + current_point.y + last_delta.y, + current_point.p + last_delta.p) + logger.debug(f'Calculated point: {last_point}') + points.append(last_point) + else: + # should never get here + raise StrokeParsingError(f'Failed to parse', data[:16]) + + logger.debug(f'Offset {consumed}: {packet}') + consumed += packet.size + data = data[packet.size:] + + self.strokes = strokes + return consumed + + +class StrokePacket(object): + ''' + .. attribute: size + + Size of the packet in bytes + ''' + def __init__(self): + self.size = 0 + + +class StrokePacketUnknown(StrokePacket): + def __init__(self, data): + header = data[0] + nbytes = bin(header).count('1') + self.size = 1 + nbytes + self.data = data[:self.size] + + def __repr__(self): + return f'Unknown packet: {list2hex(self.data)}' + + +class StrokeFileHeader(StrokePacket): + ''' + Each data packet has a file header consisting of 4 bytes file version + number and optionally extra data. + + .. attribute: timestamp + + The timestamp of this drawing or ``None`` where not available. + + .. attribute: nstrokes + + The count of strokes within this drawing or ``None`` where not + available. This count is inaccurate anyway, so it should only be + used for basic internal checks. + + ''' + def __init__(self, data): + key = little_u32(data[:4]) + file_formats = { + little_u32([0x67, 0x82, 0x69, 0x65]): self._parse_intuos_pro, + little_u32([0x62, 0x38, 0x62, 0x74]): self._parse_spark, + } + + self.timestamp = None + self.nstrokes = None + + try: + func = file_formats[key] + func(data) + except KeyError: + raise StrokeParsingError(f'Unknown file format:', data[:4]) + + def __repr__(self): + t = time.strftime("%y%m%d%H%M%S", time.gmtime(self.timestamp)) + return f'FileHeader: time: {t}, stroke count: {self.nstrokes}' + + def _parse_intuos_pro(self, data): + self.timestamp = int.from_bytes(data[4:8], byteorder='little') + # plus two bytes for ms, always zero + self.nstrokes = int.from_bytes(data[10:14], byteorder='little') + # plus two bytes always zero + self.size = 16 + + def _parse_spark(self, data): + self.size = 4 + + +class StrokeHeader(StrokePacket): + ''' + .. attribute:: pen_id + + The pen serial number or 0 if none is set + + .. attribute:: pen_type + + The pen type + + .. attribute:: timestamp + + The timestamp of this stroke or None if none was recorded + + .. attribute:: time_offset + + The time offset in ms since powerup or None if this stroke has an + absolute timestamp. + + .. attribute:: is_new_layer + + True if this stroke is on a new layer + ''' + def __init__(self, data): + header = data[0] + payload = data[1:] + self.size = bin(header).count('1') + 1 + if payload[0] == 0xfa: + self._parse_intuos_pro(data, header, payload) + elif payload[0:3] == [0xff, 0xee, 0xee]: + self._parse_slate(data, header, payload) + else: + raise StrokeParsingError(f'Invalid StrokeHeader, expected ff fa or ff ee.', data[:8]) + + def _parse_slate(self, data, header, payload): + self.pen_id = 0 + self.pen_type = 0 + self.is_new_layer = False + + self.timestamp = None + self.time_offset = little_u16(payload[4:6]) * 5 # in 5ms resolution + + # On the first stroke after the file header, this packet is 6 bytes + # only. Other strokes have 8 bytes but the last two bytes are always + # zero. + + def _parse_intuos_pro(self, data, header, payload): + flags = payload[1] + needs_pen_id = flags & 0x80 + self.pen_type = flags & 0x3f + self.is_new_layer = (flags & 0x40) != 0 + self.pen_id = 0 + self.timestamp = int.from_bytes(payload[2:6], byteorder='little') + self.time_offset = None + # FIXME: plus two bytes for milis + self.size = bin(header).count('1') + 1 + + # if the pen id flag is set, the pen ID comes in the next 8-byte + # packet (plus 0xff header) + if needs_pen_id: + pen_packet = data[self.size + 1:] + if not pen_packet: + raise StrokeParsingError('Missing pen ID packet') + + header = data[0] + if header != 0xff: + raise StrokeParsingError(f'Unexpected pen id packet header: {header}.', data[:9]) + + nbytes = bin(header).count('1') + self.pen_id = little_u64(pen_packet[:8]) + self.size += 1 + nbytes + + def __repr__(self): + if self.timestamp is not None: + t = time.strftime(f'%y%m%d%H%M%S', time.gmtime(self.timestamp)) + else: + t = time.strftime(f'boot+{self.time_offset/1000}s') + return f'StrokeHeader: time: {t} new layer: {self.is_new_layer}, pen type: {self.pen_type}, pen id: {self.pen_id:#x}' + + +class StrokeDelta(object): + ''' + .. attribute:: x + + The absolute x coordinate or None if this is packet contains a delta + + .. attribute:: y + + The absolute y coordinate or None if this is packet contains a delta + + .. attribute:: p + + The absolute pressure coordinate or None if this is packet contains a delta + + .. attribute:: dx + + The x delta or None if this is packet contains an absolute + coordinate + + .. attribute:: dy + + The y delta or None if this is packet contains an absolute + coordinate + + .. attribute:: dp + + The pressure delta or None if this is packet contains an absolute + coordinate + ''' + def __init__(self, data): + def extract(mask, databytes): + value = None + delta = None + size = 0 + if mask == 0: + # No data for this coordinate + pass + elif mask == 1: + # Supposedly not implemented by any device. + # + # If this would exist, it would throw off the byte count + # anyway, so this cannot ever exist without breaking + # everything. + raise NotImplementedError('This device is not supposed to be exist') + elif mask == 2: + # 8 bit delta + delta = int.from_bytes(bytes([databytes[0]]), byteorder='little', signed=True) + if delta == 0: + raise StrokeParsingError(f'StrokeDelta: invalid delta of zero', data) + assert delta != 0 + size = 1 + elif mask == 3: + # full abs coordinate + value = little_u16(databytes[:2]) + size = 2 + return value, delta, size + + if (data[0] & 0b11) != 0: + raise NotImplementedError(f'LSB two bits set in mask - this is not supposed to happen') + + xmask = (data[0] & 0b00001100) >> 2 + ymask = (data[0] & 0b00110000) >> 4 + pmask = (data[0] & 0b11000000) >> 6 + + offset = 1 + x, dx, size = extract(xmask, data[offset:]) + offset += size + y, dy, size = extract(ymask, data[offset:]) + offset += size + p, dp, size = extract(pmask, data[offset:]) + offset += size + + # Note: any of these will be None depending on the packet + self.dx = dx + self.dy = dy + self.dp = dp + self.x = x + self.y = y + self.p = p + + self.size = offset + + def __repr__(self): + def printstring(delta, abs): + return f'{delta:+5d}' if delta is not None \ + else f'{abs:5d}' if abs is not None \ + else ' ' # noqa + strx = printstring(self.dx, self.x) + stry = printstring(self.dy, self.y) + strp = printstring(self.dp, self.p) + + return f'StrokeDelta: {strx}/{stry} pressure: {strp}' + + +class StrokePoint(StrokeDelta): + ''' + A full point identified by three coordinates (x, y, pressure) in + absolute coordinates. + ''' + def __init__(self, data): + header = data[0] + payload = data[1:] + if payload[:2] != [0xff, 0xff]: + raise StrokeParsingError(f'Invalid StrokePoint, expected ff ff ff', data[:9]) + + # This is a wrapper around StrokeDelta which does the mask parsing. + # In theory the StrokePoint would be a separate packet but it + # occasionally uses a header other than 0xff. Which means the packet + # is completely useless and shouldn't exist because now it's just a + # StrokeDelta in the form of [header, 0xff, 0xff, payload] and the + # 0xff just keep the room warm. + + # StrokeDelta assumes the bottom two bits are unset + header &= ~0x3 + super().__init__([header] + payload[2:]) + self.size += 2 + + # self.x = little_u16(data[2:4]) + # self.y = little_u16(data[4:6]) + # self.pressure = little_u16(data[6:8]) + + def __repr__(self): + return f'StrokePoint: {self.x}/{self.y} pressure: {self.p}' + + +class StrokeEOF(StrokePacket): + def __init__(self, data): + header = data[0] + payload = data[1:] + nbytes = bin(header).count('1') + if payload[:nbytes] != [0xff] * nbytes: + raise StrokeParsingError(f'Invalid EOF, expected 0xff only', data[:9]) + self.size = nbytes + 1 + + +class StrokeEndOfStroke(StrokePacket): + def __init__(self, data): + header = data[0] + payload = data[1:] + nbytes = bin(header).count('1') + if payload[:nbytes] != [0xff] * nbytes: + raise StrokeParsingError(f'Invalid EndOfStroke, expected 0xff only', data[:9]) + self.size = nbytes + 1 + self.data = data[:self.size] + + def __repr__(self): + return f'EndOfStroke: {list2hex(self.data)}' + + +class StrokeLostPoint(StrokePacket): + ''' + Marker for lost points that the firmware couldn't record coordinates + for. + + .. attribute:: nlost + + The number of points not recorded. + ''' + def __init__(self, data): + header = data[0] + payload = data[1:] + if payload[:2] != [0xdd, 0xdd]: + raise StrokeParsingError(f'Invalid StrokeLostPoint, expected ff dd dd', data[:9]) + self.nlost = little_u16(payload[2:4]) + self.size = bin(header).count('1') + 1 diff --git a/tuhi/util.py b/tuhi/util.py index e8ac1ee..766d2cd 100644 --- a/tuhi/util.py +++ b/tuhi/util.py @@ -22,3 +22,12 @@ def list2hex(l, groupsize=8): slices.append(s) return ' '.join(slices) + + +def flatten(items): + '''flatten an array of mixed int and arrays into a simple array of int''' + for item in items: + if isinstance(item, int): + yield item + else: + yield from flatten(item) diff --git a/tuhi/wacom.py b/tuhi/wacom.py index 73e13c0..37ee1c2 100644 --- a/tuhi/wacom.py +++ b/tuhi/wacom.py @@ -26,7 +26,7 @@ from .drawing import Drawing from .uhid import UHIDDevice import tuhi.protocol from tuhi.protocol import NordicData, Interactions, Mode, ProtocolVersion -from .util import list2hex +from .util import list2hex, flatten logger = logging.getLogger('tuhi.wacom') @@ -90,15 +90,6 @@ wacom_live_rdesc_template = [ ] -def flatten(items): - '''flatten an array of mixed int and arrays into a simple array of int''' - for item in items: - if isinstance(item, int): - yield item - else: - yield from flatten(item) - - def signed_char_to_int(v): return int.from_bytes([v], byteorder='little', signed=True)