250 lines
8.1 KiB
Python
250 lines
8.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for PunimTag
|
|
Tests core functionality including face detection, recognition, tagging, and search
|
|
"""
|
|
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
from datetime import datetime
|
|
from punimtag import PunimTag
|
|
import numpy as np
|
|
|
|
|
|
class TestPunimTag(unittest.TestCase):
|
|
def setUp(self):
|
|
"""Set up test environment"""
|
|
# Create temporary directory for test database
|
|
self.test_dir = tempfile.mkdtemp()
|
|
self.db_path = os.path.join(self.test_dir, 'test.db')
|
|
self.photos_dir = os.path.join(self.test_dir, 'photos')
|
|
os.makedirs(self.photos_dir, exist_ok=True)
|
|
|
|
# Initialize PunimTag with test database
|
|
self.tagger = PunimTag(db_path=self.db_path, photos_dir=self.photos_dir)
|
|
|
|
def tearDown(self):
|
|
"""Clean up test environment"""
|
|
self.tagger.close()
|
|
shutil.rmtree(self.test_dir)
|
|
|
|
def test_database_creation(self):
|
|
"""Test that database tables are created correctly"""
|
|
c = self.tagger.conn.cursor()
|
|
|
|
# Check tables exist
|
|
c.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = {row[0] for row in c.fetchall()}
|
|
|
|
expected_tables = {'images', 'people', 'faces', 'tags', 'image_tags'}
|
|
self.assertEqual(tables & expected_tables, expected_tables)
|
|
|
|
def test_add_person(self):
|
|
"""Test adding people to database"""
|
|
# Add person
|
|
person_id = self.tagger.add_person("John Doe")
|
|
self.assertIsNotNone(person_id)
|
|
|
|
# Verify person exists
|
|
c = self.tagger.conn.cursor()
|
|
c.execute("SELECT name FROM people WHERE id = ?", (person_id,))
|
|
result = c.fetchone()
|
|
self.assertEqual(result[0], "John Doe")
|
|
|
|
# Test duplicate handling
|
|
person_id2 = self.tagger.add_person("John Doe")
|
|
self.assertEqual(person_id, person_id2)
|
|
|
|
def test_add_tag(self):
|
|
"""Test tag creation"""
|
|
# Add tag without category
|
|
tag_id1 = self.tagger.add_tag("vacation")
|
|
self.assertIsNotNone(tag_id1)
|
|
|
|
# Add tag with category
|
|
tag_id2 = self.tagger.add_tag("beach", "location")
|
|
self.assertIsNotNone(tag_id2)
|
|
|
|
# Verify tags exist
|
|
c = self.tagger.conn.cursor()
|
|
c.execute("SELECT name, category FROM tags WHERE id = ?", (tag_id2,))
|
|
result = c.fetchone()
|
|
self.assertEqual(result[0], "beach")
|
|
self.assertEqual(result[1], "location")
|
|
|
|
def test_metadata_extraction(self):
|
|
"""Test metadata extraction from images"""
|
|
# Test with a non-existent file - should handle gracefully
|
|
try:
|
|
metadata = self.tagger.extract_metadata("nonexistent.jpg")
|
|
# If it doesn't raise an exception, check default values
|
|
self.assertIsNone(metadata['date_taken'])
|
|
self.assertIsNone(metadata['latitude'])
|
|
self.assertIsNone(metadata['longitude'])
|
|
except FileNotFoundError:
|
|
# This is also acceptable behavior
|
|
pass
|
|
|
|
def test_face_identification(self):
|
|
"""Test face identification logic"""
|
|
# Test with no known faces
|
|
result = self.tagger.identify_face(np.random.rand(128))
|
|
self.assertEqual(result, (None, None))
|
|
|
|
# Would need actual face encodings for more thorough testing
|
|
|
|
def test_search_functionality(self):
|
|
"""Test search capabilities"""
|
|
# Search with no data should return empty
|
|
results = self.tagger.search_images()
|
|
self.assertEqual(len(results), 0)
|
|
|
|
# Test with filters
|
|
results = self.tagger.search_images(
|
|
people=["John Doe"],
|
|
tags=["vacation"],
|
|
date_from=datetime(2023, 1, 1),
|
|
date_to=datetime(2023, 12, 31)
|
|
)
|
|
self.assertEqual(len(results), 0)
|
|
|
|
def test_unidentified_faces(self):
|
|
"""Test getting unidentified faces"""
|
|
faces = self.tagger.get_unidentified_faces()
|
|
self.assertEqual(len(faces), 0) # Should be empty initially
|
|
|
|
|
|
class TestImageProcessing(unittest.TestCase):
|
|
"""Test image processing with actual images"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Create test images"""
|
|
cls.test_dir = tempfile.mkdtemp()
|
|
cls.photos_dir = os.path.join(cls.test_dir, 'photos')
|
|
os.makedirs(cls.photos_dir, exist_ok=True)
|
|
|
|
# Create test images (simple colored squares)
|
|
try:
|
|
from PIL import Image
|
|
|
|
# Create a few test images
|
|
for i, color in enumerate(['red', 'green', 'blue']):
|
|
img = Image.new('RGB', (100, 100), color)
|
|
img.save(os.path.join(cls.photos_dir, f'test_{color}.jpg'))
|
|
except ImportError:
|
|
print("PIL not available, skipping image creation")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
"""Clean up test images"""
|
|
shutil.rmtree(cls.test_dir)
|
|
|
|
def setUp(self):
|
|
"""Set up for each test"""
|
|
self.db_path = os.path.join(self.test_dir, 'test.db')
|
|
self.tagger = PunimTag(db_path=self.db_path, photos_dir=self.photos_dir)
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.tagger.close()
|
|
if os.path.exists(self.db_path):
|
|
os.remove(self.db_path)
|
|
|
|
def test_process_directory(self):
|
|
"""Test processing a directory of images"""
|
|
# Process all images
|
|
processed = self.tagger.process_directory()
|
|
|
|
# Should process the test images (if created)
|
|
self.assertGreaterEqual(processed, 0)
|
|
|
|
# Check images were added to database
|
|
c = self.tagger.conn.cursor()
|
|
c.execute("SELECT COUNT(*) FROM images")
|
|
count = c.fetchone()[0]
|
|
self.assertEqual(count, processed)
|
|
|
|
|
|
def test_with_sample_images(image_paths):
|
|
"""
|
|
Test PunimTag with actual image files
|
|
|
|
Args:
|
|
image_paths: List of paths to test images
|
|
"""
|
|
print("Testing PunimTag with sample images")
|
|
print("=" * 50)
|
|
|
|
# Create temporary database
|
|
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp:
|
|
db_path = tmp.name
|
|
|
|
try:
|
|
# Initialize PunimTag
|
|
tagger = PunimTag(db_path=db_path)
|
|
|
|
# Process each image
|
|
print(f"\nProcessing {len(image_paths)} images...")
|
|
for path in image_paths:
|
|
if os.path.exists(path):
|
|
print(f"Processing: {path}")
|
|
try:
|
|
image_id = tagger.process_image(path)
|
|
print(f" ✓ Added to database with ID: {image_id}")
|
|
except Exception as e:
|
|
print(f" ✗ Error: {e}")
|
|
else:
|
|
print(f" ✗ File not found: {path}")
|
|
|
|
# Show statistics
|
|
c = tagger.conn.cursor()
|
|
|
|
c.execute("SELECT COUNT(*) FROM images")
|
|
image_count = c.fetchone()[0]
|
|
print(f"\nTotal images: {image_count}")
|
|
|
|
c.execute("SELECT COUNT(*) FROM faces")
|
|
face_count = c.fetchone()[0]
|
|
print(f"Total faces detected: {face_count}")
|
|
|
|
# Get unidentified faces
|
|
unidentified = tagger.get_unidentified_faces()
|
|
print(f"Unidentified faces: {len(unidentified)}")
|
|
|
|
# Close connection
|
|
tagger.close()
|
|
|
|
print("\n✓ Test completed successfully!")
|
|
|
|
finally:
|
|
# Clean up
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
|
|
|
|
def main():
|
|
"""Main test runner"""
|
|
print("PunimTag Test Suite")
|
|
print("=" * 50)
|
|
|
|
# Run unit tests
|
|
print("\nRunning unit tests...")
|
|
unittest.main(argv=[''], exit=False, verbosity=2)
|
|
|
|
# Optional: Test with actual images
|
|
print("\n" + "=" * 50)
|
|
print("To test with actual images, call:")
|
|
print("python test_punimtag.py image1.jpg image2.jpg ...")
|
|
|
|
# Check if images were provided as arguments
|
|
import sys
|
|
if len(sys.argv) > 1:
|
|
image_paths = sys.argv[1:]
|
|
test_with_sample_images(image_paths)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |