server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from flask import Flask, request, jsonify, send_from_directory
  2. from flask_sqlalchemy import SQLAlchemy
  3. from flask_cors import CORS
  4. from transformers import AutoModel, AutoTokenizer
  5. import faiss
  6. import numpy as np
  7. import torch
  8. import pickle
  9. import re
  10. import onnxruntime as ort
  11. import base64
  12. from PIL import Image
  13. import datetime
  14. import io
  15. import os
  16. # 初始化 Flask 应用
  17. app = Flask(__name__)
  18. CORS(app)
  19. # 配置 SQLite 数据库
  20. BASE_DIR = os.path.abspath(os.path.dirname(__file__))
  21. app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///D:/STUDY/Project/jizhouyao/User/users.db'
  22. app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
  23. print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI'])
  24. # 初始化数据库
  25. db = SQLAlchemy(app)
  26. print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI'])
  27. #----------登录服务----------#
  28. # 定义用户模型
  29. class User(db.Model):
  30. __tablename__ = 'users'
  31. id = db.Column(db.Integer, primary_key=True)
  32. username = db.Column(db.String(80), unique=True, nullable=False)
  33. password = db.Column(db.String(120), nullable=False)
  34. avatar = db.Column(db.String(200), nullable=True)
  35. # 登录接口
  36. @app.route('/api/login', methods=['POST'])
  37. def login():
  38. data = request.get_json()
  39. username = data.get('username')
  40. password = data.get('password')
  41. # 检查用户是否存在于数据库
  42. user = User.query.filter_by(username=username, password=password).first()
  43. if user:
  44. # 拼接完整的 avatar URL
  45. avatar_url = f"http://{request.host}/{user.avatar}" if user.avatar else None
  46. # 返回用户的 id、username 和 avatar
  47. return jsonify({
  48. "success": True,
  49. "message": "登录成功!",
  50. "data": {
  51. "id": user.id,
  52. "username": user.username,
  53. "avatar": avatar_url
  54. }
  55. })
  56. else:
  57. return jsonify({"success": False, "message": "账号或密码错误!"})
  58. @app.route('/avatar/<filename>')
  59. def get_avatar(filename):
  60. avatar_dir = os.path.join(BASE_DIR, 'User', 'avatar')# 头像存储目录
  61. # 检查文件是否存在
  62. file_path = os.path.join(avatar_dir, filename)
  63. if not os.path.exists(file_path):
  64. print(f"文件不存在:{file_path}") # 打印日志
  65. return jsonify({"success": False, "message": "File not found"}), 404
  66. # 提供头像文件
  67. return send_from_directory(avatar_dir, filename)
  68. #----------RAG知识库查询服务----------#
  69. # 指定本地模型路径和资源路径
  70. model_path = r"D:/STUDY/Project/jizhouyao/RAG/models/all-MiniLM-L6-v2"# 本地存储的预训练模型路径
  71. resources_path = r"D:/STUDY/Project/jizhouyao/RAG/resources"# 文本文件存储路径
  72. faiss_index_path = r"D:/STUDY/Project/jizhouyao/RAG/index/vector_index.faiss"# FAISS 索引的保存路径
  73. texts_pickle_path = r"D:/STUDY/Project/jizhouyao/RAG/index/texts.pkl"# 文本数据保存路径
  74. # 检查路径是否存在
  75. if not os.path.exists(model_path):
  76. raise FileNotFoundError(f"模型路径不存在: {model_path}")
  77. if not os.path.exists(resources_path):
  78. raise FileNotFoundError(f"资源文件夹不存在: {resources_path}")
  79. # 加载模型和分词器
  80. print("加载模型和分词器...")
  81. tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
  82. model = AutoModel.from_pretrained(model_path, local_files_only=True)
  83. # 加载 FAISS 索引和文本内容
  84. print("加载 FAISS 索引...")
  85. index = faiss.read_index(faiss_index_path)
  86. print("加载文本内容...")
  87. with open(texts_pickle_path, 'rb') as f:
  88. all_sentences = pickle.load(f)
  89. # 拆分文件内容为句子列表
  90. def split_content_to_sentences(all_sentences):
  91. flat_sentences = []
  92. for entry in all_sentences:
  93. file_name = entry["name"]
  94. content = entry["content"]
  95. sentences = re.split(r'[。!?;::\n]', content) # 基于标点符号分割
  96. sentences = [sentence.strip() for sentence in sentences if sentence.strip()] # 去除空白句子
  97. # 将句子和元数据存入 flat_sentences
  98. for sentence in sentences:
  99. flat_sentences.append({
  100. "file": file_name,
  101. "sentence": sentence
  102. })
  103. return flat_sentences
  104. flat_sentences = split_content_to_sentences(all_sentences)
  105. # 查询函数
  106. def query_faiss(query, top_k):
  107. """
  108. 在 FAISS 索引中查询与输入 query 最相关的句子
  109. """
  110. # 将 query 转换为嵌入向量
  111. inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512)
  112. with torch.no_grad():
  113. outputs = model(**inputs)
  114. # 提取 CLS 向量
  115. query_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
  116. query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
  117. faiss.normalize_L2(query_embedding)
  118. # 查询 FAISS 索引,选取 top_k * 3 条候选句子
  119. distances, indices = index.search(query_embedding, top_k * 3)
  120. # 收集有效结果
  121. results = []
  122. valid_count = 0 # 用于跟踪已收集到的有效结果数量
  123. for i, idx in enumerate(indices[0]):
  124. if valid_count >= top_k: # 如果已经收集到足够的有效结果,则停止
  125. break
  126. if idx != -1 and idx < len(flat_sentences): # 确保索引有效
  127. sentence_data = flat_sentences[idx] # 从 flat_sentences 获取句子数据
  128. sentence_text = sentence_data["sentence"] # 获取句子内容
  129. # 过滤掉字数小于 15 的句子
  130. if len(sentence_text) >= 15:
  131. results.append({
  132. "rank": valid_count + 1, # 有效结果的排名
  133. "file": sentence_data["file"], # 文件名
  134. "sentence": sentence_text, # 具体句子
  135. "score": float(distances[0][i]) # 匹配分数
  136. })
  137. valid_count += 1
  138. else:
  139. continue
  140. else:
  141. print(f"跳过无效索引: {idx}") # 打印无效索引
  142. print(f"返回的查询结果: {results}")
  143. return results
  144. # 查询路由
  145. @app.route('/query', methods=['POST'])
  146. def query_endpoint():
  147. """
  148. 接收前端发送的用户问题,在知识库中查询,并将查询结果返回给前端
  149. """
  150. try:
  151. # 获取 POST 请求中的 JSON 数据
  152. data = request.get_json()
  153. query_text = data.get("query", "") # 获取用户问题
  154. if not query_text:
  155. return jsonify({"error": "No query provided"}), 400 # 检查是否提供了 query
  156. # 在知识库中查询
  157. results = query_faiss(query_text, top_k=20)
  158. # 返回查询结果
  159. return jsonify({"query": query_text, "results": results})
  160. except Exception as e:
  161. return jsonify({"error": f"服务内部错误: {str(e)}"}), 500
  162. #----------陶瓷识别服务----------#
  163. # 加载 ONNX 模型
  164. model_path = r"D:/STUDY/Project/jizhouyao/image-recognition/index/mobilenet_model.onnx"
  165. try:
  166. # 使用 ONNX Runtime 加载模型
  167. ort_session = ort.InferenceSession(model_path)
  168. print("ONNX 模型加载成功!")
  169. except Exception as e:
  170. print(f"ONNX 模型加载失败: {e}")
  171. exit(1)
  172. # 类别标签
  173. class_labels = ["兔毫盏", "凤纹盏", "剪纸贴花盏", "木叶天目盏","梅瓶", "梅纹盏", "玳瑁釉盏", "鹧鸪斑盏", "黑釉盏"]
  174. # 图像预处理函数
  175. def preprocess_image(image: Image.Image, input_shape: tuple):
  176. """
  177. 将 PIL 图像处理为 ONNX 模型输入格式
  178. :param image: PIL 图像
  179. :param input_shape: 模型期望的输入形状 (N, C, H, W)
  180. :return: 预处理后的图像 numpy 数组
  181. """
  182. height, width = input_shape[2], input_shape[3] # 模型输入的高度和宽度
  183. image = image.resize((width, height)) # 调整图片大小
  184. image = np.array(image).astype('float32') / 255.0 # 归一化
  185. image = np.transpose(image, (2, 0, 1)) # HWC -> CHW
  186. image = np.expand_dims(image, axis=0) # 添加 batch 维度
  187. return image
  188. # Softmax 函数
  189. def softmax(logits):
  190. exp_values = np.exp(logits - np.max(logits)) # 稳定的 softmax 避免溢出
  191. probabilities = exp_values / np.sum(exp_values)
  192. return probabilities
  193. # 预测路由
  194. @app.route('/predict', methods=['POST'])
  195. def predict_route():
  196. try:
  197. # 获取前端传来的 Base64 图片数据
  198. input_data = request.json.get("data")
  199. if not input_data:
  200. return jsonify({"error": "No input data provided"}), 400
  201. # 去掉 Base64 数据前缀并解码
  202. base64_image = input_data.split(",")[1]
  203. image_data = base64.b64decode(base64_image)
  204. # 使用 PIL 打开图片
  205. image = Image.open(io.BytesIO(image_data)).convert("RGB")
  206. # 获取模型输入形状并预处理
  207. input_name = ort_session.get_inputs()[0].name
  208. input_shape = ort_session.get_inputs()[0].shape # 通常是 (N, C, H, W)
  209. preprocessed_image = preprocess_image(image, input_shape)
  210. # 打印预处理后图像的形状
  211. # 使用 ONNX 模型进行推理
  212. outputs = ort_session.run(None, {input_name: preprocessed_image})
  213. # 打印模型输出的所有结果
  214. # 模型输出 logits
  215. logits = outputs[0][0]
  216. # 计算 softmax 概率分布
  217. probabilities = softmax(logits)
  218. # 获取预测类别及置信度
  219. predicted_class = int(np.argmax(probabilities)) # 预测类别索引
  220. confidence = float(np.max(probabilities)) # 置信度
  221. predicted_label = class_labels[predicted_class] # 预测类别名称
  222. # 组合类别和对应概率
  223. class_probabilities = [
  224. {"class": class_labels[i], "probability": float(probabilities[i])}
  225. for i in range(len(class_labels))
  226. ]
  227. # 返回预测结果
  228. prediction = {
  229. "predicted_class": predicted_class,
  230. "predicted_label": predicted_label,
  231. "confidence": confidence,
  232. "class_probabilities": class_probabilities
  233. }
  234. print(jsonify({"prediction": prediction}))
  235. return jsonify({"prediction": prediction})
  236. except Exception as e:
  237. # 捕获异常并记录日志
  238. print(f"Error during prediction: {e}")
  239. return jsonify({"error": str(e)}), 500
  240. #----------保存陶瓷识别记录服务----------#
  241. # 定义检测记录的数据库模型
  242. class CeramicDetectionRecord(db.Model):
  243. __tablename__ = 'ceramic_detection_record' # 新建表的名称
  244. id = db.Column(db.Integer, primary_key=True) # 主键
  245. nickname = db.Column(db.String(80), nullable=False) # 用户昵称
  246. detection_time = db.Column(db.DateTime, nullable=False) # 检测时间
  247. result = db.Column(db.String(120), nullable=False) # 检测结果
  248. confidence = db.Column(db.String(10), nullable=False) # 置信度
  249. # 保存检测记录路由
  250. @app.route('/save_detection', methods=['POST'])
  251. def save_detection():
  252. try:
  253. # 从请求中获取 JSON 数据
  254. data = request.get_json()
  255. # 提取记录数据
  256. nickname = data.get('nickname')
  257. detection_time = data.get('detection_time')
  258. result = data.get('result')
  259. confidence = data.get('confidence')
  260. # 检查必填字段
  261. if not all([nickname, detection_time, result, confidence]):
  262. return jsonify({"status": "error", "message": "缺少必要字段"}), 400
  263. # 替换 'Z' 为 '+00:00' 以支持 ISO 格式
  264. if detection_time.endswith("Z"):
  265. detection_time = detection_time.replace("Z", "+00:00")
  266. # 转换检测时间为 datetime 对象
  267. try:
  268. detection_time = datetime.datetime.fromisoformat(detection_time)
  269. except ValueError as ve:
  270. print("检测时间格式错误:", detection_time) # 打印时间格式错误
  271. return jsonify({"status": "error", "message": "检测时间格式错误"}), 400
  272. # 创建记录对象
  273. record = CeramicDetectionRecord(
  274. nickname=nickname,
  275. detection_time=detection_time,
  276. result=result,
  277. confidence=confidence,
  278. )
  279. # 保存到数据库
  280. db.session.add(record)
  281. db.session.commit()
  282. return jsonify({"status": "success", "message": "识别记录保存成功"})
  283. except Exception as e:
  284. return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500
  285. #----------渲染陶瓷识别记录服务----------#
  286. # 按 nickname 查询记录的路由
  287. @app.route('/get_detection_records', methods=['GET'])
  288. def get_detection_records_by_nickname():
  289. try:
  290. # 从请求中获取 nickname 参数
  291. nickname = request.args.get('nickname')
  292. # 检查 nickname 是否提供
  293. if not nickname:
  294. return jsonify({"status": "error", "message": "缺少参数: nickname"}), 400
  295. # 从数据库中查询对应的记录
  296. records = CeramicDetectionRecord.query.filter_by(nickname=nickname).all()
  297. # 格式化查询结果
  298. results = [
  299. {
  300. "id": record.id,
  301. "result": record.result,
  302. "confidence": record.confidence,
  303. }
  304. for record in records
  305. ]
  306. # 返回查询结果
  307. return jsonify({"status": "success", "data": results})
  308. except Exception as e:
  309. # 捕获错误并返回
  310. return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500
  311. # 根路由
  312. @app.route('/')
  313. def home():
  314. return "Welcome to the API!"
  315. # 屏蔽 favicon.ico 请求
  316. @app.route('/favicon.ico')
  317. def favicon():
  318. return "", 204
  319. # 主函数
  320. if __name__ == '__main__':
  321. app.run(debug=True, host='0.0.0.0', port=5000)