blob: 2b0309ed1a2bd7ac0db8a75d7a88fcd2d9a75c7b [file] [log] [blame]
/*
* libjingle
* Copyright 2004--2005, Google Inc.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* 3. The name of the author may not be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
* EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
* OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <iostream>
#include <cassert>
#include "talk/base/natsocketfactory.h"
namespace talk_base {
class NATSocket : public AsyncSocket {
public:
NATSocket(Socket* socket, const SocketAddress& server_addr)
: async_(false), connected_(false), server_addr_(server_addr),
socket_(socket), buf_(0), size_(0) {
}
NATSocket(AsyncSocket* socket, const SocketAddress& server_addr)
: async_(true), connected_(false), server_addr_(server_addr),
socket_(socket), buf_(0), size_(0) {
socket->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
socket->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
}
virtual ~NATSocket() {
delete socket_;
delete buf_;
}
SocketAddress GetLocalAddress() const {
return socket_->GetLocalAddress();
}
SocketAddress GetRemoteAddress() const {
return remote_addr_; // will be ANY if not connected
}
int Bind(const SocketAddress& addr) {
return socket_->Bind(addr);
}
int Connect(const SocketAddress& addr) {
connected_ = true;
remote_addr_ = addr;
return 0;
}
int Send(const void *pv, size_t cb) {
assert(connected_);
return SendInternal(pv, cb, remote_addr_);
}
int SendTo(const void *pv, size_t cb, const SocketAddress& addr) {
assert(!connected_);
return SendInternal(pv, cb, addr);
}
int SendInternal(const void *pv, size_t cb, const SocketAddress& addr) {
size_t size = cb + addr.Size_();
char* buf = new char[size];
Encode(static_cast<const char*>(pv), cb, buf, size, addr);
int result = socket_->SendTo(buf, size, server_addr_);
delete buf;
if (result < 0) {
return result;
} else {
assert(result == static_cast<int>(size)); // TODO: This isn't fair.
return (int)((size_t)result - addr.Size_());
}
}
int Recv(void *pv, size_t cb) {
SocketAddress addr;
return RecvFrom(pv, cb, &addr);
}
int RecvFrom(void *pv, size_t cb, SocketAddress *paddr) {
// Make sure we have enough room to read the requested amount plus the
// header address.
SocketAddress remote_addr;
Grow(cb + remote_addr.Size_());
// Read the packet from the socket.
int result = socket_->RecvFrom(buf_, size_, &remote_addr);
if (result < 0)
return result;
assert(remote_addr == server_addr_);
// TODO: we need better framing so that we know how many bytes we can
// return before we need to read the next address. For UDP, this will be
// fine as long as the reader always reads everything in the packet.
assert((size_t)result < size_);
// Decode the wire packet into the actual results.
SocketAddress real_remote_addr;
size_t real_size = cb;
Decode(buf_, result, pv, &real_size, &real_remote_addr);
// Make sure this packet should be delivered before returning it.
if (!connected_ || (real_remote_addr == remote_addr_)) {
if (paddr)
*paddr = real_remote_addr;
return (int)real_size;
} else {
std::cerr << "Dropping packet from unknown remote address: "
<< real_remote_addr.ToString() << std::endl;
return 0; // Tell the caller we didn't read anything
}
}
int Close() {
connected_ = false;
remote_addr_ = SocketAddress();
return socket_->Close();
}
int Listen(int backlog) {
assert(false); // not yet implemented
return 0;
}
Socket* Accept(SocketAddress *paddr) {
assert(false); // not yet implemented
return 0;
}
AsyncSocket* asyncsocket() {
assert(async_);
return static_cast<AsyncSocket*>(socket_);
}
int GetError() const { return socket_->GetError(); }
void SetError(int error) { socket_->SetError(error); }
ConnState GetState() const { return connected_ ? CS_CONNECTED : CS_CLOSED; }
virtual int EstimateMTU(uint16* mtu) { return socket_->EstimateMTU(mtu); }
virtual int SetOption(Option opt, int value) { return socket_->SetOption(opt, value); }
void OnReadEvent(AsyncSocket* socket) {
assert(socket == socket_);
SignalReadEvent(this);
}
void OnWriteEvent(AsyncSocket* socket) {
assert(socket == socket_);
SignalWriteEvent(this);
}
private:
// Makes sure the buffer is at least the given size.
void Grow(size_t new_size) {
if (size_ < new_size) {
delete buf_;
size_ = new_size;
buf_ = new char[size_];
}
}
// Encodes the given data and intended remote address into a packet to send
// to the NAT server.
void Encode(const char* data, size_t data_size, char* buf, size_t buf_size,
const SocketAddress& remote_addr) {
assert(buf_size == data_size + remote_addr.Size_());
remote_addr.Write_(buf, (int)buf_size);
std::memcpy(buf + remote_addr.Size_(), data, data_size);
}
// Decodes the given packet from the NAT server into the actual remote
// address and data.
void Decode(const char* data, size_t data_size, void* buf, size_t* buf_size,
SocketAddress* remote_addr) {
assert(data_size >= remote_addr->Size_());
assert(data_size <= *buf_size + remote_addr->Size_());
remote_addr->Read_(data, (int)data_size);
*buf_size = data_size - remote_addr->Size_();
std::memcpy(buf, data + remote_addr->Size_(), *buf_size);
}
bool async_;
bool connected_;
SocketAddress remote_addr_;
SocketAddress server_addr_; // address of the NAT server
Socket* socket_;
char* buf_;
size_t size_;
};
NATSocketFactory::NATSocketFactory(
SocketFactory* factory, const SocketAddress& nat_addr)
: factory_(factory), nat_addr_(nat_addr) {
}
Socket* NATSocketFactory::CreateSocket(int type) {
assert(type == SOCK_DGRAM); // TCP is not yet suported
return new NATSocket(factory_->CreateSocket(type), nat_addr_);
}
AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
assert(type == SOCK_DGRAM); // TCP is not yet suported
return new NATSocket(factory_->CreateAsyncSocket(type), nat_addr_);
}
} // namespace talk_base