aboutsummaryrefslogtreecommitdiff
path: root/src/umapyai_alternative/__init__.py
blob: 9e613a1f76e2ccb2cfd784c02be68861f22eae22 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import qdrant_client
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.core.storage import StorageContext
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.qdrant import QdrantVectorStore
import logging
import threading
import sys
import time


def _wait_animation(stop_event: threading.Event, interval: float = 0.35):
  frames = ["", ".", "..", "..."]
  max_length = max(len(frame) for frame in frames)
  index = 0

  while not stop_event.is_set():
    frame = frames[index % len(frames)]

    sys.stdout.write("\r" + frame + " " * (max_length - len(frame)))

    index += 1

    time.sleep(interval)

  sys.stdout.write("\r" + " " * max_length + "\r")
  sys.stdout.flush()


def main():
  logging.basicConfig(level=logging.WARNING)

  documents = SimpleDirectoryReader("./uma_articles_clean").load_data()
  client = qdrant_client.QdrantClient(path="./qdrant_data")
  vector_store = QdrantVectorStore(client=client, collection_name="umamusume")
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
  llm = Ollama(model="gpt-oss:20b", request_timeout=120.0)
  embedding_model = HuggingFaceEmbedding(
      model_name="sentence-transformers/all-MiniLM-L6-v2")

  Settings.llm = llm
  Settings.embed_model = embedding_model

  index = VectorStoreIndex.from_documents(
      documents,
      storage_context=storage_context,
  )
  query_engine = index.as_query_engine(streaming=True)

  try:
    while True:
      try:
        user_query = input("> ").strip()
      except EOFError:
        print()

        break

      if not user_query:
        continue

      if user_query.lower() in {"quit", "exit", "q"}:
        break

      response_stream = query_engine.query(user_query)
      stop_event = threading.Event()

      print()

      animation_thread = threading.Thread(
          target=_wait_animation, args=(stop_event,), daemon=True)

      animation_thread.start()

      got_first_token = False

      try:
        for token in response_stream.response_gen:
          if not got_first_token:
            stop_event.set()
            animation_thread.join()

            got_first_token = True

          print(token, end="", flush=True)
      finally:
        if not got_first_token:
          stop_event.set()
          animation_thread.join()

      print("\n")
  except KeyboardInterrupt:
    print()

  print()