protobuf具有很强的反射机制,利用这一个特性可以根据message的type name 自动创建message 对象
如上,是muduo书中的一张图,大致描述了protobuf的构建
对于上图,说一说我自己对于protbuf的理解
message类型
在构建时,都会生成一个 Descriptor对象
(与每一个类型的message是一一对应的)type name
与message的descriptor
进行绑定,调用DescriptorPool
中的类方法FindMessageTypeByName
,就可以通过type name
返回对应message 的descriptor
hash_map
将Descriptor与Message进行绑定, 通过Descriptor可以找到对应的message,这个就是MessageFactory
类所干的事情。通过上面机制,就可以根据type name创建具体具体消息对象了,具体步骤如下
generated_pool()
方法,获取DescriptorPool
。DescriptorPool
中的FindMessageTypeByName()
方法,通过type name
获取descriptor
。generated_factory()
方法,获取MessageFactory
。MessageFactory
中的GetPrototype()
方法,通过2
中获取的descriptor
找到对应的message
。message
中的new()
,创建message的具体实例对象实现如下
google::protobuf::Message* ProtobufCodec::createMessage(const std::string& typeName)
{
google::protobuf::Message* message = NULL;
//通过generated_pool找到一个DEScriptorPool对象,然后通过Find***name知道DescriptorPool对象
const google::protobuf::Descriptor* descriptor =
google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(typeName);
if (descriptor)
{
const google::protobuf::Message* prototype =
google::protobuf::MessageFactory::generated_factory()->GetPrototype(descriptor);
if (prototype)
{
message = prototype->New();
}
}
return message;
}
因为一系列因素,protobuf的默认序列化格式是没有包含消息的长度和类型的
(因为很多场景下并不需要消息的长度和类型,所以protobuf序列化并没有包含这些,若在某些需要的场景下,可以在用户层自己实现就可以)
codec是一层间接性处理函数,拦截处理connect和server之间的数据,
发送消息时,将message根据自定义的protobuf传输格式,转化为Buffer对象
接收消息时,将Buffer根据自定义的protobuf传输格式,转化为message对象
传输格式如下图所示
用C struct伪代码描述
struct ProtobufTransportFormat __attribute__ ((__packed__))
{
int32_t len;
int32_t nameLen;
char typeName[nameLen];
char protobufData[len-nameLen-8];
int32_t checkSum; // adler32 of nameLen, typeName and protobufData 类似于校验和
}
checksum是否必要?
该字段作用类似与校验和,用来验证网络包在传输过程中是否损坏。
虽然TCP是可靠的传输协议,有CRC-32校验,但是网络传输必须要考虑数据损坏情况,对于关键的网络应用,check sum是必不可少的。
根据规定好protobuf的传输格式,就可以进行书写了,过程十分简单,照着传输格式一步一步来就行
//将message填入buffer
void ProtobufCodec::fillEmptyBuffer(Buffer* buf, const google::protobuf::Message& message)
{
// buf->retrieveAll();
assert(buf->readableBytes() == 0);
//先获取消息类型名,用于接受方利用反射机制自动创建消息
const std::string& typeName = message.GetTypeName();
int32_t nameLen = static_cast<int32_t>(typeName.size()+1);
buf->appendInt32(nameLen);
buf->append(typeName.c_str(), nameLen);
// code copied from MessageLite::SerializeToArray() and MessageLite::SerializePartialToArray().
GOOGLE_DCHECK(message.IsInitialized()) << InitializationErrorMessage("serialize", message);
#if GOOGLE_PROTOBUF_VERSION > 3009002
//获取消息大小
int byte_size = google::protobuf::internal::ToIntSize(message.ByteSizeLong());
#else
int byte_size = message.ByteSize();
#endif
//确保buf的可写空间足够
buf->ensureWritableBytes(byte_size);
uint8_t* start = reinterpret_cast<uint8_t*>(buf->beginWrite());
uint8_t* end = message.SerializeWithCachedSizesToArray(start);
if (end - start != byte_size)
{
#if GOOGLE_PROTOBUF_VERSION > 3009002
ByteSizeConsistencyError(byte_size, google::protobuf::internal::ToIntSize(message.ByteSizeLong()), static_cast<int>(end - start));
#else
ByteSizeConsistencyError(byte_size, message.ByteSize(), static_cast<int>(end - start));
#endif
}
buf->hasWritten(byte_size);
int32_t checkSum = static_cast<int32_t>(
::adler32(1,
reinterpret_cast<const Bytef*>(buf->peek()),
static_cast<int>(buf->readableBytes())));
buf->appendInt32(checkSum);
assert(buf->readableBytes() == sizeof nameLen + nameLen + byte_size + sizeof checkSum);
int32_t len = sockets::hostToNetwork32(static_cast<int32_t>(buf->readableBytes()));
buf->prepend(&len, sizeof len);
}
有了protobuf的传输格式,解析消息也十分简单
//将buf中的消息parse为message
MessagePtr ProtobufCodec::parse(const char* buf, int len, ErrorCode* error)
{
MessagePtr message;
// check sum
//校验
//check sum在char最后存放
int32_t expectedCheckSum = asInt32(buf + len - kHeaderLen);
int32_t checkSum = static_cast<int32_t>(
::adler32(1,
reinterpret_cast<const Bytef*>(buf),
static_cast<int>(len - kHeaderLen)));
if (checkSum == expectedCheckSum) //校验成功
{
// get message type name
//获取message的type name
//先获取typename的长度
int32_t nameLen = asInt32(buf);
if (nameLen >= 2 && nameLen <= len - 2*kHeaderLen)
{
//获取type name,减1是去掉\0
std::string typeName(buf + kHeaderLen, buf + kHeaderLen + nameLen - 1);
// create message object
//通过得到的tyoe name 利用protobuf的反射机制,自动创建对应消息类型的message
message.reset(createMessage(typeName));
if (message)
{
// parse from buffer
const char* data = buf + kHeaderLen + nameLen;
int32_t dataLen = len - nameLen - 2*kHeaderLen;
if (message->ParseFromArray(data, dataLen))
{
*error = kNoError;
}
else
{
//parse出错
*error = kParseError;
}
}
else
{
//未知的type name
*error = kUnknownMessageType;
}
}
else
{
//typename 那么长度出问题
*error = kInvalidNameLen;
}
}
else
{
// Check sum 出问题
*error = kCheckSumError;
}
return message;
}
有了fillEmptyBuffer
和parse
,对其进一步封装,便于muduo库回调函数的使用
void ProtobufCodec::onMessage(const TcpConnectionPtr& conn,
Buffer* buf,
Timestamp receiveTime)
{
while (buf->readableBytes() >= kMinMessageLen + kHeaderLen)
{
//peekInt32 从可读位置读32个字节,转化为int32
const int32_t len = buf->peekInt32();
if (len > kMaxMessageLen || len < kMinMessageLen)
{
errorCallback_(conn, buf, receiveTime, kInvalidLength);
break;
}
else if (buf->readableBytes() >= implicit_cast<size_t>(len + kHeaderLen))
{
ErrorCode errorCode = kNoError;
//将message解析为message(errorCode为传出参数)
MessagePtr message = parse(buf->peek()+kHeaderLen, len, &errorCode);
if (errorCode == kNoError && message)
{
messageCallback_(conn, message, receiveTime);
buf->retrieve(kHeaderLen+len);
}
else
{
errorCallback_(conn, buf, receiveTime, errorCode);
break;
}
}
else
{
break;
}
}
}
void send(const muduo::net::TcpConnectionPtr& conn,
const google::protobuf::Message& message)
{
// FIXME: serialize to TcpConnection::outputBuffer()
muduo::net::Buffer buf;
fillEmptyBuffer(&buf, message);
conn->send(&buf);
}
消息类型有很多种,服务器需要根据不同的消息类型去调用不同的回调处理函数,这就是消息分发器的用处
Muduo库中的实现是通过map
,将不同消息的Descriptor*
与callback
对应处理函数进行绑定,服务器接收到消息后,通过调用类方法GetDescriptor
获取descriptor*
,从而找到对对应的callback
,若没有找到
就调用defaultCallback_
。
typedef std::map<const google::protobuf::Descriptor*, std::shared_ptr<Callback> > CallbackMap;
CallbackMap callbacks_;
ProtobufMessageCallback defaultCallback_;
通过消息类型,调用回调函数
void onProtobufMessage(const muduo::net::TcpConnectionPtr& conn,
const MessagePtr& message,
muduo::Timestamp receiveTime) const
{
//Message是基类,所有类型的消息都是继承Message的,调用基类中的GetDescriptor方法,获得该消息绑定的
//descriptor,然后通过map找到对应的回调函数
CallbackMap::const_iterator it = callbacks_.find(message->GetDescriptor());
//调用找到的回调处理函数
if (it != callbacks_.end())
{
it->second->onMessage(conn, message, receiveTime);
}
else
{
//没有找到
defaultCallback_(conn, message, receiveTime);
}
}
注册回调函数
//回调注册,给每一个类型的Message注册回调函数
template<typename T>
void registerMessageCallback(const typename CallbackT<T>::ProtobufMessageTCallback& callback)
{
std::shared_ptr<CallbackT<T> > pd(new CallbackT<T>(callback));
//在map中完成注册
callbacks_[T::descriptor()] = pd;
}
package muduo;
option java_package = "muduo.codec.tests";
option java_outer_classname = "QueryProtos";
message Query {
required int64 id = 1;
required string questioner = 2;
repeated string question = 3;
}
message Answer {
required int64 id = 1;
required string questioner = 2;
required string answerer = 3;
repeated string solution = 4;
}
message Empty {
optional int32 id = 1;
}
#include "examples/protobuf/codec/codec.h"
#include "examples/protobuf/codec/dispatcher.h"
#include "query.pb.h"
#include "muduo/base/Logging.h"
#include "muduo/base/Mutex.h"
#include "muduo/net/EventLoop.h"
#include "muduo/net/TcpServer.h"
#include
#include
using namespace muduo;
using namespace muduo::net;
typedef std::shared_ptr<muduo::Query> QueryPtr;
typedef std::shared_ptr<muduo::Answer> AnswerPtr;
class QueryServer : noncopyable
{
public:
QueryServer(EventLoop* loop,
const InetAddress& listenAddr)
: server_(loop, listenAddr, "QueryServer"),
dispatcher_(std::bind(&QueryServer::onUnknownMessage, this, _1, _2, _3)),
codec_(std::bind(&ProtobufDispatcher::onProtobufMessage, &dispatcher_, _1, _2, _3))
{
//为不同类型消息注册回调函数
dispatcher_.registerMessageCallback<muduo::Query>(
std::bind(&QueryServer::onQuery, this, _1, _2, _3));
dispatcher_.registerMessageCallback<muduo::Answer>(
std::bind(&QueryServer::onAnswer, this, _1, _2, _3));
//给server绑定connection回调
server_.setConnectionCallback(
std::bind(&QueryServer::onConnection, this, _1));
//给server绑定来消息回调
server_.setMessageCallback(
std::bind(&ProtobufCodec::onMessage, &codec_, _1, _2, _3));
}
void start()
{
server_.start();
}
private:
void onConnection(const TcpConnectionPtr& conn)
{
LOG_INFO << conn->peerAddress().toIpPort() << " -> "
<< conn->localAddress().toIpPort() << " is "
<< (conn->connected() ? "UP" : "DOWN");
}
void onUnknownMessage(const TcpConnectionPtr& conn,
const MessagePtr& message,
Timestamp)
{
LOG_INFO << "onUnknownMessage: " << message->GetTypeName();
conn->shutdown();
}
void onQuery(const muduo::net::TcpConnectionPtr& conn,
const QueryPtr& message,
muduo::Timestamp)
{
LOG_INFO << "onQuery:\n" << message->GetTypeName() << message->DebugString();
Answer answer;
answer.set_id(1);
answer.set_questioner("Chen Shuo");
answer.set_answerer("blog.csdn.net/Solstice");
answer.add_solution("Jump!");
answer.add_solution("Win!");
codec_.send(conn, answer);
conn->shutdown();
}
void onAnswer(const muduo::net::TcpConnectionPtr& conn,
const AnswerPtr& message,
muduo::Timestamp)
{
LOG_INFO << "onAnswer: " << message->GetTypeName();
conn->shutdown();
}
TcpServer server_;
ProtobufDispatcher dispatcher_;
ProtobufCodec codec_;
};
int main(int argc, char* argv[])
{
LOG_INFO << "pid = " << getpid();
if (argc > 1)
{
EventLoop loop;
uint16_t port = static_cast<uint16_t>(atoi(argv[1]));
InetAddress serverAddr(port);
QueryServer server(&loop, serverAddr);
server.start();
loop.loop();
}
else
{
printf("Usage: %s port\n", argv[0]);
}
}
#include "examples/protobuf/codec/dispatcher.h"
#include "examples/protobuf/codec/codec.h"
#include "query.pb.h"
#include "muduo/base/Logging.h"
#include "muduo/base/Mutex.h"
#include "muduo/net/EventLoop.h"
#include "muduo/net/TcpClient.h"
#include
#include
using namespace muduo;
using namespace muduo::net;
typedef std::shared_ptr<muduo::Empty> EmptyPtr;
typedef std::shared_ptr<muduo::Answer> AnswerPtr;
google::protobuf::Message* messageToSend;
class QueryClient : noncopyable
{
public:
QueryClient(EventLoop* loop,
const InetAddress& serverAddr)
: loop_(loop),
client_(loop, serverAddr, "QueryClient"),
dispatcher_(std::bind(&QueryClient::onUnknownMessage, this, _1, _2, _3)),
codec_(std::bind(&ProtobufDispatcher::onProtobufMessage, &dispatcher_, _1, _2, _3))
{
dispatcher_.registerMessageCallback<muduo::Answer>(
std::bind(&QueryClient::onAnswer, this, _1, _2, _3));
dispatcher_.registerMessageCallback<muduo::Empty>(
std::bind(&QueryClient::onEmpty, this, _1, _2, _3));
client_.setConnectionCallback(
std::bind(&QueryClient::onConnection, this, _1));
client_.setMessageCallback(
std::bind(&ProtobufCodec::onMessage, &codec_, _1, _2, _3));
}
void connect()
{
client_.connect();
}
private:
void onConnection(const TcpConnectionPtr& conn)
{
LOG_INFO << conn->localAddress().toIpPort() << " -> "
<< conn->peerAddress().toIpPort() << " is "
<< (conn->connected() ? "UP" : "DOWN");
if (conn->connected())
{
codec_.send(conn, *messageToSend);
}
else
{
loop_->quit();
}
}
void onUnknownMessage(const TcpConnectionPtr&,
const MessagePtr& message,
Timestamp)
{
LOG_INFO << "onUnknownMessage: " << message->GetTypeName();
}
void onAnswer(const muduo::net::TcpConnectionPtr&,
const AnswerPtr& message,
muduo::Timestamp)
{
LOG_INFO << "onAnswer:\n" << message->GetTypeName() << message->DebugString();
}
void onEmpty(const muduo::net::TcpConnectionPtr&,
const EmptyPtr& message,
muduo::Timestamp)
{
LOG_INFO << "onEmpty: " << message->GetTypeName();
}
EventLoop* loop_;
TcpClient client_;
ProtobufDispatcher dispatcher_;
ProtobufCodec codec_;
};
int main(int argc, char* argv[])
{
LOG_INFO << "pid = " << getpid();
if (argc > 2)
{
EventLoop loop;
uint16_t port = static_cast<uint16_t>(atoi(argv[2]));
InetAddress serverAddr(argv[1], port);
muduo::Query query;
query.set_id(1);
query.set_questioner("Chen Shuo");
query.add_question("Running?");
muduo::Empty empty;
messageToSend = &query;
if (argc > 3 && argv[3][0] == 'e')
{
messageToSend = ∅
}
QueryClient client(&loop, serverAddr);
client.connect();
loop.loop();
}
else
{
printf("Usage: %s host_ip port [q|e]\n", argv[0]);
}
}