| """ |
| Routes for the Image Similarity Search API |
| Contains all endpoints for the application using your original route implementation |
| """ |
|
|
| import uuid |
| import base64 |
| import io |
| from typing import List, Optional |
| from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path |
| from pydantic import BaseModel |
| from PIL import Image |
|
|
| from services.embedding_service import ImageEmbeddingModel |
| from services.vector_db_service import VectorDatabaseClient |
|
|
|
|
| class Base64ImageRequest(BaseModel): |
| """Request model for base64 encoded images""" |
| image_data: str |
|
|
|
|
| def register_routes( |
| app: FastAPI, |
| embedding_model: ImageEmbeddingModel, |
| vector_db: VectorDatabaseClient, |
| ): |
| """Register all routes with the FastAPI app""" |
|
|
| @app.api_route("/", methods=["GET", "HEAD"]) |
| async def read_root(): |
| return {"status": "API running"} |
| |
| @app.post("/add-image/") |
| async def add_image( |
| file: UploadFile = File(...), |
| item_name: str = Form(...), |
| design_name: str = Form(...), |
| item_price: float = Form(...) |
| ): |
| """Upload an image with product details and store its embedding""" |
| |
| |
| embedding = await embedding_model.get_embedding_from_upload(file) |
| |
| |
| image_id = str(uuid.uuid4()) |
| |
| |
| payload = { |
| "filename": file.filename, |
| "item_name": item_name, |
| "design_name": design_name, |
| "item_price": item_price |
| } |
| |
| |
| vector_db.add_image(image_id, embedding, payload) |
| |
| return {"message": "Image added successfully", "id": image_id} |
| |
| @app.post("/add-images-from-folder/") |
| async def add_images_from_folder(folder_path: str): |
| """Process and add all images from a specified folder""" |
| embeddings = embedding_model.get_embeddings_from_folder(folder_path) |
| return {"embeddings": embeddings} |
| |
| @app.post("/search-by-image/") |
| async def search_by_image(file: UploadFile = File(...)): |
| """Search for similar images by uploading a file""" |
| |
| |
| embedding = await embedding_model.get_embedding_from_upload(file) |
| |
| |
| results = vector_db.search_by_vector(embedding, limit=1) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return results |
| |
| @app.post("/search-by-image-scan/") |
| async def search_by_image_scan(request: Base64ImageRequest): |
| """Search for similar images using a base64 encoded image""" |
| |
| image_data = request.image_data |
| image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data) |
| |
| |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
| |
| embedding = embedding_model.get_embedding_from_pil(image) |
| |
| |
| results = vector_db.search_by_vector(embedding, limit=1) |
| |
| return results |
| |
| @app.get("/collections") |
| def list_collections(): |
| """List all available collections in the vector database""" |
| return vector_db.list_collections() |