from flask import Flask, request, jsonify, send_from_directory from flask_sqlalchemy import SQLAlchemy from flask_cors import CORS from transformers import AutoModel, AutoTokenizer import faiss import numpy as np import torch import pickle import re import onnxruntime as ort import base64 from PIL import Image import datetime import io import os # 初始化 Flask 应用 app = Flask(__name__) CORS(app) # 配置 SQLite 数据库 BASE_DIR = os.path.abspath(os.path.dirname(__file__)) app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///D:/STUDY/Project/jizhouyao/User/users.db' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI']) # 初始化数据库 db = SQLAlchemy(app) print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI']) #----------登录服务----------# # 定义用户模型 class User(db.Model): __tablename__ = 'users' id = db.Column(db.Integer, primary_key=True) username = db.Column(db.String(80), unique=True, nullable=False) password = db.Column(db.String(120), nullable=False) avatar = db.Column(db.String(200), nullable=True) # 登录接口 @app.route('/api/login', methods=['POST']) def login(): data = request.get_json() username = data.get('username') password = data.get('password') # 检查用户是否存在于数据库 user = User.query.filter_by(username=username, password=password).first() if user: # 拼接完整的 avatar URL avatar_url = f"http://{request.host}/{user.avatar}" if user.avatar else None # 返回用户的 id、username 和 avatar return jsonify({ "success": True, "message": "登录成功!", "data": { "id": user.id, "username": user.username, "avatar": avatar_url } }) else: return jsonify({"success": False, "message": "账号或密码错误!"}) @app.route('/avatar/') def get_avatar(filename): avatar_dir = os.path.join(BASE_DIR, 'User', 'avatar')# 头像存储目录 # 检查文件是否存在 file_path = os.path.join(avatar_dir, filename) if not os.path.exists(file_path): print(f"文件不存在:{file_path}") # 打印日志 return jsonify({"success": False, "message": "File not found"}), 404 # 提供头像文件 return send_from_directory(avatar_dir, filename) #----------RAG知识库查询服务----------# # 指定本地模型路径和资源路径 model_path = r"D:/STUDY/Project/jizhouyao/RAG/models/all-MiniLM-L6-v2"# 本地存储的预训练模型路径 resources_path = r"D:/STUDY/Project/jizhouyao/RAG/resources"# 文本文件存储路径 faiss_index_path = r"D:/STUDY/Project/jizhouyao/RAG/index/vector_index.faiss"# FAISS 索引的保存路径 texts_pickle_path = r"D:/STUDY/Project/jizhouyao/RAG/index/texts.pkl"# 文本数据保存路径 # 检查路径是否存在 if not os.path.exists(model_path): raise FileNotFoundError(f"模型路径不存在: {model_path}") if not os.path.exists(resources_path): raise FileNotFoundError(f"资源文件夹不存在: {resources_path}") # 加载模型和分词器 print("加载模型和分词器...") tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) model = AutoModel.from_pretrained(model_path, local_files_only=True) # 加载 FAISS 索引和文本内容 print("加载 FAISS 索引...") index = faiss.read_index(faiss_index_path) print("加载文本内容...") with open(texts_pickle_path, 'rb') as f: all_sentences = pickle.load(f) # 拆分文件内容为句子列表 def split_content_to_sentences(all_sentences): flat_sentences = [] for entry in all_sentences: file_name = entry["name"] content = entry["content"] sentences = re.split(r'[。!?;::\n]', content) # 基于标点符号分割 sentences = [sentence.strip() for sentence in sentences if sentence.strip()] # 去除空白句子 # 将句子和元数据存入 flat_sentences for sentence in sentences: flat_sentences.append({ "file": file_name, "sentence": sentence }) return flat_sentences flat_sentences = split_content_to_sentences(all_sentences) # 查询函数 def query_faiss(query, top_k): """ 在 FAISS 索引中查询与输入 query 最相关的句子 """ # 将 query 转换为嵌入向量 inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) # 提取 CLS 向量 query_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy() query_embedding = query_embedding.reshape(1, -1).astype(np.float32) faiss.normalize_L2(query_embedding) # 查询 FAISS 索引,选取 top_k * 3 条候选句子 distances, indices = index.search(query_embedding, top_k * 3) # 收集有效结果 results = [] valid_count = 0 # 用于跟踪已收集到的有效结果数量 for i, idx in enumerate(indices[0]): if valid_count >= top_k: # 如果已经收集到足够的有效结果,则停止 break if idx != -1 and idx < len(flat_sentences): # 确保索引有效 sentence_data = flat_sentences[idx] # 从 flat_sentences 获取句子数据 sentence_text = sentence_data["sentence"] # 获取句子内容 # 过滤掉字数小于 15 的句子 if len(sentence_text) >= 15: results.append({ "rank": valid_count + 1, # 有效结果的排名 "file": sentence_data["file"], # 文件名 "sentence": sentence_text, # 具体句子 "score": float(distances[0][i]) # 匹配分数 }) valid_count += 1 else: continue else: print(f"跳过无效索引: {idx}") # 打印无效索引 print(f"返回的查询结果: {results}") return results # 查询路由 @app.route('/query', methods=['POST']) def query_endpoint(): """ 接收前端发送的用户问题,在知识库中查询,并将查询结果返回给前端 """ try: # 获取 POST 请求中的 JSON 数据 data = request.get_json() query_text = data.get("query", "") # 获取用户问题 if not query_text: return jsonify({"error": "No query provided"}), 400 # 检查是否提供了 query # 在知识库中查询 results = query_faiss(query_text, top_k=20) # 返回查询结果 return jsonify({"query": query_text, "results": results}) except Exception as e: return jsonify({"error": f"服务内部错误: {str(e)}"}), 500 #----------陶瓷识别服务----------# # 加载 ONNX 模型 model_path = r"D:/STUDY/Project/jizhouyao/image-recognition/index/mobilenet_model.onnx" try: # 使用 ONNX Runtime 加载模型 ort_session = ort.InferenceSession(model_path) print("ONNX 模型加载成功!") except Exception as e: print(f"ONNX 模型加载失败: {e}") exit(1) # 类别标签 class_labels = ["兔毫盏", "凤纹盏", "剪纸贴花盏", "木叶天目盏","梅瓶", "梅纹盏", "玳瑁釉盏", "鹧鸪斑盏", "黑釉盏"] # 图像预处理函数 def preprocess_image(image: Image.Image, input_shape: tuple): """ 将 PIL 图像处理为 ONNX 模型输入格式 :param image: PIL 图像 :param input_shape: 模型期望的输入形状 (N, C, H, W) :return: 预处理后的图像 numpy 数组 """ height, width = input_shape[2], input_shape[3] # 模型输入的高度和宽度 image = image.resize((width, height)) # 调整图片大小 image = np.array(image).astype('float32') / 255.0 # 归一化 image = np.transpose(image, (2, 0, 1)) # HWC -> CHW image = np.expand_dims(image, axis=0) # 添加 batch 维度 return image # Softmax 函数 def softmax(logits): exp_values = np.exp(logits - np.max(logits)) # 稳定的 softmax 避免溢出 probabilities = exp_values / np.sum(exp_values) return probabilities # 预测路由 @app.route('/predict', methods=['POST']) def predict_route(): try: # 获取前端传来的 Base64 图片数据 input_data = request.json.get("data") if not input_data: return jsonify({"error": "No input data provided"}), 400 # 去掉 Base64 数据前缀并解码 base64_image = input_data.split(",")[1] image_data = base64.b64decode(base64_image) # 使用 PIL 打开图片 image = Image.open(io.BytesIO(image_data)).convert("RGB") # 获取模型输入形状并预处理 input_name = ort_session.get_inputs()[0].name input_shape = ort_session.get_inputs()[0].shape # 通常是 (N, C, H, W) preprocessed_image = preprocess_image(image, input_shape) # 打印预处理后图像的形状 # 使用 ONNX 模型进行推理 outputs = ort_session.run(None, {input_name: preprocessed_image}) # 打印模型输出的所有结果 # 模型输出 logits logits = outputs[0][0] # 计算 softmax 概率分布 probabilities = softmax(logits) # 获取预测类别及置信度 predicted_class = int(np.argmax(probabilities)) # 预测类别索引 confidence = float(np.max(probabilities)) # 置信度 predicted_label = class_labels[predicted_class] # 预测类别名称 # 组合类别和对应概率 class_probabilities = [ {"class": class_labels[i], "probability": float(probabilities[i])} for i in range(len(class_labels)) ] # 返回预测结果 prediction = { "predicted_class": predicted_class, "predicted_label": predicted_label, "confidence": confidence, "class_probabilities": class_probabilities } print(jsonify({"prediction": prediction})) return jsonify({"prediction": prediction}) except Exception as e: # 捕获异常并记录日志 print(f"Error during prediction: {e}") return jsonify({"error": str(e)}), 500 #----------保存陶瓷识别记录服务----------# # 定义检测记录的数据库模型 class CeramicDetectionRecord(db.Model): __tablename__ = 'ceramic_detection_record' # 新建表的名称 id = db.Column(db.Integer, primary_key=True) # 主键 nickname = db.Column(db.String(80), nullable=False) # 用户昵称 detection_time = db.Column(db.DateTime, nullable=False) # 检测时间 result = db.Column(db.String(120), nullable=False) # 检测结果 confidence = db.Column(db.String(10), nullable=False) # 置信度 # 保存检测记录路由 @app.route('/save_detection', methods=['POST']) def save_detection(): try: # 从请求中获取 JSON 数据 data = request.get_json() # 提取记录数据 nickname = data.get('nickname') detection_time = data.get('detection_time') result = data.get('result') confidence = data.get('confidence') # 检查必填字段 if not all([nickname, detection_time, result, confidence]): return jsonify({"status": "error", "message": "缺少必要字段"}), 400 # 替换 'Z' 为 '+00:00' 以支持 ISO 格式 if detection_time.endswith("Z"): detection_time = detection_time.replace("Z", "+00:00") # 转换检测时间为 datetime 对象 try: detection_time = datetime.datetime.fromisoformat(detection_time) except ValueError as ve: print("检测时间格式错误:", detection_time) # 打印时间格式错误 return jsonify({"status": "error", "message": "检测时间格式错误"}), 400 # 创建记录对象 record = CeramicDetectionRecord( nickname=nickname, detection_time=detection_time, result=result, confidence=confidence, ) # 保存到数据库 db.session.add(record) db.session.commit() return jsonify({"status": "success", "message": "识别记录保存成功"}) except Exception as e: return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500 #----------渲染陶瓷识别记录服务----------# # 按 nickname 查询记录的路由 @app.route('/get_detection_records', methods=['GET']) def get_detection_records_by_nickname(): try: # 从请求中获取 nickname 参数 nickname = request.args.get('nickname') # 检查 nickname 是否提供 if not nickname: return jsonify({"status": "error", "message": "缺少参数: nickname"}), 400 # 从数据库中查询对应的记录 records = CeramicDetectionRecord.query.filter_by(nickname=nickname).all() # 格式化查询结果 results = [ { "id": record.id, "result": record.result, "confidence": record.confidence, } for record in records ] # 返回查询结果 return jsonify({"status": "success", "data": results}) except Exception as e: # 捕获错误并返回 return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500 # 根路由 @app.route('/') def home(): return "Welcome to the API!" # 屏蔽 favicon.ico 请求 @app.route('/favicon.ico') def favicon(): return "", 204 # 主函数 if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=5000)