Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c313b2c

Browse files
fix(reranker): tests and top_n check fix #7212 (#7284)
reranker tests and top_n check fix #7212 Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
1 parent 137f163 commit c313b2c

File tree

3 files changed

+88
-34
lines changed

3 files changed

+88
-34
lines changed

‎core/http/endpoints/jina/rerank.go‎

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,22 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
3232
}
3333

3434
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
35-
35+
var requestTopN int32
36+
docs := int32(len(input.Documents))
37+
if input.TopN == nil { // omit top_n to get all
38+
requestTopN = docs
39+
} else {
40+
requestTopN = int32(*input.TopN)
41+
if requestTopN < 1 {
42+
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
43+
}
44+
if requestTopN > docs { // make it more obvious for backends
45+
requestTopN = docs
46+
}
47+
}
3648
request := &proto.RerankRequest{
3749
Query: input.Query,
38-
TopN: int32(input.TopN),
50+
TopN: requestTopN,
3951
Documents: input.Documents,
4052
}
4153

‎core/schema/jina.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ type JINARerankRequest struct {
55
BasicModelRequest
66
Query string `json:"query"`
77
Documents []string `json:"documents"`
8-
TopN int `json:"top_n"`
8+
TopN *int `json:"top_n,omitempty"`
99
Backend string `json:"backend"`
1010
}
1111

‎tests/e2e-aio/e2e_test.go‎

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -286,45 +286,64 @@ var _ = Describe("E2E test", func() {
286286
Context("reranker", func() {
287287
It("correctly", func() {
288288
modelName := "jina-reranker-v1-base-en"
289-
290-
req := schema.JINARerankRequest{
291-
BasicModelRequest: schema.BasicModelRequest{
292-
Model: modelName,
293-
},
294-
Query: "Organic skincare products for sensitive skin",
295-
Documents: []string{
296-
"Eco-friendly kitchenware for modern homes",
297-
"Biodegradable cleaning supplies for eco-conscious consumers",
298-
"Organic cotton baby clothes for sensitive skin",
299-
"Natural organic skincare range for sensitive skin",
300-
"Tech gadgets for smart homes: 2024 edition",
301-
"Sustainable gardening tools and compost solutions",
302-
"Sensitive skin-friendly facial cleansers and toners",
303-
"Organic food wraps and storage solutions",
304-
"All-natural pet food for dogs with allergies",
305-
"Yoga mats made from recycled materials",
306-
},
307-
TopN: 3,
289+
const query = "Organic skincare products for sensitive skin"
290+
var documents = []string{
291+
"Eco-friendly kitchenware for modern homes",
292+
"Biodegradable cleaning supplies for eco-conscious consumers",
293+
"Organic cotton baby clothes for sensitive skin",
294+
"Natural organic skincare range for sensitive skin",
295+
"Tech gadgets for smart homes: 2024 edition",
296+
"Sustainable gardening tools and compost solutions",
297+
"Sensitive skin-friendly facial cleansers and toners",
298+
"Organic food wraps and storage solutions",
299+
"All-natural pet food for dogs with allergies",
300+
"Yoga mats made from recycled materials",
301+
}
302+
// Exceed len or requested results
303+
randomValue := int(GinkgoRandomSeed()) % (len(documents) + 1)
304+
requestResults := randomValue + 1 // at least 1 results
305+
// Cap expectResults by the length of documents
306+
expectResults := min(requestResults, len(documents))
307+
var maybeSkipTopN = &requestResults
308+
if requestResults >= len(documents) && int(GinkgoRandomSeed())%2 == 0 {
309+
maybeSkipTopN = nil
308310
}
309311

310-
serialized, err := json.Marshal(req)
311-
Expect(err).To(BeNil())
312-
Expect(serialized).ToNot(BeNil())
313-
314-
rerankerEndpoint := apiEndpoint + "/rerank"
315-
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
316-
Expect(err).To(BeNil())
317-
Expect(resp).ToNot(BeNil())
318-
body, err := io.ReadAll(resp.Body)
319-
Expect(err).ToNot(HaveOccurred())
312+
resp, body := requestRerank(modelName, query, documents, maybeSkipTopN, apiEndpoint)
320313
Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp))
321314

322315
deserializedResponse := schema.JINARerankResponse{}
323-
err = json.Unmarshal(body, &deserializedResponse)
316+
err := json.Unmarshal(body, &deserializedResponse)
324317
Expect(err).To(BeNil())
325318
Expect(deserializedResponse).ToNot(BeZero())
326319
Expect(deserializedResponse.Model).To(Equal(modelName))
327-
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
320+
//Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
321+
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
322+
// Assert that relevance scores are in decreasing order
323+
for i := 1; i < len(deserializedResponse.Results); i++ {
324+
Expect(deserializedResponse.Results[i].RelevanceScore).To(
325+
BeNumerically("<=", deserializedResponse.Results[i-1].RelevanceScore),
326+
fmt.Sprintf("Result at index %d should have lower relevance score than previous result.", i),
327+
)
328+
}
329+
// Assert that each result's index points to the correct document
330+
for i, result := range deserializedResponse.Results {
331+
Expect(result.Index).To(
332+
And(
333+
BeNumerically(">=", 0),
334+
BeNumerically("<", len(documents)),
335+
),
336+
fmt.Sprintf("Result at position %d has index %d which should be within bounds [0, %d)", i, result.Index, len(documents)),
337+
)
338+
Expect(result.Document.Text).To(
339+
Equal(documents[result.Index]),
340+
fmt.Sprintf("Result at position %d (index %d) should have document text '%s', but got '%s'",
341+
i, result.Index, documents[result.Index], result.Document.Text),
342+
)
343+
}
344+
zeroOrNeg := int(GinkgoRandomSeed())%2 - 1 // Results in either -1 or 0
345+
resp, body = requestRerank(modelName, query, documents, &zeroOrNeg, apiEndpoint)
346+
Expect(resp.StatusCode).To(Equal(422), fmt.Sprintf("body: %s, response: %+v", body, resp))
328347
})
329348
})
330349
})
@@ -350,3 +369,26 @@ func downloadHttpFile(url string) (string, error) {
350369

351370
return tmpfile.Name(), nil
352371
}
372+
373+
func requestRerank(modelName, query string, documents []string, topN *int, apiEndpoint string) (*http.Response, []byte) {
374+
req := schema.JINARerankRequest{
375+
BasicModelRequest: schema.BasicModelRequest{
376+
Model: modelName,
377+
},
378+
Query: query,
379+
Documents: documents,
380+
TopN: topN,
381+
}
382+
383+
serialized, err := json.Marshal(req)
384+
Expect(err).To(BeNil())
385+
Expect(serialized).ToNot(BeNil())
386+
rerankerEndpoint := apiEndpoint + "/rerank"
387+
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
388+
Expect(err).To(BeNil())
389+
Expect(resp).ToNot(BeNil())
390+
body, err := io.ReadAll(resp.Body)
391+
Expect(err).ToNot(HaveOccurred())
392+
393+
return resp, body
394+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /