test_ner.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. """
  2. NER 服务测试
  3. """
  4. import pytest
  5. from fastapi.testclient import TestClient
  6. from app.main import app
  7. client = TestClient(app)
  8. def test_health_check():
  9. """测试健康检查接口"""
  10. response = client.get("/health")
  11. assert response.status_code == 200
  12. data = response.json()
  13. assert data["status"] == "ok"
  14. assert "version" in data
  15. def test_extract_entities():
  16. """测试实体提取接口"""
  17. request_data = {
  18. "documentId": "test-doc-001",
  19. "text": "2024年5月15日,成都检测公司在成都市高新区完成了环境监测项目的检测工作,使用了噪音检测设备。",
  20. "extractRelations": True
  21. }
  22. response = client.post("/ner/extract", json=request_data)
  23. assert response.status_code == 200
  24. data = response.json()
  25. assert data["success"] is True
  26. assert data["documentId"] == "test-doc-001"
  27. assert "entities" in data
  28. assert len(data["entities"]) > 0
  29. # 验证提取到的实体类型
  30. entity_types = {e["type"] for e in data["entities"]}
  31. assert "DATE" in entity_types or "ORG" in entity_types
  32. def test_extract_relations():
  33. """测试关系抽取接口"""
  34. request_data = {
  35. "documentId": "test-doc-002",
  36. "text": "成都检测公司负责环境监测项目",
  37. "entities": [
  38. {
  39. "name": "成都检测公司",
  40. "type": "ORG",
  41. "value": "成都检测公司",
  42. "position": {"charStart": 0, "charEnd": 6, "line": 1},
  43. "tempId": "e1"
  44. },
  45. {
  46. "name": "环境监测项目",
  47. "type": "PROJECT",
  48. "value": "环境监测项目",
  49. "position": {"charStart": 8, "charEnd": 14, "line": 1},
  50. "tempId": "e2"
  51. }
  52. ]
  53. }
  54. response = client.post("/ner/relations", json=request_data)
  55. assert response.status_code == 200
  56. data = response.json()
  57. assert data["success"] is True
  58. assert "relations" in data
  59. def test_empty_text():
  60. """测试空文本"""
  61. request_data = {
  62. "documentId": "test-doc-003",
  63. "text": "",
  64. "extractRelations": False
  65. }
  66. response = client.post("/ner/extract", json=request_data)
  67. assert response.status_code == 200
  68. data = response.json()
  69. assert data["success"] is True
  70. assert len(data["entities"]) == 0
  71. def test_text_too_long():
  72. """测试文本过长"""
  73. request_data = {
  74. "documentId": "test-doc-004",
  75. "text": "a" * 60000, # 超过限制
  76. "extractRelations": False
  77. }
  78. response = client.post("/ner/extract", json=request_data)
  79. assert response.status_code == 400 # 应该返回错误
  80. if __name__ == "__main__":
  81. pytest.main([__file__, "-v"])