using System;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Runtime.InteropServices;
namespace ServerLib
{
public class HttpServer : IServer
{
///
/// 服务器IP
///
public string ServerIP { get; private set; }
///
/// 服务器端口
///
public int ServerPort { get; private set; }
///
/// 服务器目录
///
public string ServerRoot { get; private set; }
///
/// 是否运行
///
public bool IsRunning { get; private set; }
///
/// 服务器协议
///
public Protocols Protocol { get; private set; }
///
/// 服务端Socet
///
private TcpListener serverListener;
///
/// 日志接口
///
public ILogger Logger { get; set; }
///
/// SSL证书
///
private X509Certificate serverCertificate = null;
///
/// 构造函数
///
/// IP地址
/// 端口号
/// 根目录
private HttpServer(IPAddress ipAddress, int port, string root)
{
this.ServerIP = ipAddress.ToString();
this.ServerPort = port;
//如果指定目录不存在则采用默认目录
if (!Directory.Exists(root))
this.ServerRoot = AppDomain.CurrentDomain.BaseDirectory;
this.ServerRoot = root;
}
///
/// 构造函数
///
/// IP地址
/// 端口号
/// 根目录
public HttpServer(string ipAddress, int port, string root) :
this(IPAddress.Parse(ipAddress), port, root)
{
}
///
/// 构造函数
///
/// IP地址
/// 端口号
public HttpServer(string ipAddress, int port) :
this(IPAddress.Parse(ipAddress), port, AppDomain.CurrentDomain.BaseDirectory)
{
}
///
/// 构造函数
///
/// 端口号
/// 根目录
public HttpServer(int port, string root) :
this(IPAddress.Loopback, port, root)
{
}
///
/// 构造函数
///
/// 端口号
public HttpServer(int port) :
this(IPAddress.Loopback, port, AppDomain.CurrentDomain.BaseDirectory)
{
}
///
/// 构造函数
///
///
public HttpServer(string ip) :
this(IPAddress.Parse(ip), 80, AppDomain.CurrentDomain.BaseDirectory)
{
}
#region 公开方法
///
/// 开启服务器
///
public void Start()
{
if (IsRunning) return;
//创建服务端Socket
this.serverListener = new TcpListener(IPAddress.Parse(ServerIP), ServerPort);
this.Protocol = serverCertificate == null ? Protocols.Http : Protocols.Https;
this.IsRunning = true;
this.serverListener.Start();
this.Log(string.Format("Sever is running at {0}://{1}:{2}", Protocol.ToString().ToLower(), ServerIP,
ServerPort));
try
{
while (IsRunning)
{
TcpClient client = serverListener.AcceptTcpClient();
Thread requestThread = new Thread(() => { ProcessRequest(client); });
requestThread.Start();
}
}
catch (Exception e)
{
Log(e.Message);
}
}
public HttpServer SetSSL(string certificate)
{
return SetSSL(X509Certificate.CreateFromCertFile(certificate));
}
public HttpServer SetSSL(X509Certificate certifiate)
{
this.serverCertificate = certifiate;
return this;
}
public void Stop()
{
if (!IsRunning) return;
IsRunning = false;
serverListener.Stop();
}
///
/// 设置服务器目录
///
///
public HttpServer SetRoot(string root)
{
if (!Directory.Exists(root))
this.ServerRoot = AppDomain.CurrentDomain.BaseDirectory;
this.ServerRoot = root;
return this;
}
///
/// 获取服务器目录
///
public string GetRoot()
{
return this.ServerRoot;
}
///
/// 设置端口
///
/// 端口号
///
public HttpServer SetPort(int port)
{
this.ServerPort = port;
return this;
}
#endregion
#region 内部方法
///
/// 处理客户端请求
///
/// 客户端Socket
private void ProcessRequest(TcpClient handler)
{
//处理请求
Stream clientStream = handler.GetStream();
//处理SSL
if (serverCertificate != null) clientStream = ProcessSSL(clientStream);
if (clientStream == null) return;
//构造HTTP请求
HttpRequest request = new HttpRequest(clientStream);
request.Logger = Logger;
//构造HTTP响应
HttpResponse response = new HttpResponse(clientStream);
response.Logger = Logger;
//处理请求类型
switch (request.Method)
{
case "GET":
OnGet(request, response);
break;
case "POST":
OnPost(request, response);
break;
default:
OnDefault(request, response);
break;
}
}
///
/// 处理ssl加密请求
///
///
///
private Stream ProcessSSL(Stream clientStream)
{
try
{
SslStream sslStream = new SslStream(clientStream);
sslStream.AuthenticateAsServer(serverCertificate, false, SslProtocols.Tls, true);
sslStream.ReadTimeout = 10000;
sslStream.WriteTimeout = 10000;
return sslStream;
}
catch (Exception e)
{
Log(e.Message);
clientStream.Close();
}
return null;
}
///
/// 记录日志
///
/// 日志消息
protected void Log(object message)
{
if (Logger != null) Logger.Log(message);
}
#endregion
#region 虚方法
///
/// 响应Get请求
///
/// 请求报文
public virtual void OnGet(HttpRequest request, HttpResponse response)
{
}
///
/// 响应Post请求
///
///
public virtual void OnPost(HttpRequest request, HttpResponse response)
{
}
///
/// 响应默认请求
///
public virtual void OnDefault(HttpRequest request, HttpResponse response)
{
}
#endregion
}
}