forked from vespa-engine/sample-apps
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencode_query.py
More file actions
82 lines (66 loc) · 3.09 KB
/
Copy pathencode_query.py
File metadata and controls
82 lines (66 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Helper: emit the JSON body to send to /search/ for each rank profile.
Four modes:
--cosine For the `cosine_baseline` rank profile. Sends only the query
text - Vespa runs the passage embedder server-side via embed()
to produce the query vector.
--rerank For the `hypencoder_rerank` rank profile. Sends the query text
(server-side embed() for the cosine first phase) AND the
tokenized query (for the ONNX q-net second phase).
--lexical For the `hypencoder_lexical_rerank` rank profile. BM25
first-phase against the text field, hypencoder q-net on the
top candidates. Sends `userQuery()` + tokenized query; no
vector.
default For the `hypencoder_onnx` rank profile (rank-all). Sends only
token IDs + attention mask. ~830 byte payload.
All modes are designed to be piped into `vespa query --file`:
python encode_query.py "tallest mountain in the world" > q.json
vespa query --file q.json
"""
import argparse
import json
import sys
import numpy as np
from transformers import AutoTokenizer
MAX_SEQ = 64
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("query", help="The query text.")
ap.add_argument("--checkpoint", default="jfkback/hypencoder.2_layer")
mode = ap.add_mutually_exclusive_group()
mode.add_argument("--rerank", action="store_true",
help="Emit a body for hypencoder_rerank.")
mode.add_argument("--cosine", action="store_true",
help="Emit a body for cosine_baseline.")
mode.add_argument("--lexical", action="store_true",
help="Emit a body for hypencoder_lexical_rerank.")
ap.add_argument("--hits", type=int, default=10)
args = ap.parse_args()
body = {
"yql": "select id, text from doc where true",
"hits": args.hits,
"timeout": "30s",
}
if args.cosine:
body["ranking.profile"] = "cosine_baseline"
body["input.query(q_vec)"] = "embed(passage_embedder, @q)"
body["q"] = args.query
else:
tok = AutoTokenizer.from_pretrained(args.checkpoint, use_fast=True)
enc = tok([args.query], padding="max_length", truncation=True,
max_length=MAX_SEQ, return_tensors="np")
body["input.query(input_ids)"] = [enc["input_ids"][0].astype(np.float32).tolist()]
body["input.query(attention_mask)"] = [enc["attention_mask"][0].astype(np.float32).tolist()]
if args.rerank:
body["ranking.profile"] = "hypencoder_rerank"
body["input.query(q_vec)"] = "embed(passage_embedder, @q)"
body["q"] = args.query
elif args.lexical:
body["yql"] = ('select id, text from doc where '
'{grammar: "weakAnd", defaultIndex: "text"}userInput(@q)')
body["q"] = args.query
body["ranking.profile"] = "hypencoder_lexical_rerank"
else:
body["ranking.profile"] = "hypencoder_onnx"
json.dump(body, sys.stdout)
if __name__ == "__main__":
main()