123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352 |
- from flask import Flask, request, jsonify, send_from_directory
- from flask_sqlalchemy import SQLAlchemy
- from flask_cors import CORS
- from transformers import AutoModel, AutoTokenizer
- import faiss
- import numpy as np
- import torch
- import pickle
- import re
- import onnxruntime as ort
- import base64
- from PIL import Image
- import datetime
- import io
- import os
- # 初始化 Flask 应用
- app = Flask(__name__)
- CORS(app)
- # 配置 SQLite 数据库
- BASE_DIR = os.path.abspath(os.path.dirname(__file__))
- app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///D:/STUDY/Project/jizhouyao/User/users.db'
- app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
- print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI'])
- # 初始化数据库
- db = SQLAlchemy(app)
- print("Database URI:", app.config['SQLALCHEMY_DATABASE_URI'])
- #----------登录服务----------#
- # 定义用户模型
- class User(db.Model):
- __tablename__ = 'users'
- id = db.Column(db.Integer, primary_key=True)
- username = db.Column(db.String(80), unique=True, nullable=False)
- password = db.Column(db.String(120), nullable=False)
- avatar = db.Column(db.String(200), nullable=True)
- # 登录接口
- @app.route('/api/login', methods=['POST'])
- def login():
- data = request.get_json()
- username = data.get('username')
- password = data.get('password')
- # 检查用户是否存在于数据库
- user = User.query.filter_by(username=username, password=password).first()
- if user:
- # 拼接完整的 avatar URL
- avatar_url = f"http://{request.host}/{user.avatar}" if user.avatar else None
- # 返回用户的 id、username 和 avatar
- return jsonify({
- "success": True,
- "message": "登录成功!",
- "data": {
- "id": user.id,
- "username": user.username,
- "avatar": avatar_url
- }
- })
- else:
- return jsonify({"success": False, "message": "账号或密码错误!"})
- @app.route('/avatar/<filename>')
- def get_avatar(filename):
- avatar_dir = os.path.join(BASE_DIR, 'User', 'avatar')# 头像存储目录
- # 检查文件是否存在
- file_path = os.path.join(avatar_dir, filename)
- if not os.path.exists(file_path):
- print(f"文件不存在:{file_path}") # 打印日志
- return jsonify({"success": False, "message": "File not found"}), 404
- # 提供头像文件
- return send_from_directory(avatar_dir, filename)
- #----------RAG知识库查询服务----------#
- # 指定本地模型路径和资源路径
- model_path = r"D:/STUDY/Project/jizhouyao/RAG/models/all-MiniLM-L6-v2"# 本地存储的预训练模型路径
- resources_path = r"D:/STUDY/Project/jizhouyao/RAG/resources"# 文本文件存储路径
- faiss_index_path = r"D:/STUDY/Project/jizhouyao/RAG/index/vector_index.faiss"# FAISS 索引的保存路径
- texts_pickle_path = r"D:/STUDY/Project/jizhouyao/RAG/index/texts.pkl"# 文本数据保存路径
- # 检查路径是否存在
- if not os.path.exists(model_path):
- raise FileNotFoundError(f"模型路径不存在: {model_path}")
- if not os.path.exists(resources_path):
- raise FileNotFoundError(f"资源文件夹不存在: {resources_path}")
- # 加载模型和分词器
- print("加载模型和分词器...")
- tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
- model = AutoModel.from_pretrained(model_path, local_files_only=True)
- # 加载 FAISS 索引和文本内容
- print("加载 FAISS 索引...")
- index = faiss.read_index(faiss_index_path)
- print("加载文本内容...")
- with open(texts_pickle_path, 'rb') as f:
- all_sentences = pickle.load(f)
- # 拆分文件内容为句子列表
- def split_content_to_sentences(all_sentences):
- flat_sentences = []
- for entry in all_sentences:
- file_name = entry["name"]
- content = entry["content"]
- sentences = re.split(r'[。!?;::\n]', content) # 基于标点符号分割
- sentences = [sentence.strip() for sentence in sentences if sentence.strip()] # 去除空白句子
- # 将句子和元数据存入 flat_sentences
- for sentence in sentences:
- flat_sentences.append({
- "file": file_name,
- "sentence": sentence
- })
- return flat_sentences
- flat_sentences = split_content_to_sentences(all_sentences)
- # 查询函数
- def query_faiss(query, top_k):
- """
- 在 FAISS 索引中查询与输入 query 最相关的句子
- """
- # 将 query 转换为嵌入向量
- inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512)
- with torch.no_grad():
- outputs = model(**inputs)
- # 提取 CLS 向量
- query_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
- query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
- faiss.normalize_L2(query_embedding)
- # 查询 FAISS 索引,选取 top_k * 3 条候选句子
- distances, indices = index.search(query_embedding, top_k * 3)
- # 收集有效结果
- results = []
- valid_count = 0 # 用于跟踪已收集到的有效结果数量
- for i, idx in enumerate(indices[0]):
- if valid_count >= top_k: # 如果已经收集到足够的有效结果,则停止
- break
- if idx != -1 and idx < len(flat_sentences): # 确保索引有效
- sentence_data = flat_sentences[idx] # 从 flat_sentences 获取句子数据
- sentence_text = sentence_data["sentence"] # 获取句子内容
- # 过滤掉字数小于 15 的句子
- if len(sentence_text) >= 15:
- results.append({
- "rank": valid_count + 1, # 有效结果的排名
- "file": sentence_data["file"], # 文件名
- "sentence": sentence_text, # 具体句子
- "score": float(distances[0][i]) # 匹配分数
- })
- valid_count += 1
- else:
- continue
- else:
- print(f"跳过无效索引: {idx}") # 打印无效索引
- print(f"返回的查询结果: {results}")
- return results
- # 查询路由
- @app.route('/query', methods=['POST'])
- def query_endpoint():
- """
- 接收前端发送的用户问题,在知识库中查询,并将查询结果返回给前端
- """
- try:
- # 获取 POST 请求中的 JSON 数据
- data = request.get_json()
- query_text = data.get("query", "") # 获取用户问题
- if not query_text:
- return jsonify({"error": "No query provided"}), 400 # 检查是否提供了 query
- # 在知识库中查询
- results = query_faiss(query_text, top_k=20)
- # 返回查询结果
- return jsonify({"query": query_text, "results": results})
- except Exception as e:
- return jsonify({"error": f"服务内部错误: {str(e)}"}), 500
- #----------陶瓷识别服务----------#
- # 加载 ONNX 模型
- model_path = r"D:/STUDY/Project/jizhouyao/image-recognition/index/mobilenet_model.onnx"
- try:
- # 使用 ONNX Runtime 加载模型
- ort_session = ort.InferenceSession(model_path)
- print("ONNX 模型加载成功!")
- except Exception as e:
- print(f"ONNX 模型加载失败: {e}")
- exit(1)
- # 类别标签
- class_labels = ["兔毫盏", "凤纹盏", "剪纸贴花盏", "木叶天目盏","梅瓶", "梅纹盏", "玳瑁釉盏", "鹧鸪斑盏", "黑釉盏"]
- # 图像预处理函数
- def preprocess_image(image: Image.Image, input_shape: tuple):
- """
- 将 PIL 图像处理为 ONNX 模型输入格式
- :param image: PIL 图像
- :param input_shape: 模型期望的输入形状 (N, C, H, W)
- :return: 预处理后的图像 numpy 数组
- """
- height, width = input_shape[2], input_shape[3] # 模型输入的高度和宽度
- image = image.resize((width, height)) # 调整图片大小
- image = np.array(image).astype('float32') / 255.0 # 归一化
- image = np.transpose(image, (2, 0, 1)) # HWC -> CHW
- image = np.expand_dims(image, axis=0) # 添加 batch 维度
- return image
- # Softmax 函数
- def softmax(logits):
- exp_values = np.exp(logits - np.max(logits)) # 稳定的 softmax 避免溢出
- probabilities = exp_values / np.sum(exp_values)
- return probabilities
- # 预测路由
- @app.route('/predict', methods=['POST'])
- def predict_route():
- try:
- # 获取前端传来的 Base64 图片数据
- input_data = request.json.get("data")
- if not input_data:
- return jsonify({"error": "No input data provided"}), 400
- # 去掉 Base64 数据前缀并解码
- base64_image = input_data.split(",")[1]
- image_data = base64.b64decode(base64_image)
- # 使用 PIL 打开图片
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
- # 获取模型输入形状并预处理
- input_name = ort_session.get_inputs()[0].name
- input_shape = ort_session.get_inputs()[0].shape # 通常是 (N, C, H, W)
- preprocessed_image = preprocess_image(image, input_shape)
- # 打印预处理后图像的形状
- # 使用 ONNX 模型进行推理
- outputs = ort_session.run(None, {input_name: preprocessed_image})
- # 打印模型输出的所有结果
- # 模型输出 logits
- logits = outputs[0][0]
- # 计算 softmax 概率分布
- probabilities = softmax(logits)
- # 获取预测类别及置信度
- predicted_class = int(np.argmax(probabilities)) # 预测类别索引
- confidence = float(np.max(probabilities)) # 置信度
- predicted_label = class_labels[predicted_class] # 预测类别名称
- # 组合类别和对应概率
- class_probabilities = [
- {"class": class_labels[i], "probability": float(probabilities[i])}
- for i in range(len(class_labels))
- ]
- # 返回预测结果
- prediction = {
- "predicted_class": predicted_class,
- "predicted_label": predicted_label,
- "confidence": confidence,
- "class_probabilities": class_probabilities
- }
- print(jsonify({"prediction": prediction}))
- return jsonify({"prediction": prediction})
- except Exception as e:
- # 捕获异常并记录日志
- print(f"Error during prediction: {e}")
- return jsonify({"error": str(e)}), 500
- #----------保存陶瓷识别记录服务----------#
- # 定义检测记录的数据库模型
- class CeramicDetectionRecord(db.Model):
- __tablename__ = 'ceramic_detection_record' # 新建表的名称
- id = db.Column(db.Integer, primary_key=True) # 主键
- nickname = db.Column(db.String(80), nullable=False) # 用户昵称
- detection_time = db.Column(db.DateTime, nullable=False) # 检测时间
- result = db.Column(db.String(120), nullable=False) # 检测结果
- confidence = db.Column(db.String(10), nullable=False) # 置信度
- # 保存检测记录路由
- @app.route('/save_detection', methods=['POST'])
- def save_detection():
- try:
- # 从请求中获取 JSON 数据
- data = request.get_json()
- # 提取记录数据
- nickname = data.get('nickname')
- detection_time = data.get('detection_time')
- result = data.get('result')
- confidence = data.get('confidence')
- # 检查必填字段
- if not all([nickname, detection_time, result, confidence]):
- return jsonify({"status": "error", "message": "缺少必要字段"}), 400
- # 替换 'Z' 为 '+00:00' 以支持 ISO 格式
- if detection_time.endswith("Z"):
- detection_time = detection_time.replace("Z", "+00:00")
- # 转换检测时间为 datetime 对象
- try:
- detection_time = datetime.datetime.fromisoformat(detection_time)
- except ValueError as ve:
- print("检测时间格式错误:", detection_time) # 打印时间格式错误
- return jsonify({"status": "error", "message": "检测时间格式错误"}), 400
- # 创建记录对象
- record = CeramicDetectionRecord(
- nickname=nickname,
- detection_time=detection_time,
- result=result,
- confidence=confidence,
- )
- # 保存到数据库
- db.session.add(record)
- db.session.commit()
- return jsonify({"status": "success", "message": "识别记录保存成功"})
- except Exception as e:
- return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500
- #----------渲染陶瓷识别记录服务----------#
- # 按 nickname 查询记录的路由
- @app.route('/get_detection_records', methods=['GET'])
- def get_detection_records_by_nickname():
- try:
- # 从请求中获取 nickname 参数
- nickname = request.args.get('nickname')
- # 检查 nickname 是否提供
- if not nickname:
- return jsonify({"status": "error", "message": "缺少参数: nickname"}), 400
- # 从数据库中查询对应的记录
- records = CeramicDetectionRecord.query.filter_by(nickname=nickname).all()
- # 格式化查询结果
- results = [
- {
- "id": record.id,
- "result": record.result,
- "confidence": record.confidence,
- }
- for record in records
- ]
- # 返回查询结果
- return jsonify({"status": "success", "data": results})
- except Exception as e:
- # 捕获错误并返回
- return jsonify({"status": "error", "message": f"发生错误: {str(e)}"}), 500
- # 根路由
- @app.route('/')
- def home():
- return "Welcome to the API!"
- # 屏蔽 favicon.ico 请求
- @app.route('/favicon.ico')
- def favicon():
- return "", 204
- # 主函数
- if __name__ == '__main__':
- app.run(debug=True, host='0.0.0.0', port=5000)
|