当前位置: 首页 > news >正文

netty-websocket扩展协议及token鉴权补充

文章源码:gitee
源码部分可以看上一篇文章中的源码分析netty-websocket 鉴权token及统一请求和响应头(鉴权控制器)

最近刚好没事,看到有朋友说自定义协议好搞,我就想了想,发现上面那种方式实现确实麻烦,而且兼容性还不行,后来我对照着WebSocketServerProtocolHandler试了试扩展一下,将WebSocketServerProtocolHandler中handlerAdded添加的握手逻辑换成自己的,终于测通了,我用postman测试时,请求头也可以自定义,下面上代码

1.(userEventTriggered): 鉴权成功后可以抛出自定义事件,业务channel中实现 事件监听器userEventTriggered,这样就可以在鉴权成功后,握手成功前执行某个方法,比如验证权限啥的,具体可看SecurityHandler中的例子
2. (exceptionCaught): 异常捕获
3. channel设置attr实现channel上下文的数据属性
4. …等等

扩展WebSocketProtocolHandler

这个协议有很多私有方法外部引用不了,所以只能copy一份出来,主要是把handlerAdded这个方法重写了,将原有的‘WebSocketServerProtocolHandshakeHandler’替换为‘自己的(SecurityHandler)’

package com.chat.nettywebsocket.handler.test;import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.Utf8FrameValidator;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.AttributeKey;import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;/***** @author qb* @date 2024/2/5 8:53* @version 1.0*/
public class CustomWebSocketServerProtocolHandler extends WebSocketServerProtocolHandler {public enum ServerHandshakeStateEvent {/*** The Handshake was completed successfully and the channel was upgraded to websockets.** @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,* it provides extra information about the handshake*/@DeprecatedHANDSHAKE_COMPLETE}/*** The Handshake was completed successfully and the channel was upgraded to websockets.*/public static final class HandshakeComplete {private final String requestUri;private final HttpHeaders requestHeaders;private final String selectedSubprotocol;HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {this.requestUri = requestUri;this.requestHeaders = requestHeaders;this.selectedSubprotocol = selectedSubprotocol;}public String requestUri() {return requestUri;}public HttpHeaders requestHeaders() {return requestHeaders;}public String selectedSubprotocol() {return selectedSubprotocol;}}public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {super(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith);this.websocketPath = websocketPath;this.subprotocols = subprotocols;this.allowExtensions = allowExtensions;maxFramePayloadLength = maxFrameSize;this.allowMaskMismatch = allowMaskMismatch;this.checkStartsWith = checkStartsWith;}private final String websocketPath;private final String subprotocols;private final boolean allowExtensions;private final int maxFramePayloadLength;private final boolean allowMaskMismatch;private final boolean checkStartsWith;@Overridepublic void handlerAdded(ChannelHandlerContext ctx) {System.err.println("handlerAdded");ChannelPipeline cp = ctx.pipeline();if (cp.get(SecurityHandler.class) == null) {// Add the WebSocketHandshakeHandler before this one.// 增加协议实现handlerctx.pipeline().addBefore(ctx.name(), SecurityHandler.class.getName(),new SecurityHandler(websocketPath, subprotocols,allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));}if (cp.get(Utf8FrameValidator.class) == null) {// Add the UFT8 checking before this one.ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),new Utf8FrameValidator());}}private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");static WebSocketServerHandshaker getHandshaker(Channel channel) {return channel.attr(HANDSHAKER_ATTR_KEY).get();}static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);}static ChannelHandler forbiddenHttpRequestResponder() {return new ChannelInboundHandlerAdapter() {@Overridepublic void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {if (msg instanceof FullHttpRequest) {((FullHttpRequest) msg).release();FullHttpResponse response =new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);ctx.channel().writeAndFlush(response);} else {ctx.fireChannelRead(msg);}}};}}

SecurityHandler

复制的WebSocketServerProtocolHandshakeHandler的方法,就是改了请求头逻辑和发布事件的相关类调整

package com.chat.nettywebsocket.handler.test;import com.chat.nettywebsocket.handler.test.CustomWebSocketServerProtocolHandler;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.ssl.SslHandler;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;import static com.chat.nettywebsocket.handler.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpUtil.isKeepAlive;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;/***** @author qb* @date 2024/2/5 8:37* @version 1.0*/
@Slf4j
@ChannelHandler.Sharable
public class SecurityHandler extends ChannelInboundHandlerAdapter {private final String websocketPath;private final String subprotocols;private final boolean allowExtensions;private final int maxFramePayloadSize;private final boolean allowMaskMismatch;private final boolean checkStartsWith;SecurityHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);}SecurityHandler(String websocketPath, String subprotocols,boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {this.websocketPath = websocketPath;this.subprotocols = subprotocols;this.allowExtensions = allowExtensions;maxFramePayloadSize = maxFrameSize;this.allowMaskMismatch = allowMaskMismatch;this.checkStartsWith = checkStartsWith;}@Overridepublic void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {final FullHttpRequest req = (FullHttpRequest) msg;if (isNotWebSocketPath(req)) {ctx.fireChannelRead(msg);return;}try {if (req.method() != GET) {sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));return;}// 比如 此处极权成功就抛出成功事件SecurityCheckComplete complete = new SecurityHandler.SecurityCheckComplete(true);// 设置 channel属性,相当于channel固定的上下文属性ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);ctx.fireUserEventTriggered(complete);final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,allowExtensions, maxFramePayloadSize, allowMaskMismatch);final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);if (handshaker == null) {WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());} else {String s = req.headers().get("Sec-token");HttpHeaders httpHeaders = null;if(StringUtils.hasText(s)){httpHeaders = new DefaultHttpHeaders().add("Sec-token",s);}else {httpHeaders = new DefaultHttpHeaders();}// 设置请求头final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(),req, httpHeaders,ctx.channel().newPromise());System.err.println("handshakeFuture: "+handshakeFuture.isSuccess());handshakeFuture.addListener(new ChannelFutureListener() {@Overridepublic void operationComplete(ChannelFuture future) throws Exception {if (!future.isSuccess()) {ctx.fireExceptionCaught(future.cause());} else {// Kept for compatibilityctx.fireUserEventTriggered(CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);ctx.fireUserEventTriggered(new CustomWebSocketServerProtocolHandler.HandshakeComplete(req.uri(), req.headers(), handshaker.selectedSubprotocol()));}}});CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);ctx.pipeline().replace(this, "WS403Responder",CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());}} finally {req.release();}}private boolean isNotWebSocketPath(FullHttpRequest req) {return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);}private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {ChannelFuture f = ctx.channel().writeAndFlush(res);if (!isKeepAlive(req) || res.status().code() != 200) {f.addListener(ChannelFutureListener.CLOSE);}}private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {String protocol = "ws";if (cp.get(SslHandler.class) != null) {// SSL in use so use Secure WebSocketsprotocol = "wss";}String host = req.headers().get(HttpHeaderNames.HOST);return protocol + "://" + host + path;}// 自定义事件实体@Getter@AllArgsConstructorpublic static final class SecurityCheckComplete {private Boolean isLogin;}
}

ChatHandler

package com.chat.nettywebsocket.handler;import com.alibaba.fastjson.JSONObject;
import com.chat.nettywebsocket.domain.Message;
import com.chat.nettywebsocket.handler.test.CustomWebSocketServerProtocolHandler;
import com.chat.nettywebsocket.handler.test.SecurityHandler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;import java.nio.charset.StandardCharsets;/*** 自定义控制器* @author qubing* @date 2021/8/16 9:26*/
@Slf4j
public class ChatHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {/*** 为channel添加属性  将userid设置为属性,避免客户端特殊情况退出时获取不到userid*/AttributeKey<Integer> userid = AttributeKey.valueOf("userid");/*** 连接时* @param ctx 上下文* @throws Exception /*/@Overridepublic void channelActive(ChannelHandlerContext ctx) throws Exception {log.info("与客户端建立连接,通道开启!");// 添加到channelGroup通道组MyChannelHandlerPool.channelGroup.add(ctx.channel());}/*** 断开连接时* @param ctx /* @throws Exception /*/@Overridepublic void channelInactive(ChannelHandlerContext ctx) throws Exception {log.info("与客户端断开连接,通道关闭!");// 从channelGroup通道组移除
//        MyChannelHandlerPool.channelGroup.remove(ctx.channel());
//        Integer useridQuit = ctx.channel().attr(userid).get();
//        MyChannelHandlerPool.channelIdMap.remove(useridQuit);log.info("断开的用户id为");}// 监听事件@Overridepublic void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {// 自定义鉴权成功事件if (evt instanceof SecurityHandler.SecurityCheckComplete){// 鉴权成功后的逻辑log.info("鉴权成功  SecurityHandler.SecurityCheckComplete");}// 握手成功else if (evt instanceof CustomWebSocketServerProtocolHandler.HandshakeComplete) {log.info("Handshake has completed");// 握手成功后的逻辑  鉴权和不鉴权模式都绑定channellog.info("Handshake has completed after binding channel");}super.userEventTriggered(ctx, evt);}/*** 获取消息时* @param ctx /* @param msg 消息* @throws Exception /*/@Overrideprotected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {String mssage = msg.content().toString(StandardCharsets.UTF_8);ctx.channel().writeAndFlush(mssage);System.err.println(mssage);}/*** 群发所有人*/private void sendAllMessage(String message){//收到信息后,群发给所有channelMyChannelHandlerPool.channelGroup.writeAndFlush( new TextWebSocketFrame(message));}@Overridepublic void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {log.info("exceptionCaught 异常:{}",cause.getMessage());cause.printStackTrace();Channel channel = ctx.channel();//……if(channel.isActive()){log.info("手动关闭通道");ctx.close();};}
}

AttributeKeyUtils

public class AttributeKeyUtils {/*** 为channel添加属性  将userid设置为属性,避免客户端特殊情况退出时获取不到userid*/public static final AttributeKey<String> USER_ID = AttributeKey.valueOf("userid");public static final AttributeKey<SecurityHandler.SecurityCheckComplete> SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY =AttributeKey.valueOf("SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY");}

WsServerInitializer

@Slf4j
@ChannelHandler.Sharable
public class WsServerInitializer extends ChannelInitializer<SocketChannel> {//    @Override
//    protected void initChannel(SocketChannel socketChannel) throws Exception {
//        log.info("有新的连接");
//        ChannelPipeline pipeline = socketChannel.pipeline();
//        //netty 自带的http解码器
//        pipeline.addLast(new HttpServerCodec());
//        //http聚合器
//        pipeline.addLast(new HttpObjectAggregator(8192));
//        pipeline.addLast(new ChunkedWriteHandler());
//        //压缩协议
//        pipeline.addLast(new WebSocketServerCompressionHandler());
//        //http处理器 用来握手和执行进一步操作
        pipeline.addLast(new NettyWebsocketHttpHandler(config, listener));
//
//    }@Overrideprotected void initChannel(SocketChannel ch) throws Exception {log.info("有新的连接");//获取工人所要做的工程(管道器==管道器对应的便是管道channel)ChannelPipeline pipeline = ch.pipeline();//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)//1.设置心跳机制pipeline.addLast(new IdleStateHandler(5,0,0, TimeUnit.SECONDS));//2.出入站时的控制器,大部分用于针对心跳机制pipeline.addLast(new WsChannelDupleHandler());//3.加解码pipeline.addLast(new HttpServerCodec());//3.打印控制器,为工人提供明显可见的操作结果的样式pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));pipeline.addLast(new ChunkedWriteHandler());pipeline.addLast(new HttpObjectAggregator(8192));// 扩展的websocket协议pipeline.addLast(new CustomWebSocketServerProtocolHandler("/ws","websocket",true,65536 * 10,false,true));//7.自定义的handler针对业务pipeline.addLast(new ChatHandler());}
}

上截图

postman测试怎增加自定义请求头

在这里插入图片描述

点击链接查看控制台

postman链接成功
在这里插入图片描述
根据日志可以看出,链接成功并且相应和请求的头是一致的
在这里插入图片描述

发送消息

在这里插入图片描述

在这里插入图片描述

相关文章:

  • 多线程(一)
  • 【力扣】查找总价格为目标值的两个商品,双指针法
  • Mac 下JDK环境变量配置 及 JDK多版本切换
  • 吉他学习:识谱,认识节奏,视唱节奏,节拍器的使用
  • 2402d,d的静态构造器
  • 多线程基础详解(看到就是赚到)
  • 预测模型:MATLAB线性回归
  • 在 VMware 虚拟机上安装 CentOS系统 完整(全图文)教程
  • K8S之Pod常见的状态和重启策略
  • 人工智能之无约束最优化与有约束最优化
  • C# Task的使用
  • 编码技巧——基于RedisTemplate的RedisClient实现、操作Lua脚本
  • python二维数组初始化的一个极其隐蔽的bug(浅拷贝)
  • Win32 SDK Gui编程系列之--ListView自绘OwnerDraw(续)
  • 幻兽帕鲁(Palworld)允许自建私服,它是怎么挣钱的呢?
  • iOS动画编程-View动画[ 1 ] 基础View动画
  • iOS帅气加载动画、通知视图、红包助手、引导页、导航栏、朋友圈、小游戏等效果源码...
  • leetcode388. Longest Absolute File Path
  • Linux中的硬链接与软链接
  • PHP的类修饰符与访问修饰符
  • SpingCloudBus整合RabbitMQ
  • sublime配置文件
  • supervisor 永不挂掉的进程 安装以及使用
  • vue 配置sass、scss全局变量
  • VuePress 静态网站生成
  • Vue学习第二天
  • 包装类对象
  • 前端面试题总结
  • - 转 Ext2.0 form使用实例
  • Unity3D - 异步加载游戏场景与异步加载游戏资源进度条 ...
  • ​MPV,汽车产品里一个特殊品类的进化过程
  • # C++之functional库用法整理
  • #HarmonyOS:Web组件的使用
  • #pragma data_seg 共享数据区(转)
  • #我与Java虚拟机的故事#连载18:JAVA成长之路
  • ()、[]、{}、(())、[[]]等各种括号的使用
  • (09)Hive——CTE 公共表达式
  • (27)4.8 习题课
  • (C语言)fgets与fputs函数详解
  • (二)斐波那契Fabonacci函数
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • (附源码)php新闻发布平台 毕业设计 141646
  • (六)激光线扫描-三维重建
  • (续)使用Django搭建一个完整的项目(Centos7+Nginx)
  • (转)es进行聚合操作时提示Fielddata is disabled on text fields by default
  • (转)linux自定义开机启动服务和chkconfig使用方法
  • .bat批处理(六):替换字符串中匹配的子串
  • .NET 6 在已知拓扑路径的情况下使用 Dijkstra,A*算法搜索最短路径
  • .NET6实现破解Modbus poll点表配置文件
  • .net开发时的诡异问题,button的onclick事件无效
  • [ C++ ] STL_vector -- 迭代器失效问题
  • [C++][数据结构][算法]单链式结构的深拷贝
  • [codevs 2822] 爱在心中 【tarjan 算法】
  • [C语言][PTA基础C基础题目集] strtok 函数的理解与应用
  • [Docker]四.Docker部署nodejs项目,部署Mysql,部署Redis,部署Mongodb