"""ToDoWrite Model Schema Validator and Database Initializer.
This module provides programmatic access to the ToDoWrite model schemas,
allowing the library to validate data, initialize databases from schemas,
and ensure consistency between models and database structure.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import TypedDict
from sqlalchemy import (
Engine,
create_engine,
text,
)
from sqlalchemy.orm import sessionmaker
from .exceptions import ToDoWriteError
from .models import Base
# Type definitions for schema structure
[docs]
class FieldSchema(TypedDict):
"""Schema definition for a model field."""
type: str
nullable: bool
default: str | int | float | bool | None
max_length: int | None
primary_key: bool
unique: bool
[docs]
class ModelSchema(TypedDict):
"""Schema definition for a model."""
table_name: str
fields: dict[str, FieldSchema]
required_fields: list[str]
relationships: dict[str, object] # Complex relationship structure
[docs]
class AssociationTableSchema(TypedDict):
"""Schema definition for an association table."""
table_name: str
columns: dict[str, FieldSchema]
source_model: str
target_model: str
[docs]
class ToDoWriteCompleteSchema(TypedDict):
"""Complete ToDoWrite schema structure."""
models: dict[str, ModelSchema]
association_tables: dict[str, AssociationTableSchema]
generated_at: str
description: str
[docs]
class SchemaValidationError(ToDoWriteError):
"""Raised when schema validation fails."""
pass
[docs]
class DatabaseInitializationError(ToDoWriteError):
"""Raised when database initialization from schema fails."""
pass
[docs]
class ToDoWriteSchemaValidator:
"""
ToDoWrite Model Schema Validator and Database Manager.
Provides programmatic access to:
1. Validate data against schemas
2. Initialize databases from schema definitions
3. Ensure model-schema consistency
4. Import schemas into properly typed database tables
"""
[docs]
def __init__(self, schema_path: Path | None = None) -> None:
"""Initialize schema validator with schema file path."""
if schema_path is None:
# Default to the generated ToDoWrite Models schema
schema_path = (
Path(__file__).parent
/ "schemas"
/ "todowrite_models.schema.json"
)
self.schema_path = schema_path
self.schema: ToDoWriteCompleteSchema = self._load_schema()
def _load_schema(self) -> ToDoWriteCompleteSchema:
"""Load the ToDoWrite model schema from JSON file."""
try:
with open(self.schema_path) as f:
return json.load(f)
except FileNotFoundError:
raise SchemaValidationError(
f"Schema file not found: {self.schema_path}"
)
except json.JSONDecodeError as e:
raise SchemaValidationError(f"Invalid JSON in schema file: {e}")
[docs]
def get_model_schema(self, model_name: str) -> ModelSchema:
"""Get the schema definition for a specific model."""
if model_name not in self.schema.get("models", {}):
available = list(self.schema.get("models", {}).keys())
raise SchemaValidationError(
f"Model '{model_name}' not found. Available models: {available}"
)
return self.schema["models"][model_name]
[docs]
def get_association_table_schema(
self, table_name: str
) -> AssociationTableSchema:
"""Get the schema definition for an association table."""
if table_name not in self.schema.get("association_tables", {}):
available = list(self.schema.get("association_tables", {}).keys())
raise SchemaValidationError(
f"Association table '{table_name}' not found. Available: {available}"
)
return self.schema["association_tables"][table_name]
[docs]
def validate_model_data(
self, model_name: str, data: dict[str, object]
) -> bool:
"""Validate data against a specific model schema."""
model_schema = self.get_model_schema(model_name)
errors: list[str] = []
# Check required fields
required_fields = model_schema.get("required_fields", [])
for field in required_fields:
if field not in data:
errors.append(f"Missing required field: {field}")
# Check field types and constraints
fields = model_schema.get("fields", {})
for field_name, field_value in data.items():
if field_name in fields:
field_schema = fields[field_name]
# Type validation
expected_type = field_schema.get("type")
if expected_type == "string" and not isinstance(
field_value, str
):
errors.append(
f"Field '{field_name}' should be string, got {type(field_value).__name__}"
)
elif expected_type == "integer" and not isinstance(
field_value, int
):
errors.append(
f"Field '{field_name}' should be integer, got {type(field_value).__name__}"
)
elif expected_type == "boolean" and not isinstance(
field_value, bool
):
errors.append(
f"Field '{field_name}' should be boolean, got {type(field_value).__name__}"
)
# Nullable validation
nullable = field_schema.get("nullable", True)
if not nullable and field_value is None:
errors.append(f"Field '{field_name}' cannot be null")
if errors:
raise SchemaValidationError(
f"Data validation failed for {model_name}: {errors}"
)
return True
[docs]
def initialize_database_from_schema(
self, engine: Engine, drop_existing: bool = False
) -> None:
"""
Initialize database with all tables defined in the schema.
Args:
engine: SQLAlchemy engine to use for database operations
drop_existing: Whether to drop existing tables first
"""
try:
# Drop existing tables if requested
if drop_existing:
Base.metadata.drop_all(engine)
# Create all tables from SQLAlchemy models
Base.metadata.create_all(engine)
# Apply CASCADE DELETE constraints for proper hierarchy cleanup
self._apply_cascade_constraints(engine)
# Verify all expected tables exist
self._verify_database_structure(engine)
except Exception as e:
raise DatabaseInitializationError(
f"Database initialization failed: {e}"
)
def _apply_cascade_constraints(self, engine: Engine) -> None:
"""Apply CASCADE DELETE constraints to ensure proper hierarchy cleanup.
This method applies CASCADE DELETE constraints after the tables are created,
ensuring that when Goal(id).delete() is called, all associated entities
are automatically deleted, preventing orphaned data.
"""
try:
cascade_sql_path = (
Path(__file__).parent.parent
/ "database"
/ "cascade_constraints.sql"
)
if not cascade_sql_path.exists():
# Skip cascade constraints if SQL file doesn't exist
return
with open(cascade_sql_path) as f:
cascade_sql = f.read()
with engine.connect() as conn:
# Apply cascade constraints in a transaction
with conn.begin():
# Split SQL into individual statements and execute them
statements = [
stmt.strip()
for stmt in cascade_sql.split(";")
if stmt.strip() and not stmt.strip().startswith("--")
]
for statement in statements:
if statement.strip():
conn.execute(text(statement))
except Exception as e:
# Log warning but don't fail initialization if cascade constraints fail
print(f"Warning: Failed to apply CASCADE constraints: {e}")
print(
"Database tables created successfully but cascade constraints may need manual application."
)
def _verify_database_structure(self, engine: Engine) -> None:
"""Verify that all expected tables and columns exist in the database."""
with engine.connect() as conn:
# Check model tables
for _model_name, model_schema in self.schema.get(
"models", {}
).items():
table_name = model_schema["table_name"]
# Check table exists - handle different database engines
if engine.dialect.name == "sqlite":
table_query = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"
else: # PostgreSQL and others
table_query = "SELECT table_name FROM information_schema.tables WHERE table_name=:table_name AND table_schema='public'"
result = conn.execute(
text(table_query), {"table_name": table_name}
)
if not result.fetchone():
raise DatabaseInitializationError(
f"Table '{table_name}' not created"
)
# Check columns
columns = model_schema.get("fields", {}).keys()
for column in columns:
# Skip auto-generated columns
if column in ["id", "created_at", "updated_at"]:
continue
if engine.dialect.name == "sqlite":
result = conn.execute(
text(f"PRAGMA table_info({table_name})")
)
table_columns = [row[1] for row in result.fetchall()]
else: # PostgreSQL and others
result = conn.execute(
text("""
SELECT column_name FROM information_schema.columns
WHERE table_name = :table_name AND table_schema = 'public'
"""),
{"table_name": table_name},
)
table_columns = [row[0] for row in result.fetchall()]
if column not in table_columns:
raise DatabaseInitializationError(
f"Column '{column}' not found in table '{table_name}'"
)
# Check association tables
for table_name in self.schema.get("association_tables", {}):
if engine.dialect.name == "sqlite":
table_query = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"
else: # PostgreSQL and others
table_query = "SELECT table_name FROM information_schema.tables WHERE table_name=:table_name AND table_schema='public'"
result = conn.execute(
text(table_query), {"table_name": table_name}
)
if not result.fetchone():
raise DatabaseInitializationError(
f"Association table '{table_name}' not created"
)
[docs]
def get_all_model_schemas(self) -> dict[str, ModelSchema]:
"""Get all model schemas."""
return self.schema.get("models", {})
[docs]
def get_all_association_table_schemas(
self,
) -> dict[str, AssociationTableSchema]:
"""Get all association table schemas."""
return self.schema.get("association_tables", {})
[docs]
def get_model_relationships(self, model_name: str) -> dict[str, object]:
"""Get relationship information for a model."""
model_schema = self.get_model_schema(model_name)
return model_schema.get("relationships", {})
[docs]
def get_associated_models(self, model_name: str) -> list[str]:
"""Get list of models that this model has relationships with."""
relationships = self.get_model_relationships(model_name)
return [
rel.get("target")
for rel in relationships.values()
if rel.get("target")
]
[docs]
def get_schema_summary(self) -> dict[str, object]:
"""Get a summary of the schema structure."""
models = list(self.schema.get("models", {}).keys())
association_tables = list(
self.schema.get("association_tables", {}).keys()
)
return {
"total_models": len(models),
"total_association_tables": len(association_tables),
"models": models,
"association_tables": association_tables,
"generated_at": self.schema.get("generated_at"),
"description": self.schema.get("description"),
}
[docs]
class DatabaseSchemaInitializer:
"""
Helper class for database initialization using schema definitions.
Provides methods to create, drop, and verify database structure
based on the ToDoWrite model schema.
"""
[docs]
def __init__(
self, validator: ToDoWriteSchemaValidator | None = None
) -> None:
"""Initialize with schema validator."""
self.validator = validator or ToDoWriteSchemaValidator()
[docs]
def create_database(
self, database_url: str, drop_existing: bool = False
) -> Engine:
"""
Create a new database initialized with the ToDoWrite model schema.
Args:
database_url: SQLAlchemy database URL
drop_existing: Whether to drop existing database first
Returns:
SQLAlchemy engine for the created database
"""
engine = create_engine(database_url)
try:
self.validator.initialize_database_from_schema(
engine, drop_existing
)
# Create session factory
Session = sessionmaker(bind=engine)
# Test the database
with Session() as session:
session.execute(text("SELECT 1"))
return engine
except Exception as e:
raise DatabaseInitializationError(
f"Failed to create database: {e}"
)
[docs]
def verify_database_structure(self, database_url: str) -> bool:
"""Verify that database matches the schema structure."""
engine = create_engine(database_url)
try:
self.validator._verify_database_structure(engine)
return True
except DatabaseInitializationError:
return False
[docs]
def get_database_status(self, database_url: str) -> dict[str, object]:
"""Get status information about the database."""
engine = create_engine(database_url)
try:
schema_summary = self.validator.get_schema_summary()
with engine.connect() as conn:
# Count records in each table
table_counts = {}
for (
_model_name,
model_schema,
) in self.validator.get_all_model_schemas().items():
table_name = model_schema["table_name"]
try:
result = conn.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
count = result.scalar()
table_counts[table_name] = count
except:
table_counts[table_name] = 0
return {
"database_url": database_url,
"schema_matches": self.verify_database_structure(
database_url
),
"table_counts": table_counts,
"schema_summary": schema_summary,
}
except Exception as e:
return {
"database_url": database_url,
"schema_matches": False,
"error": str(e),
"schema_summary": self.validator.get_schema_summary(),
}
# Global validator instance for easy access
_default_validator: ToDoWriteSchemaValidator | None = None
[docs]
def get_schema_validator() -> ToDoWriteSchemaValidator:
"""Get the default schema validator instance."""
global _default_validator
if _default_validator is None:
_default_validator = ToDoWriteSchemaValidator()
return _default_validator
[docs]
def validate_model_data(model_name: str, data: dict[str, object]) -> bool:
"""Validate data against model schema using default validator."""
validator = get_schema_validator()
return validator.validate_model_data(model_name, data)
[docs]
def initialize_database(
database_url: str, drop_existing: bool = False
) -> Engine:
"""Initialize database with schema using default validator."""
initializer = DatabaseSchemaInitializer()
return initializer.create_database(database_url, drop_existing)