@@ -11,6 +11,7 @@ import (
11
11
"time"
12
12
13
13
"github.com/google/uuid"
14
+ "nhooyr.io/websocket"
14
15
15
16
"cdr.dev/slog"
16
17
@@ -98,12 +99,28 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
98
99
return
99
100
}
100
101
102
+ api .websocketWaitMutex .Lock ()
103
+ api .websocketWaitGroup .Add (1 )
104
+ api .websocketWaitMutex .Unlock ()
105
+ defer api .websocketWaitGroup .Done ()
106
+ conn , err := websocket .Accept (rw , r , nil )
107
+ if err != nil {
108
+ httpapi .Write (rw , http .StatusBadRequest , httpapi.Response {
109
+ Message : "Failed to accept websocket." ,
110
+ Detail : err .Error (),
111
+ })
112
+ return
113
+ }
114
+
115
+ ctx , wsNetConn := websocketNetConn (r .Context (), conn , websocket .MessageText )
116
+ defer wsNetConn .Close () // Also closes conn.
117
+
101
118
bufferedLogs := make (chan database.ProvisionerJobLog , 128 )
102
119
closeSubscribe , err := api .Pubsub .Subscribe (provisionerJobLogsChannel (job .ID ), func (ctx context.Context , message []byte ) {
103
120
var logs []database.ProvisionerJobLog
104
121
err := json .Unmarshal (message , & logs )
105
122
if err != nil {
106
- api .Logger .Warn (r . Context () , fmt .Sprintf ("invalid provisioner job log on channel %q: %s" , provisionerJobLogsChannel (job .ID ), err .Error ()))
123
+ api .Logger .Warn (ctx , fmt .Sprintf ("invalid provisioner job log on channel %q: %s" , provisionerJobLogsChannel (job .ID ), err .Error ()))
107
124
return
108
125
}
109
126
@@ -113,7 +130,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
113
130
default :
114
131
// If this overflows users could miss logs streaming. This can happen
115
132
// if a database request takes a long amount of time, and we get a lot of logs.
116
- api .Logger .Warn (r . Context () , "provisioner job log overflowing channel" )
133
+ api .Logger .Warn (ctx , "provisioner job log overflowing channel" )
117
134
}
118
135
}
119
136
})
@@ -126,7 +143,7 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
126
143
}
127
144
defer closeSubscribe ()
128
145
129
- provisionerJobLogs , err := api .Database .GetProvisionerLogsByIDBetween (r . Context () , database.GetProvisionerLogsByIDBetweenParams {
146
+ provisionerJobLogs , err := api .Database .GetProvisionerLogsByIDBetween (ctx , database.GetProvisionerLogsByIDBetweenParams {
130
147
JobID : job .ID ,
131
148
CreatedAfter : after ,
132
149
CreatedBefore : before ,
@@ -142,17 +159,8 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
142
159
return
143
160
}
144
161
145
- // "follow" uses the ndjson format to stream data.
146
- // See: https://canjs.com/doc/can-ndjson-stream.html
147
- rw .Header ().Set ("Content-Type" , "application/stream+json" )
148
- rw .WriteHeader (http .StatusOK )
149
- if flusher , ok := rw .(http.Flusher ); ok {
150
- flusher .Flush ()
151
- }
152
-
153
162
// The Go stdlib JSON encoder appends a newline character after message write.
154
- encoder := json .NewEncoder (rw )
155
-
163
+ encoder := json .NewEncoder (wsNetConn )
156
164
for _ , provisionerJobLog := range provisionerJobLogs {
157
165
err = encoder .Encode (convertProvisionerJobLog (provisionerJobLog ))
158
166
if err != nil {
@@ -171,9 +179,6 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
171
179
if err != nil {
172
180
return
173
181
}
174
- if flusher , ok := rw .(http.Flusher ); ok {
175
- flusher .Flush ()
176
- }
177
182
case <- ticker .C :
178
183
job , err := api .Database .GetProvisionerJobByID (r .Context (), job .ID )
179
184
if err != nil {
0 commit comments