Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/gateway/gateway-plugin/gateway-plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ spec:
- path:
type: PathPrefix
value: /v1/models
- path:
type: PathPrefix
value: /view
backendRefs:
- name: aibrix-metadata-service
port: 8090
Expand Down
3 changes: 3 additions & 0 deletions dist/chart/templates/gateway-plugin/httproute.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ spec:
- path:
type: PathPrefix
value: /v1/models
- path:
type: PathPrefix
value: /view
backendRefs:
- name: aibrix-metadata-service
port: 8090
97 changes: 97 additions & 0 deletions pkg/metadata/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"time"

"github.com/gorilla/mux"
"github.com/redis/go-redis/v9"
"github.com/vllm-project/aibrix/pkg/cache"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -52,6 +57,9 @@ func NewHTTPServer(addr string, redis *redis.Client) *http.Server {
r.HandleFunc("/DeleteUser", server.deleteUser).Methods("POST")
// OpenAI API related handlers
r.HandleFunc("/v1/models", server.models).Methods("GET")

r.HandleFunc("/view", server.view).Methods("POST") // Changed from GET to POST

// Health related handlers
r.HandleFunc("/healthz", server.healthz).Methods("GET")
r.HandleFunc("/readyz", server.readyz).Methods("GET")
Expand All @@ -75,6 +83,90 @@ func (s *httpServer) models(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", string(jsonBytes))
}

func (s *httpServer) view(w http.ResponseWriter, r *http.Request) {
var req ViewRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
return
}

// Validate required parameters
if req.RequestID == "" {
klog.ErrorS(errors.New("missing requestID"), "missing required parameter: requestID")
http.Error(w, "missing required parameter: requestID", http.StatusBadRequest)
return
}

// Get RequestStore struct from Redis
data, err := s.redisClient.Get(r.Context(), req.RequestID).Result()
if err != nil {
klog.ErrorS(err, "Failed to get request data")
http.Error(w, fmt.Sprintf("Failed to get request data: %v", err), http.StatusInternalServerError)
return
}

// Deserialize JSON data into RequestStore struct
var requestStore types.RequestStore
if err = json.Unmarshal([]byte(data), &requestStore); err != nil {
klog.ErrorS(err, "Failed to deserialize RequestStore")
http.Error(w, fmt.Sprintf("Failed to deserialize RequestStore: %v", err), http.StatusInternalServerError)
return
}

// Use the deserialized struct
reqPath := requestStore.Path
if !strings.HasPrefix(reqPath, "/") {
reqPath = "/" + reqPath
}

// Construct the download URL
downloadURL := fmt.Sprintf("http://%s:%s/view%s", requestStore.IP, requestStore.Port, reqPath)
klog.Infof("Attempting to download file from: %s", downloadURL)

// Make request to download the model
client := &http.Client{
Timeout: 30 * time.Second,
}
resp, err := client.Get(downloadURL)
if err != nil {
klog.Errorf("Failed to download file from %s: %v", downloadURL, err)
http.Error(w, fmt.Sprintf("failed to download file: %v", err), http.StatusInternalServerError)
return
}
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
klog.Errorf("Download failed with status %d from %s", resp.StatusCode, downloadURL)
http.Error(w, fmt.Sprintf("download failed with status: %d", resp.StatusCode), resp.StatusCode)
return
}

// Set appropriate headers for file download
w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
w.Header().Set("Content-Length", contentLength)
}

// Extract filename from path for Content-Disposition
filename := filepath.Base(requestStore.Path)
if filename == "." || filename == "/" {
filename = "model_file"
}
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))

// Stream the file content to the response
_, err = io.Copy(w, resp.Body)
if err != nil {
klog.Errorf("Failed to stream file: %v", err)
// Note: Can't send HTTP error here as headers are already written
return
}

klog.Infof("Successfully downloaded file from %s", downloadURL)
}

func (s *httpServer) createUser(w http.ResponseWriter, r *http.Request) {
var u utils.User

Expand Down Expand Up @@ -199,3 +291,8 @@ func (s *httpServer) readyz(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ready")
}

// Add this struct definition near the top of the file after the httpServer struct
type ViewRequest struct {
RequestID string `json:"request-id"`
}
8 changes: 4 additions & 4 deletions pkg/plugins/gateway/gateway_req_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
var term int64 // Identify the trace window

routingCtx, _ := ctx.(*types.RoutingContext)
requestPath := routingCtx.ReqPath
routingAlgorithm := routingCtx.Algorithm

body := req.Request.(*extProcPb.ProcessingRequest_RequestBody)
model, message, stream, errRes := validateRequestBody(requestID, requestPath, body.RequestBody.GetBody(), user)
model, message, stream, errRes := validateRequestBody(routingCtx, body.RequestBody.GetBody(), user)
if errRes != nil {
return errRes, model, routingCtx, stream, term
}

requestPath := routingCtx.ReqPath
routingAlgorithm := routingCtx.Algorithm
routingCtx.Model = model
routingCtx.Message = message
routingCtx.ReqBody = body.RequestBody.GetBody()
Expand Down
78 changes: 47 additions & 31 deletions pkg/plugins/gateway/gateway_rsp_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/ssestream"
Expand Down Expand Up @@ -52,6 +53,8 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
var processingRes *extProcPb.ProcessingResponse
var promptTokens, completionTokens, totalTokens int64
var headers []*configPb.HeaderValueOption
headers = buildEnvoyProxyHeaders(headers, HeaderRequestID, requestID)

complete := hasCompleted
routerCtx, _ := ctx.(*types.RoutingContext)

Expand Down Expand Up @@ -98,6 +101,12 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
if processingRes != nil {
return processingRes, complete
}
} else if !isLanguageRequest(routerCtx.ReqPath) {
processingRes = s.processNonLanguangeResponse(ctx, b)
if processingRes != nil {
return processingRes, true
}
totalTokens = 1
}
}

Expand All @@ -116,41 +125,14 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
}}},
err.Error()), complete
}

headers = append(headers,
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderUpdateRPM,
RawValue: []byte(fmt.Sprintf("%d", rpm)),
},
},
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderUpdateTPM,
RawValue: []byte(fmt.Sprintf("%d", tpm)),
},
},
)
headers = buildEnvoyProxyHeaders(headers, HeaderUpdateRPM, fmt.Sprintf("%d", rpm),
HeaderUpdateTPM, fmt.Sprintf("%d", tpm))
requestEnd = fmt.Sprintf(requestEnd+"rpm: %d, tpm: %d, ", rpm, tpm)
}

if routerCtx != nil && routerCtx.HasRouted() {
targetPodIP := routerCtx.TargetAddress()
headers = append(headers,
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderTargetPod,
RawValue: []byte(targetPodIP),
},
},
&configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: HeaderRequestID,
RawValue: []byte(requestID),
},
},
)
requestEnd = fmt.Sprintf(requestEnd+"targetPod: %s", targetPodIP)
headers = buildEnvoyProxyHeaders(headers, HeaderTargetPod, routerCtx.TargetAddress())
requestEnd = fmt.Sprintf(requestEnd+"targetPod: %s", routerCtx.TargetAddress())
}

klog.Infof("request end, requestID: %s - %s", requestID, requestEnd)
Expand Down Expand Up @@ -244,3 +226,37 @@ func processLanguageResponse(requestID string, b *extProcPb.ProcessingRequest_Re
}
return
}

// nolint:nakedret
func (s *Server) processNonLanguangeResponse(ctx context.Context, b *extProcPb.ProcessingRequest_ResponseBody) (processingRes *extProcPb.ProcessingResponse) {
routerCtx, _ := ctx.(*types.RoutingContext)
if !routerCtx.SaveToRemoteStorage {
return
}

var jsonMap map[string]interface{}
if err := json.Unmarshal(b.ResponseBody.GetBody(), &jsonMap); err != nil {
return buildErrorResponse(envoyTypePb.StatusCode_InternalServerError,
err.Error(), HeaderErrorResponseUnmarshal, "true")
}

storagePath, ok := jsonMap["output"].(string)
if !ok {
klog.ErrorS(ErrorUnknownResponse, "path not found in response", "requestID", routerCtx.RequestID, "responseBody", string(b.ResponseBody.GetBody()))
return buildErrorResponse(envoyTypePb.StatusCode_InternalServerError, "path not found in response")
}

if err := s.writeStorageRequest(ctx, routerCtx.RequestID, types.RequestStore{
RequestID: routerCtx.RequestID,
Status: true,
IP: strings.Split(routerCtx.TargetAddress(), ":")[0],
Port: "8080",
Path: storagePath,
UpdateTime: time.Now(),
}); err != nil {
klog.ErrorS(err, "error to store request response in redis", "requestID", routerCtx.RequestID)
return buildErrorResponse(envoyTypePb.StatusCode_InternalServerError, err.Error())
}

return
}
1 change: 1 addition & 0 deletions pkg/plugins/gateway/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
HeaderWentIntoReqHeaders = "x-went-into-req-headers"
HeaderTargetPod = "target-pod"
HeaderRoutingStrategy = "routing-strategy"
HeaderContentLength = "content-length"
HeaderRequestID = "request-id"
HeaderModel = "model"

Expand Down
32 changes: 31 additions & 1 deletion pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@ limitations under the License.
package gateway

import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)
Expand All @@ -40,7 +44,10 @@ const (

// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestpath.
// nolint:nakedret
func validateRequestBody(requestID, requestPath string, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {
func validateRequestBody(routingCtx *types.RoutingContext, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {
requestID := routingCtx.RequestID
requestPath := routingCtx.ReqPath

var streamOptions openai.ChatCompletionStreamOptionsParam
var jsonMap map[string]json.RawMessage
if err := json.Unmarshal(requestBody, &jsonMap); err != nil {
Expand Down Expand Up @@ -108,6 +115,14 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
return
}
model = imageGenerationObj.Model
_, ok := jsonMap["save_disk_path"]
if ok {
routingCtx.SaveToRemoteStorage = true
if routingCtx.Algorithm == routing.RouterNotSet {
routingCtx.Algorithm = routing.RouterRandom
}
}

default:
errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true")
return
Expand Down Expand Up @@ -373,3 +388,18 @@ func validateTokenInputs(tokenArrays [][]int64) error {

return nil
}

func (s *Server) writeStorageRequest(ctx context.Context, requestID string, requestStore types.RequestStore) error {
data, err := json.Marshal(requestStore)
if err != nil {
return err
}

// TODO: make storage path expiration configurable, default 1 day
status := s.redisClient.Set(ctx, requestID, data, 24*time.Hour)
if err := status.Err(); err != nil {
return err
}

return nil
}
5 changes: 3 additions & 2 deletions pkg/plugins/gateway/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/openai/openai-go"
"github.com/stretchr/testify/assert"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
)

Expand Down Expand Up @@ -138,7 +139,7 @@ func Test_ValidateRequestBody(t *testing.T) {
}

for _, tt := range testCases {
model, messages, stream, errRes := validateRequestBody("1", tt.requestPath, tt.requestBody, tt.user)
model, messages, stream, errRes := validateRequestBody(&types.RoutingContext{RequestID: "1", ReqPath: tt.requestPath}, tt.requestBody, tt.user)

if tt.statusCode == 200 {
assert.Equal(t, (*extProcPb.ProcessingResponse)(nil), errRes, tt.message)
Expand Down Expand Up @@ -312,7 +313,7 @@ func Test_ValidateRequestBody_Embeddings(t *testing.T) {
}

for _, tt := range testCases {
model, messages, stream, errRes := validateRequestBody("test-request-id", tt.requestPath, tt.requestBody, tt.user)
model, messages, stream, errRes := validateRequestBody(&types.RoutingContext{RequestID: "test-request-id", ReqPath: tt.requestPath}, tt.requestBody, tt.user)
t.Log(tt.message)
if tt.statusCode == 200 {
assert.Equal(t, (*extProcPb.ProcessingResponse)(nil), errRes, tt.message)
Expand Down
Loading