Add feature to Socket so that it can send digests of data it has sent
authorCarl Hetherington <cth@carlh.net>
Wed, 15 Apr 2020 22:38:08 +0000 (00:38 +0200)
committerCarl Hetherington <cth@carlh.net>
Thu, 16 Apr 2020 22:42:54 +0000 (00:42 +0200)
and check those digests on receive.

src/lib/dcpomatic_socket.cc
src/lib/dcpomatic_socket.h
src/lib/digester.cc
src/lib/digester.h
test/socket_test.cc [new file with mode: 0644]
test/wscript

index ca910bb79aaa3fe80dedbc979d56df853c72e5f9..a0a7a1cf3f3d877ee92aa204c2ad3c23f162a754 100644 (file)
@@ -1,5 +1,5 @@
 /*
-    Copyright (C) 2012-2015 Carl Hetherington <cth@carlh.net>
+    Copyright (C) 2012-2020 Carl Hetherington <cth@carlh.net>
 
     This file is part of DCP-o-matic.
 
 #include "dcpomatic_socket.h"
 #include "compose.hpp"
 #include "exceptions.h"
+#include "dcpomatic_assert.h"
 #include <boost/bind.hpp>
 #include <boost/lambda/lambda.hpp>
 #include <iostream>
 
 #include "i18n.h"
 
+using boost::shared_ptr;
+using boost::weak_ptr;
+
 /** @param timeout Timeout in seconds */
 Socket::Socket (int timeout)
        : _deadline (_io_service)
@@ -89,6 +93,10 @@ Socket::write (uint8_t const * data, int size)
        if (ec) {
                throw NetworkError (String::compose (_("error during async_write (%1)"), ec.value ()));
        }
+
+       if (_write_digester) {
+               _write_digester->add (data, static_cast<size_t>(size));
+       }
 }
 
 void
@@ -117,6 +125,10 @@ Socket::read (uint8_t* data, int size)
        if (ec) {
                throw NetworkError (String::compose (_("error during async_read (%1)"), ec.value ()));
        }
+
+       if (_read_digester) {
+               _read_digester->add (data, static_cast<size_t>(size));
+       }
 }
 
 uint32_t
@@ -126,3 +138,98 @@ Socket::read_uint32 ()
        read (reinterpret_cast<uint8_t *> (&v), 4);
        return ntohl (v);
 }
+
+
+void
+Socket::start_read_digest ()
+{
+       DCPOMATIC_ASSERT (!_read_digester);
+       _read_digester.reset (new Digester());
+}
+
+void
+Socket::start_write_digest ()
+{
+       DCPOMATIC_ASSERT (!_write_digester);
+       _write_digester.reset (new Digester());
+}
+
+
+Socket::ReadDigestScope::ReadDigestScope (shared_ptr<Socket> socket)
+       : _socket (socket)
+{
+       socket->start_read_digest ();
+}
+
+
+bool
+Socket::ReadDigestScope::check ()
+{
+       shared_ptr<Socket> sp = _socket.lock ();
+       if (!sp) {
+               return false;
+       }
+
+       return sp->check_read_digest ();
+}
+
+
+Socket::WriteDigestScope::WriteDigestScope (shared_ptr<Socket> socket)
+       : _socket (socket)
+{
+       socket->start_write_digest ();
+}
+
+
+Socket::WriteDigestScope::~WriteDigestScope ()
+{
+       shared_ptr<Socket> sp = _socket.lock ();
+       if (sp) {
+               try {
+                       sp->finish_write_digest ();
+               } catch (...) {
+                       /* If we can't write our digest, something bad has happened
+                        * so let's just let it happen.
+                        */
+               }
+       }
+}
+
+
+bool
+Socket::check_read_digest ()
+{
+       DCPOMATIC_ASSERT (_read_digester);
+       int const size = _read_digester->size ();
+
+       uint8_t ref[size];
+       _read_digester->get (ref);
+
+       /* Make sure _read_digester is gone before we call read() so that the digest
+        * isn't itself digested.
+        */
+       _read_digester.reset ();
+
+       uint8_t actual[size];
+       read (actual, size);
+
+       return memcmp(ref, actual, size) == 0;
+}
+
+void
+Socket::finish_write_digest ()
+{
+       DCPOMATIC_ASSERT (_write_digester);
+       int const size = _write_digester->size();
+
+       uint8_t buffer[size];
+       _write_digester->get (buffer);
+
+       /* Make sure _write_digester is gone before we call write() so that the digest
+        * isn't itself digested.
+        */
+       _write_digester.reset ();
+
+       write (buffer, size);
+}
+
index 870e7315c3126fb67b67fa3b493ce2be34ee8a56..1fa0b046f0c1c7e2490865bfaa41a99114f8dcfb 100644 (file)
 
 */
 
+#include "digester.h"
 #include <boost/asio.hpp>
 #include <boost/noncopyable.hpp>
+#include <boost/scoped_ptr.hpp>
+#include <boost/weak_ptr.hpp>
 
 /** @class Socket
  *  @brief A class to wrap a boost::asio::ip::tcp::socket with some things
@@ -46,8 +49,36 @@ public:
        void read (uint8_t* data, int size);
        uint32_t read_uint32 ();
 
+       class ReadDigestScope
+       {
+       public:
+               ReadDigestScope (boost::shared_ptr<Socket> socket);
+               bool check ();
+       private:
+               boost::weak_ptr<Socket> _socket;
+       };
+
+       /** After one of these is created everything that is sent from the socket will be
+        *  added to a digest.  When the DigestScope is destroyed the digest will be sent
+        *  from the socket.
+        */
+       class WriteDigestScope
+       {
+       public:
+               WriteDigestScope (boost::shared_ptr<Socket> socket);
+               ~WriteDigestScope ();
+       private:
+               boost::weak_ptr<Socket> _socket;
+       };
+
 private:
+       friend class DigestScope;
+
        void check ();
+       void start_read_digest ();
+       bool check_read_digest ();
+       void start_write_digest ();
+       void finish_write_digest ();
 
        Socket (Socket const &);
 
@@ -55,4 +86,6 @@ private:
        boost::asio::deadline_timer _deadline;
        boost::asio::ip::tcp::socket _socket;
        int _timeout;
+       boost::scoped_ptr<Digester> _read_digester;
+       boost::scoped_ptr<Digester> _write_digester;
 };
index 7bcc77646f8245873006fc6f23c675008179b8fd..452452ba4b1c1091622a001c639446d83f2f2989 100644 (file)
@@ -19,6 +19,7 @@
 */
 
 #include "digester.h"
+#include "dcpomatic_assert.h"
 #include <nettle/md5.h>
 #include <iomanip>
 #include <cstdio>
@@ -67,3 +68,16 @@ Digester::get () const
 
        return _digest.get ();
 }
+
+void
+Digester::get (uint8_t* buffer) const
+{
+       md5_digest (&_context, MD5_DIGEST_SIZE, buffer);
+}
+
+
+int
+Digester::size () const
+{
+       return MD5_DIGEST_SIZE;
+}
index bec1f6416c02e42096c15691ae8eb595731362ee..6cdaf2331b8bda04a7aec588c659add946f8c79c 100644 (file)
@@ -40,6 +40,10 @@ public:
 
        std::string get () const;
 
+       void get (uint8_t* buffer) const;
+
+       int size () const;
+
 private:
        mutable md5_ctx _context;
        mutable boost::optional<std::string> _digest;
diff --git a/test/socket_test.cc b/test/socket_test.cc
new file mode 100644 (file)
index 0000000..562d106
--- /dev/null
@@ -0,0 +1,167 @@
+/*
+    Copyright (C) 2020 Carl Hetherington <cth@carlh.net>
+
+    This file is part of DCP-o-matic.
+
+    DCP-o-matic is free software; you can redistribute it and/or modify
+    it under the terms of the GNU General Public License as published by
+    the Free Software Foundation; either version 2 of the License, or
+    (at your option) any later version.
+
+    DCP-o-matic is distributed in the hope that it will be useful,
+    but WITHOUT ANY WARRANTY; without even the implied warranty of
+    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+    GNU General Public License for more details.
+
+    You should have received a copy of the GNU General Public License
+    along with DCP-o-matic.  If not, see <http://www.gnu.org/licenses/>.
+
+*/
+
+#include "lib/server.h"
+#include "lib/dcpomatic_socket.h"
+#include <boost/thread.hpp>
+#include <boost/test/unit_test.hpp>
+#include <boost/shared_ptr.hpp>
+#include <cstring>
+#include <iostream>
+
+using boost::shared_ptr;
+using boost::bind;
+
+#define TEST_SERVER_PORT 9142
+#define TEST_SERVER_BUFFER_LENGTH 1024
+
+
+class TestServer : public Server
+{
+public:
+       TestServer (bool digest)
+               : Server (TEST_SERVER_PORT, 30)
+               , _buffer (new uint8_t[TEST_SERVER_BUFFER_LENGTH])
+               , _size (0)
+               , _result (false)
+               , _digest (digest)
+       {
+               _thread = boost::thread(bind(&TestServer::run, this));
+       }
+
+       ~TestServer ()
+       {
+               stop ();
+               _thread.join ();
+               delete[] _buffer;
+       }
+
+       void expect (int size)
+       {
+               boost::mutex::scoped_lock lm (_mutex);
+               _size = size;
+       }
+
+       uint8_t const * buffer() const {
+               return _buffer;
+       }
+
+       void await ()
+       {
+               boost::mutex::scoped_lock lm (_mutex);
+               if (_size) {
+                       _condition.wait (lm);
+               }
+       }
+
+       bool result () const {
+               return _result;
+       }
+
+private:
+       void handle (boost::shared_ptr<Socket> socket)
+       {
+               boost::mutex::scoped_lock lm (_mutex);
+               BOOST_REQUIRE (_size);
+               if (_digest) {
+                       Socket::ReadDigestScope ds (socket);
+                       socket->read (_buffer, _size);
+                       _size = 0;
+                       _condition.notify_one ();
+                       _result = ds.check();
+               } else {
+                       socket->read (_buffer, _size);
+                       _size = 0;
+                       _condition.notify_one ();
+               }
+       }
+
+       boost::thread _thread;
+       boost::mutex _mutex;
+       boost::condition _condition;
+       uint8_t* _buffer;
+       int _size;
+       bool _result;
+       bool _digest;
+};
+
+
+void
+send (shared_ptr<Socket> socket, char const* message)
+{
+       socket->write (reinterpret_cast<uint8_t const *>(message), strlen(message) + 1);
+}
+
+/** Basic test to see if Socket can send and receive data */
+BOOST_AUTO_TEST_CASE (socket_basic_test)
+{
+       TestServer server(false);
+       server.expect (13);
+
+       shared_ptr<Socket> socket (new Socket);
+       socket->connect (boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), TEST_SERVER_PORT));
+       send (socket, "Hello world!");
+
+       server.await ();
+       BOOST_CHECK_EQUAL(strcmp(reinterpret_cast<char const *>(server.buffer()), "Hello world!"), 0);
+}
+
+
+/** Check that the socket "auto-digest" creation works */
+BOOST_AUTO_TEST_CASE (socket_digest_test1)
+{
+       TestServer server(false);
+       server.expect (13 + 16);
+
+       shared_ptr<Socket> socket(new Socket);
+       socket->connect (boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), TEST_SERVER_PORT));
+       {
+               Socket::WriteDigestScope ds(socket);
+               send (socket, "Hello world!");
+       }
+
+       server.await ();
+       BOOST_CHECK_EQUAL(strcmp(reinterpret_cast<char const *>(server.buffer()), "Hello world!"), 0);
+
+       /* printf "%s\0" "Hello world!" | md5sum" in bash */
+       char ref[] = "\x59\x86\x88\xed\x18\xc8\x71\xdd\x57\xb9\xb7\x9f\x4b\x03\x14\xcf";
+       BOOST_CHECK_EQUAL (memcmp(server.buffer() + 13, ref, 16), 0);
+}
+
+
+/** Check that the socket "auto-digest" round-trip works */
+BOOST_AUTO_TEST_CASE (socket_digest_test2)
+{
+       TestServer server(true);
+       server.expect (13);
+
+       shared_ptr<Socket> socket(new Socket);
+       socket->connect (boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), TEST_SERVER_PORT));
+       {
+               Socket::WriteDigestScope ds(socket);
+               send (socket, "Hello world!");
+       }
+
+       server.await ();
+       BOOST_CHECK_EQUAL(strcmp(reinterpret_cast<char const *>(server.buffer()), "Hello world!"), 0);
+
+       BOOST_CHECK (server.result());
+}
+
index a1a78eaea01dffb1d88de1cefddce6e877fb9364..4e5c57e0df366b5b0e9e29162da0fa57319ce5d2 100644 (file)
@@ -109,6 +109,7 @@ def build(bld):
                  silence_padding_test.cc
                  shuffler_test.cc
                  skip_frame_test.cc
+                 socket_test.cc
                  srt_subtitle_test.cc
                  ssa_subtitle_test.cc
                  stream_test.cc