2022-09-26业务场景00
请注意,本文编写于 549 天前,最后修改于 549 天前,其中某些信息可能已经过时。

一、WebSocket 协议

WebSocket 是个好东西,为我们提供了便捷且实时的通讯能力。然而,对于 WebSocket 客户端的鉴权,协议的 RFC 是这么说的:

This protocol doesn’t prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket server can use any client authentication mechanism available to a generic HTTP server, such as cookies, HTTP authentication, or TLS authentication.

也就是说,鉴权需要自己动手

二、协议原理

WebSocket 是独立的、创建在 TCP 上的协议。

为了创建 Websocket 连接,需要通过浏览器发出请求,之后服务器进行回应,这个过程通常称为“握手”。

实现步骤

  • 1)发起请求的浏览器端,发出协商报文
GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
  • 2)服务端响应 101 状态码(即切换到 socket 通讯方式),其报文:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
  • 3)协议切换完成,双方使用 Socket 通讯

三、鉴权授权实现方案

通过对协议实现的解读可知:在 HTTP 切换到 Socket 之前,没有什么好的机会进行鉴权,因为在这个时间节点,报文(或者说请求的 Headers)必须遵守协议规范。但这不妨碍我们在协议切换完成后,进行鉴权授权:

3.1 鉴权

  1. 在连接建立时,检查连接的 HTTP 请求头信息(比如 cookies 中关于用户的身份信息)
  2. 在每次接收到信息时,检查连接是否已经授权过,及授权是否过期
  3. 以上两点,只要答案为否,则服务端主动关闭 socket 连接

3.2 授权

服务端在连接建立时,颁发一个 ticket 给 peer 端,这个 ticket 可以包含但不限于:

  • peer 端的 uniqueId(可以是 ip,userid,deviceid…任一种具备唯一性的键)
  • 过期时间的 timestamp
  • token:由以上信息生成的哈希值,最好能加盐

3.3 安全性补充说明

比如,这一套机制如何防范重放攻击?个人认为可以从以下几点出发:

  • 可以用这里提到的 expires,保证过期,如果你愿意,甚至可以每次下发消息时都发送一个新的 Ticket,只要上传消息对不上这个 Ticket,就断开,这样非 Original Peer 是没法重放的
  • 可以结合 redis,实现 ratelimit,防止高频刷接口,可以参考 express-rate-limit
  • 为防止中间人,最好使用 wss(TLS)

3.4 代码实现

WebSocket 连接处理,基于 Node.js 的 ws 实现

import url from "url";
import WebSocket from "ws";
import debug from "debug";
import moment from "moment";
import { Ticket } from "../models";

const debugInfo = debug("server:global");

// server 可以是http server 实例
const wss = new WebSocket.Server({ server });
wss.on("connection", async (ws) => {
  const location = url.parse(ws.upgradeReq.url, true);
  const cookie = ws.upgradeReq.cookie;
  debugInfo("ws request from:" + location, "cookies:", cookie);
  // issue & send ticket to the peer
  if (!checkIdentify(ws)) {
    terminate(ws);
  } else {
    const ticket = issueTicket(ws);
    await ticket.save();
    ws.send(ticket.pojo());
    ws.on("message", (message) => {
      if (!checkTicket(ws, message)) {
        terminate(ws);
      }
      debugInfo("received:%s", message);
    });
  }
});
function issueTicket(ws) {
  const uniqueId = ws.upgradeReq.connection.remoteAddress;
  return new Ticket(uniqueId);
}
async function checkTicket(ws, message) {
  const uniqueId = ws.upgrade.connection.remoteAddress;
  const record = await Ticket.get(uniqueId);
  const token = message.token;
  return (
    record &&
    record.expires &&
    record.token &&
    record.token === token &&
    moment(record.expires) >= moment()
  );
}
// 身份检查,可填入具体检查逻辑
function checkIdentity(ws) {
  return true;
}
function terminate(ws) {
  ws.send("BYE!");
  ws.close();
}

授权用到的 Ticket(这里存储用到的是 knex + postgreSQL)

import shortid from 'shortid';
import {utils} from '../components';
import {db} from './database';

export default class Ticket{
    constructor(uniqueId,expiresMinutes = 30){
        const now = new Date();
        this.unique_id = uniqueId;
        this.token = Ticket.generateToken(uniqueId,now);
        this.crated = now;
        this.expires = moment(now).add(expiresMinutes,'minute');
    }
    pojo(){
        return {
            ...this;
        }
    }
    async save(){
        return await db.from('tickets').insert(this.pojo()).returning('id');
    }
    static async get(uniqueId){
        const result = await db
            .from('tickets')
            .select('id','unique_id','token','expires','created')
            .where('unique_id',unique_id);
        const tickets = JSON.parse(JSON.stringify(result[0]));
        return tickets;
    }
    static generateToken(uniqueId,now){
        const part1 = uniqueId;
        const part2 = now.getTime().toString();
        const part3 = shortid.generate();
        return utils.sha1(`${part1}:${part2}:${part3}`);
    }
}

utils 的哈希方法:

import crypto from "crypto";

export default {
  sha1(str) {
    const shaAlog = crypto.createHash("sha1");
    shaAlog.update(str);
    return shaAlog.digest("hex");
  },
};

本文作者:前端小毛

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!