Artem Kastrov il y a 1 mois
Parent
commit
8b5e72db77
15 fichiers modifiés avec 1803 ajouts et 100 suppressions
  1. 1 1
      .gitignore
  2. 3 0
      Pipfile
  3. 1371 13
      Pipfile.lock
  4. 6 4
      config.py
  5. 0 2
      datasets/.gitignore
  6. 0 26
      docker-compose.traefik.yml
  7. 22 16
      docker-compose.yml
  8. 9 0
      requirement.txt
  9. 20 38
      src/app.py
  10. 59 0
      src/chroma_manager.py
  11. 79 0
      src/formatter.py
  12. 48 0
      src/inference.py
  13. 108 0
      src/main.py
  14. 44 0
      src/ollama.py
  15. 33 0
      src/rag.py

+ 1 - 1
.gitignore

@@ -7,10 +7,10 @@
 /.vscode
 /.venv
 /.git
-.DS_Store
 .idea
 .devbox
 .project
 .settings
 Thumbs.db
+**/.DS_Store
 **/__pycache__

+ 3 - 0
Pipfile

@@ -8,6 +8,9 @@ fastapi = "*"
 uvicorn = "*"
 ollama = "*"
 load-dotenv = "*"
+datasets = "*"
+langchain = {extras = ["anthropic"], version = "*"}
+langchain-ollama = "*"
 
 [dev-packages]
 

Fichier diff supprimé car celui-ci est trop grand
+ 1371 - 13
Pipfile.lock


+ 6 - 4
config.py

@@ -1,10 +1,12 @@
-from dotenv import load_dotenv
 import os
 
+from dotenv import load_dotenv
+
 load_dotenv()
 
-# CHROMA_HOST = os.getenv("CHROMA_HOST")
-# CHROMA_PORT = os.getenv("CHROMA_PORT")
+CHROMA_HOST = os.getenv("CHROMA_HOST")
+CHROMA_PORT = os.getenv("CHROMA_PORT")
 
 OLLAMA_URL = os.getenv("OLLAMA_URL")
-OLLAMA_MODEL = os.getenv("OLLAMA_MODEL")
+OLLAMA_MODEL = os.getenv("OLLAMA_MODEL")
+OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY")

+ 0 - 2
datasets/.gitignore

@@ -1,2 +0,0 @@
-*
-!.gitignore

+ 0 - 26
docker-compose.traefik.yml

@@ -1,26 +0,0 @@
-services:
-  application:
-    command: uvicorn src.app:app --reload --host 0.0.0.0 --port 8000
-    labels:
-      - "traefik.enable=true"
-      - "traefik.http.routers.rag-http.entrypoints=web"
-      - "traefik.http.routers.rag-http.rule=Host(`rag.localhost`)"
-      - "traefik.http.routers.rag-http.service=rag"
-      - "traefik.http.services.rag.loadbalancer.server.port=8000"
-      - "traefik.docker.network=proxy"
-    healthcheck:
-      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
-      interval: 30s
-      timeout: 10s
-      retries: 3
-      start_period: 15s
-    networks:
-      default:
-      proxy:
-        aliases:
-          - rag
-
-networks:
-  proxy:
-    name: proxy
-    external: true

+ 22 - 16
docker-compose.yml

@@ -1,27 +1,28 @@
 services:
-  application:
+  inference:
     build:
       context: .
       dockerfile: .docker/python/Dockerfile
+    command: uvicorn src.app:app --reload --host 0.0.0.0 --port 8000
     user: "${APP_UID:-1000}:${APP_GID:-1000}"
     volumes:
       - ./:/app:cached
 
-  ollama:
-    build:
-      context: .docker/ollama
-      dockerfile: Dockerfile
-    volumes:
-      - ollama-data:/root/.ollama
-    environment:
-      OLLAMA_MODELS: ${OLLAMA_MODELS}
-      OLLAMA_MODEL: ${OLLAMA_MODEL}
-    healthcheck:
-      test: ["CMD", "ollama", "ps"]
-      interval: 15s
-      retries: 5
-      start_period: 5s
-      timeout: 3s
+  # ollama:
+  #   build:
+  #     context: .docker/ollama
+  #     dockerfile: Dockerfile
+  #   volumes:
+  #     - ollama-data:/root/.ollama
+  #   environment:
+  #     OLLAMA_MODELS: ${OLLAMA_MODELS}
+  #     OLLAMA_MODEL: ${OLLAMA_MODEL}
+  #   healthcheck:
+  #     test: ["CMD", "ollama", "ps"]
+  #     interval: 15s
+  #     retries: 5
+  #     start_period: 5s
+  #     timeout: 3s
 
   chroma:
     image: chromadb/chroma
@@ -31,3 +32,8 @@ services:
 volumes:
   ollama-data: {}
   chroma-data: {}
+
+networks:
+  default:
+    name: rag
+    external: true

+ 9 - 0
requirement.txt

@@ -0,0 +1,9 @@
+langchain
+langchain-ollama
+langchain-huggingface
+langchain-chroma
+sentence-transformers
+tiktoken
+langfuse
+datasets
+python-dotenv

+ 20 - 38
src/app.py

@@ -1,51 +1,33 @@
-from fastapi import Body, FastAPI
-from ollama import ChatResponse, Client
-
-from config import OLLAMA_MODEL, OLLAMA_URL
-
-client = Client(host=f"{OLLAMA_URL}")
-# ollama.create(model='example', from_='gemma3', system="You are Mario from Super Mario Bros.") // TODO: Для асистента?
-
-
-def message(text: str) -> ChatResponse:
-    return chat(
-        [
-            {
-                "role": "system",
-                "content": "Отвечай строго в формате Markdown. Не нужно пихать везде большие заголовки! Пиши как обычный человек, но красиво оформляй ответ",
-            },
-            {"role": "user", "content": text},
-        ]
-    )
-
-
-def chat(messages: list) -> ChatResponse:
-    return client.chat(model=f"{OLLAMA_MODEL}", messages=messages, think=True)
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
 
+from .formatter import StreamFormatter
+from .inference import Inference
 
+formatter = StreamFormatter()
+inference = Inference()
 app = FastAPI()
 
 
-@app.post("/chat")
-def _(prompt: str = Body(..., embed=True)):
-    print(prompt)
-    response = message(prompt)
-
-    return response
+@app.post("/answer")
+async def _(request: Request):
+    fields = await request.json()
 
+    message = fields.get("message")
+    # history = fields.get("history", [])
 
-@app.post("/generate")
-def generate(prompt: str = Body(..., embed=True)):
-    response = client.generate(
-        model=f"{OLLAMA_MODEL}",
-        prompt=prompt,
-        think=True,
-        options={"temperature": 0.15},
+    return StreamingResponse(
+        formatter.format(inference.answer(message)), media_type="text/plain"
     )
 
-    return response
+
+@app.post("/generate")
+async def _(request: Request):
+    fields = await request.json()
+    message = fields.get("message")
+    return await inference.generate(message)
 
 
 @app.get("/health")
-def health():
+def _():
     return {"status": "ok"}

+ 59 - 0
src/chroma_manager.py

@@ -0,0 +1,59 @@
+from chromadb.config import Settings
+from langchain_chroma import Chroma
+from langchain_huggingface import HuggingFaceEmbeddings
+from langchain_text_splitters import TokenTextSplitter
+
+from datasets import load_dataset
+
+from ..config import CHROMA_HOST, CHROMA_PORT
+
+
+class ChromaManager:
+    def __init__(
+        self,
+        embeddings: str = "jinaai/jina-embeddings-v4",
+        collection_name: str = "terraria",
+        batch_size: int = 2500,
+        host: str = CHROMA_HOST,  # pyright: ignore[reportArgumentType]
+        port: int = CHROMA_PORT,  # pyright: ignore[reportArgumentType]
+        dataset: str = "lparkourer10/terraria-wiki",
+    ):
+        self.splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=32)
+
+        self.embeddings = HuggingFaceEmbeddings(model=embeddings)
+        self.collection_name = collection_name
+        self.batch_size = batch_size
+
+        self.settings = Settings(
+            chroma_api_impl="chromadb.api.fastapi.FastAPI",
+            chroma_server_host=host,
+            chroma_server_http_port=str(port),
+        )
+
+        self.vectordb = Chroma(
+            embedding_function=self.embeddings,
+            collection_name=self.collection_name,
+            client_settings=self.settings,
+        )
+
+        if dataset and self.is_empty():
+            self.insert(self.load(dataset))
+
+    def is_empty(self) -> bool:
+        data = self.vectordb.get(include=["metadatas"])["metadatas"]
+        return len(data) == 0
+
+    def insert(self, dataset) -> None:
+        for i in range(0, len(dataset), self.batch_size):
+            batch = dataset[i : i + self.batch_size]
+            documents = self.splitter.create_documents(
+                texts=batch["question"],
+                metadatas=[{"answer": a} for a in batch["answer"]],
+            )
+            self.vectordb.add_documents(documents)
+
+    def load(self, name: str, split: str = "train"):
+        return load_dataset(name, split=split)
+
+    def retriever(self, count: int = 5):
+        return self.vectordb.as_retriever(search_kwargs={"k": count})

+ 79 - 0
src/formatter.py

@@ -0,0 +1,79 @@
+import json
+from typing import Any, AsyncIterator, List
+
+
+class StreamFormatter:
+    name = "inference_rag_"
+
+    def __init__(
+        self,
+        keys: List[str] = None,  # pyright: ignore[reportArgumentType]
+    ) -> None:
+        self.name = ""
+        self.keys = keys or [
+            "model",
+            "done",
+            "done_reason",
+            "total_duration",
+            "load_duration",
+            "prompt_eval_count",
+            "prompt_eval_duration",
+            "eval_count",
+            "eval_duration",
+        ]
+
+    @staticmethod
+    def wrap(tag: str, content: str) -> str:
+        if content is None:
+            content = ""
+        return (
+            f"<|{StreamFormatter.name}{tag}|>{content}<|{StreamFormatter.name}{tag}|>\n"
+        )
+
+    @staticmethod
+    def fields(obj, keys: List[str]) -> str:
+        fields = {key: getattr(obj, key, None) for key in keys}
+        return f"<|{StreamFormatter.name}fields|>{json.dumps(fields, ensure_ascii=False)}<|{StreamFormatter.name}fields|>\n"
+
+    async def format(self, stream: AsyncIterator[Any]) -> AsyncIterator[str]:
+        last = None
+        state = None
+
+        transitions = {
+            ("think", None): lambda: f"<|{StreamFormatter.name}think|>\n",
+            ("content", None): lambda: f"<|{StreamFormatter.name}content|>\n",
+            (
+                "think",
+                "content",
+            ): lambda: f"<|{StreamFormatter.name}content|>\n<|{StreamFormatter.name}think|>\n",
+            (
+                "content",
+                "think",
+            ): lambda: f"<|{StreamFormatter.name}think|>\n<|{StreamFormatter.name}content|>\n",
+        }
+
+        async for chunk in stream:
+            last = chunk
+            thinking = chunk.message.thinking
+            content = chunk.message.content
+
+            target = (
+                "think"
+                if thinking is not None
+                else "content"
+                if content is not None
+                else None
+            )
+            if target and target != state:
+                yield transitions.get((target, state), lambda: "")()
+                state = target
+
+            yield thinking or content or ""
+
+        if state == "think":
+            yield f"<|{StreamFormatter.name}think|>\n"
+        elif state == "content":
+            yield f"<|{StreamFormatter.name}content|>\n"
+
+        if last:
+            yield self.fields(last, self.keys)

+ 48 - 0
src/inference.py

@@ -0,0 +1,48 @@
+from config import OLLAMA_API_KEY, OLLAMA_MODEL, OLLAMA_URL
+from ollama import AsyncClient, GenerateResponse
+
+
+class Inference:
+    def __init__(
+        self,
+        model: str = OLLAMA_MODEL,  # pyright: ignore[reportArgumentType]
+        host: str = OLLAMA_URL,  # pyright: ignore[reportArgumentType]
+        api_key: str = OLLAMA_API_KEY,  # pyright: ignore[reportArgumentType]
+    ):
+        self.model = model
+        self.client = AsyncClient(
+            host=host, headers={"Authorization": f"Bearer {api_key}"}
+        )
+
+    async def generate(self, message: str) -> GenerateResponse:
+        return await self.client.generate(
+            model=self.model, prompt=message, think=False, options={"temperature": 0.15}
+        )
+
+    async def stream(self, messages: list):
+        stream = await self.client.chat(
+            model=self.model,
+            messages=messages,
+            think="low",
+            stream=True,
+            options={"temperature": 0.15},
+        )
+
+        async for chunk in stream:
+            yield chunk
+
+    async def answer(self, text: str):
+        messages = [
+            {
+                "role": "system",
+                "content": (
+                    "Используй формат Markdown для ответов, но не злоупотребляй им. "
+                    "Применяй его когда нужно веделить важные моменты, заголовки, таблицы, списки и тд. "
+                    "В обычном тексте используй простые форматы, такие как жирный и курсив. "
+                ),
+            },
+            {"role": "user", "content": text},
+        ]
+
+        async for chunk in self.stream(messages):
+            yield chunk

+ 108 - 0
src/main.py

@@ -0,0 +1,108 @@
+from langchain.chains import create_history_aware_retriever, create_retrieval_chain
+from langchain.chains.combine_documents import create_stuff_documents_chain
+from langchain_core.prompts import (
+    ChatPromptTemplate,
+    MessagesPlaceholder,
+    PromptTemplate,
+)
+from langchain_ollama import OllamaLLM
+
+from chroma_manager import ChromaManager
+from config import OLLAMA_URL
+
+print("Инициализация LLM...")
+llm = OllamaLLM(
+    model="llama3.1:8b",
+    base_url=f"{OLLAMA_URL}",
+    temperature=0.15,
+    num_predict=1024,
+    reasoning=False,
+)
+
+print("Инициализация Chroma...")
+retriever = ChromaManager().retriever()
+
+contextualize_q_system_prompt = (
+    "Given a chat history and the latest user question "
+    "which might reference context in the chat history, "
+    "formulate a standalone question which can be understood "
+    "without the chat history. Do NOT answer the question, just "
+    "reformulate it if needed and otherwise return it as is."
+)
+
+contextualize_q_prompt = ChatPromptTemplate.from_messages(
+    [
+        ("system", contextualize_q_system_prompt),
+        MessagesPlaceholder("chat_history"),
+        ("human", "{input}"),
+    ]
+)
+
+history_aware_retriever = create_history_aware_retriever(
+    llm, retriever, contextualize_q_prompt
+)
+
+qa_system_prompt = (
+    "You are an assistant for question-answering tasks. Use "
+    "the following pieces of retrieved context to answer the "
+    "question. If you don't know the answer, just say that you "
+    "don't know. Use three sentences maximum and keep the answer "
+    "concise."
+    "{context}"
+)
+
+qa_prompt = ChatPromptTemplate.from_messages(
+    [
+        ("system", qa_system_prompt),
+        MessagesPlaceholder("chat_history"),
+        ("human", "{input}"),
+    ]
+)
+
+question_answer_chain = create_stuff_documents_chain(
+    llm, qa_prompt, document_prompt=PromptTemplate.from_template("{answer}")
+)
+rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
+
+chat_history = []
+
+
+def ask(question: str):
+    print("=" * 100)
+    print("Вопрос пользователя:", question)
+
+    result = {"question": question, "answer": ""}
+
+    print("=" * 100)
+    print("Ответ модели:")
+    for chunk in rag_chain.stream(
+        {"input": question, "chat_history": chat_history},
+        config={"callbacks": [langfuse_handler]},
+    ):
+        if "answer" in chunk:
+            print(chunk["answer"], end="", flush=True)
+            result["answer"] += chunk["answer"]
+    print()
+
+    chat_history.append(("human", result["question"]))
+    chat_history.append(("ai", result["answer"]))
+
+
+def main():
+    questions = [
+        # 'Какие есть боссы в Террарии?',
+        # 'Какой финальный босс?',
+        # 'И как его победить?',
+        # 'Какую броню на него использовать?',
+        "What bosses are there in Terraria?",
+        "What is the final boss?",
+        "And how to defeat it?",
+        "What armor should be used against it?",
+    ]
+
+    for question in questions:
+        ask(question)
+
+
+if __name__ == "__main__":
+    main()

+ 44 - 0
src/ollama.py

@@ -0,0 +1,44 @@
+from config import OLLAMA_API_KEY, OLLAMA_MODEL, OLLAMA_URL
+from ollama import AsyncClient, GenerateResponse
+
+client = AsyncClient(
+    host=f"{OLLAMA_URL}", headers={"Authorization": f"Bearer {OLLAMA_API_KEY}"}
+)
+# ollama.create(model='example', from_='gemma3', system="You are Mario from Super Mario Bros.") // TODO: Для асистента?
+
+
+async def generate(message: str) -> GenerateResponse:
+    return await client.generate(
+        model=f"{OLLAMA_MODEL}",
+        prompt=message,
+        think=False,
+        options={"temperature": 0.15},
+    )
+
+
+async def request(messages: list):
+    stream = await client.chat(
+        model=f"{OLLAMA_MODEL}",
+        messages=messages,
+        think=True,
+        stream=True,
+        options={"temperature": 0.15},
+    )
+
+    async for chunk in stream:
+        print(chunk)
+        content = chunk["message"]["content"]
+        yield content
+
+
+async def answer(text: str):
+    messages = [
+        {
+            "role": "system",
+            "content": "Отвечай строго в формате Markdown. Не нужно пихать везде большие заголовки! Пиши как обычный человек, но красиво оформляй ответ",
+        },
+        {"role": "user", "content": text},
+    ]
+
+    async for chunk in request(messages):
+        yield chunk

+ 33 - 0
src/rag.py

@@ -0,0 +1,33 @@
+from langchain.agents import create_agent
+from langchain_ollama import ChatOllama
+
+llm = ChatOllama(
+    model="llama3.1",
+    base_url="http://localhost:11434",  # URL Ollama-сервера
+    client_kwargs={  # для синхронного клиента
+        "timeout": 30,
+        "headers": {"Authorization": "Bearer …"},
+    },
+    async_client_kwargs={  # для асинхронного клиента
+        "timeout": 60,
+        "headers": {"Authorization": "Bearer …"},
+    },
+)
+
+
+def get_weather(city: str) -> str:
+    """Get weather for a given city."""
+    return f"It's always sunny in {city}!"
+
+
+agent = create_agent(
+    model="claude-sonnet-4-5-20250929",
+    tools=[get_weather],
+    system_prompt="You are a helpful assistant",
+)
+
+# Run the agent
+result = agent.invoke(
+    {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
+)
+print(result)

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff