@@ -2,7 +2,9 @@ package sse
22
33import (
44 "context"
5+ "errors"
56 "fmt"
7+ "io"
68 "net/http"
79 "sync"
810 "time"
@@ -135,8 +137,10 @@ type Event struct {
135137
136138// Client SSE 客户端
137139type Client struct {
138- ID string
139- Writer http.ResponseWriter
140+ ID string
141+ Writer http.ResponseWriter
142+ disconnected bool
143+ mu sync.RWMutex
140144}
141145
142146// Close 关闭客户端连接
@@ -147,13 +151,23 @@ func (c *Client) Close() {
147151
148152// Send 发送数据给客户端
149153func (c * Client ) Send (data []byte ) error {
154+ // 检查客户端是否已断开
155+ if c .IsDisconnected () {
156+ return errors .New ("客户端已断开连接" )
157+ }
158+
150159 // 使用标准SSE格式发送数据
151160 // 格式: "data: {json数据}\n\n"
152161 sseData := fmt .Sprintf ("data: %s\n \n " , string (data ))
153162
154163 // 直接写入HTTP响应流
155164 _ , err := c .Writer .Write ([]byte (sseData ))
156165 if err != nil {
166+ // 检查是否是连接断开相关的错误
167+ if isConnectionError (err ) {
168+ c .SetDisconnected (true )
169+ return fmt .Errorf ("客户端连接已断开: %v" , err )
170+ }
157171 return fmt .Errorf ("写入SSE数据失败: %v" , err )
158172 }
159173
@@ -167,6 +181,50 @@ func (c *Client) Send(data []byte) error {
167181
168182// SetDisconnected 设置客户端断开状态
169183func (c * Client ) SetDisconnected (disconnected bool ) {
170- // 这里可以实现断开状态的设置逻辑
171- // 目前是一个空实现
184+ c .mu .Lock ()
185+ defer c .mu .Unlock ()
186+ c .disconnected = disconnected
187+ }
188+
189+ // IsDisconnected 检查客户端是否已断开
190+ func (c * Client ) IsDisconnected () bool {
191+ c .mu .RLock ()
192+ defer c .mu .RUnlock ()
193+ return c .disconnected
194+ }
195+
196+ // isConnectionError 检查错误是否为连接断开相关错误
197+ func isConnectionError (err error ) bool {
198+ // 检查常见的连接断开错误
199+ if errors .Is (err , io .ErrShortWrite ) {
200+ return true
201+ }
202+
203+ // 检查错误信息中是否包含连接断开的关键字
204+ errorStr := err .Error ()
205+ connectionErrors := []string {
206+ "broken pipe" ,
207+ "connection reset" ,
208+ "use of closed network connection" ,
209+ "short write" ,
210+ "connection aborted" ,
211+ }
212+
213+ for _ , errPattern := range connectionErrors {
214+ if contains (errorStr , errPattern ) {
215+ return true
216+ }
217+ }
218+
219+ return false
220+ }
221+
222+ // contains 检查字符串是否包含子字符串(简单实现)
223+ func contains (s , substr string ) bool {
224+ for i := 0 ; i <= len (s )- len (substr ); i ++ {
225+ if s [i :i + len (substr )] == substr {
226+ return true
227+ }
228+ }
229+ return false
172230}
0 commit comments