Source code for raiiaf.handlers.file_handler

"""RAIIAF file handler for encoding and decoding .raiiaf artifacts.

This module orchestrates packing/unpacking of latent, image, environment, and metadata
chunks, builds the header, and provides a high-level API for reading and writing
RAIIAF files.
"""

import hashlib
from ..core.constants import MAX_FILE_SIZE, MAX_CHUNK_SIZE, HEADER_SIZE, MAX_CHUNKS
from ..core.exceptions import (
    raiiafDecodeError,
    raiiafChunkError,
    raiiafCorruptHeader,
    raiiafImageError,
    raiiafEnvChunkError,
)
from ..core.header import header_init, header_parse, header_validate
from ..chunks.latent import raiiafLatent
from ..chunks.image import raiiafImage
from ..chunks.env import raiiafEnv
from ..chunks.metadata import raiiafMetadata
from typing import Optional, Dict, Any
from PIL import Image
import numpy as np
import struct
import io
import warnings
import json


[docs] class raiiafFileHandler: """High-level API for reading and writing RAIIAF files.""" def __init__(self, max_file_size: Optional[int] = None, max_chunk_size: Optional[int] = None): """Initialize the file handler with optional size limits. Args: max_file_size (Optional[int]): Maximum allowed file size in bytes. Defaults to MAX_FILE_SIZE. max_chunk_size (Optional[int]): Maximum allowed chunk size in bytes. Defaults to MAX_CHUNK_SIZE. """ self.latent = raiiafLatent() self.image = raiiafImage() self.metadata = raiiafMetadata() self.env = raiiafEnv() self.max_file_size = max_file_size or MAX_FILE_SIZE self.max_chunk_size = max_chunk_size or MAX_CHUNK_SIZE self.HEADER_SIZE = HEADER_SIZE
[docs] def validate_file_size(self, size: int, context: str = "file") -> bool: """Validate a file or chunk size. Args: size (int): Size in bytes to validate. context (str): Validation context; either "file" or "chunk". Returns: bool: True if the size is within limits. Raises: raiiafDecodeError: If size exceeds configured limits. """ if size < 0: raise raiiafDecodeError(f"Invalid {context} size: {size} (negative)") if context == "file" and size > self.max_file_size: raise raiiafDecodeError( f"File size {size:,} bytes exceeds maximum {self.max_file_size:,} bytes " f"({size / (1024**3):.2f} GB)" ) if context == "chunk" and size > self.max_chunk_size: raise raiiafChunkError( f"Chunk size {size:,} bytes exceeds maximum {self.max_chunk_size:,} bytes" ) return True
[docs] def validate_chunk_count(self, count: int) -> bool: """Validate the number of chunks. Args: count (int): Number of chunks. Returns: bool: True if the count is within limits. Raises: raiiafDecodeError: If count exceeds configured limits. """ if count < 0: raise raiiafDecodeError(f"Invalid chunk count: {count}") if count > MAX_CHUNKS: raise raiiafDecodeError(f"Chunk count {count} exceeds maximum {MAX_CHUNKS}") return True
[docs] def file_encoder( self, filename: str, latent: Dict[str, np.ndarray], chunk_records: list, model_name: str, model_version: str, prompt: str, tags: list, img_binary: bytes, should_compress: bool = True, convert_float16: bool = True, generation_settings: Optional[dict] = None, hardware_info: Optional[dict] = None, extra_image: Optional[Dict[str, Any]] = None, ): """Encode a RAIIAF file. Orchestrates packing latent, image, environment, and metadata chunks, builds the header, and writes the final file. Args: filename (str): Output .raiiaf filename. The .raiiaf extension is required. latent (Dict[str, np.ndarray]): Mapping of latent keys to arrays. chunk_records (list): Mutable list that will be appended with chunk records. model_name (str): Name of the model. model_version (str): Version of the model. prompt (str): Prompt used for generation. tags (list): Tags associated with the generation. img_binary (bytes): PNG image bytes to embed. should_compress (bool): Whether to compress chunks. Defaults to True. convert_float16 (bool): Convert latents to float16 for storage. Defaults to True. generation_settings (Optional[dict]): Generation configuration to include in metadata. hardware_info (Optional[dict]): Hardware information to include in metadata. extra_image (Optional[Dict[str, Any]]): Extra fields for the image chunk record. Returns: dict: A dictionary containing header bytes and chunk bytes with keys: - header (bytes) - latent_chunks (bytes) - metadata_chunk (bytes) - image_chunk (Optional[bytes]) """ # Pack the latent chunks first and update chunk_records with correct offsets if not filename.endswith(".raiiaf"): raise ValueError("Filename must have a .raiiaf extension") current_offset = self.HEADER_SIZE latent_chunks = self.latent.latent_packer( latent, file_offset=current_offset, chunk_records=chunk_records, should_compress=should_compress, convert_float16=convert_float16, ) latent_chunk = b"".join(latent_chunks) self.validate_file_size(len(latent_chunk), "chunk") current_offset += len(latent_chunk) # Pack image chunk if provided image_chunk = None if img_binary is not None: image_chunk = self.image.image_data_chunk_builder(img_binary) image_chunk_record = { "type": "DATA", "flags": "0000", "offset": current_offset, "compressed_size": len(image_chunk), "uncompressed_size": len(img_binary), "hash": hashlib.sha256(img_binary).hexdigest(), "extra": extra_image or {}, "compressed": True, } chunk_records.append(image_chunk_record) current_offset += len(image_chunk) print("ENCODER STORED FLAG:", chunk_records[-1]["flags"]) # Environment chunk env_chunk, env_raw = self.env.env_chunk_builder(self.env.env_chunk_populator()) if len(env_chunk) > 0: env_record = { "type": "ENVC", "flags": "0000", "offset": current_offset, "compressed_size": len(env_chunk), # bytes written "uncompressed_size": len(env_raw), # JSON bytes "hash": hashlib.sha256(env_raw).hexdigest(), # hash original "extra": {}, "compressed": True, } chunk_records.append(env_record) current_offset += len(env_chunk) # Build manifest with all chunks manifest = self.metadata.build_manifest( version_major=1, version_minor=0, model_name=model_name, model_version=model_version, prompt=prompt, tags=tags, chunk_records=chunk_records, generation_settings=generation_settings, hardware_info=hardware_info, ) # Ensure that manifest/metadata is valid self.metadata.metadata_validator(manifest) # Compress metadata compressed_metadata = self.metadata.metadata_compressor(manifest) metadata_size = len(compressed_metadata) # Calculate total file size total_file_size = ( self.HEADER_SIZE + len(latent_chunk) + (len(image_chunk) if image_chunk else 0) + len(env_chunk) + metadata_size ) # Update file_size in manifest manifest["raiiaf_metadata"]["file_info"]["file_size"] = total_file_size # Recompress metadata with the updated file_size compressed_metadata = self.metadata.metadata_compressor(manifest) metadata_size = len(compressed_metadata) total_file_size = ( self.HEADER_SIZE + len(latent_chunk) + (len(image_chunk) if image_chunk else 0) + len(env_chunk) + len(compressed_metadata) ) chunk_table_offset = ( self.HEADER_SIZE + len(latent_chunk) + (len(image_chunk) if image_chunk else 0) + len(env_chunk) ) # Final header header = header_init( version_major=1, version_minor=0, flags=0, chunk_table_offset=chunk_table_offset, chunk_table_size=metadata_size, chunk_count=len(chunk_records), file_size=total_file_size, ) with open(filename, "wb") as f: f.write(header) f.write(latent_chunk) if image_chunk is not None: f.write(image_chunk) f.write(env_chunk) f.write(compressed_metadata) return { "header": header, "latent_chunks": latent_chunk, "metadata_chunk": compressed_metadata, "image_chunk": image_chunk, }
[docs] def file_decoder(self, filename: str): """Decode a RAIIAF file. Reads the header, chunk table (metadata), and iteratively decodes the latent, image, and environment chunks. Args: filename (str): Path to the input .raiiaf file. Returns: dict: A dictionary with keys: - header (dict): Parsed header fields. - chunks (dict): Parsed chunks: 'latent' (list), 'image' (bytes), 'env' (EnvChunk). - metadata (dict): Parsed metadata manifest. Raises: raiiafCorruptHeader: If the header fails validation. raiiafChunkError: If a chunk is truncated or corrupt. raiiafEnvChunkError: If the environment chunk cannot be parsed when uncompressed. """ with open(filename, "rb") as f: header_bytes = f.read(HEADER_SIZE) header = header_parse(header_bytes) if not header_validate(header): raise raiiafCorruptHeader(message=f"Invalid header: {header}") f.seek(header["chunk_table_offset"]) metadata_compressed = f.read(header["chunk_table_size"]) metadata = self.metadata.metadata_parser(metadata_compressed) chunk_records = metadata["raiiaf_metadata"]["chunks"] chunks = {} chunks["latent"] = [] for record in chunk_records: chunk_type = record["type"] compressed = record.get("compressed", True) f.seek(record["offset"]) raw_chunk = f.read(record["compressed_size"]) if len(raw_chunk) != record["compressed_size"]: raise raiiafChunkError( f"Truncated chunk {chunk_type} at offset {record['offset']}" ) if chunk_type == "LATN": shape = tuple(record["extra"]["shape"]) if compressed: print("decoded raw_chunk len:", len(raw_chunk)) print("expected total bytes:", np.prod(shape)) chunk_obj = { "chunk": raw_chunk, "len_header": record.get("len_header"), "len_data": record.get("len_data"), } latent_array = self.latent.latent_parser(chunk_obj, shape, True) else: latent_array = self.latent.latent_parser(raw_chunk, shape, False) chunks["latent"].append(latent_array) elif chunk_type == "DATA": if compressed: parsed = self.image.image_data_chunk_parser(raw_chunk) chunks["image"] = parsed["image_data"] else: # DATA chunks have the same header layout even when not compressed chunk_type_b, flags_b, size = struct.unpack("<4s 4s I", raw_chunk[:12]) chunks["image"] = raw_chunk[12 : 12 + size] elif chunk_type == "ENVC": if compressed: parsed = self.env.env_chunk_parser(raw_chunk) chunks["env"] = parsed["env_chunk"] else: if len(raw_chunk) < 12: raise raiiafEnvChunkError("Truncated ENVC chunk header") chunk_type_b, flags_b, size = struct.unpack("<4s 4s I", raw_chunk[:12]) env_json_bytes = raw_chunk[12 : 12 + size] env_dict = json.loads(env_json_bytes.decode("utf-8")) chunks["env"] = env_dict try: # --- Normalize CURRENT environment (from populator) --- current_env_obj = self.env.env_chunk_populator() current_env = {} for comp in current_env_obj.components: if isinstance(comp, dict): comp_id = comp["component_id"] cononical_str = comp["cononical_str"] digest = comp["component_sha256_digest"] else: comp_id = comp.component_id cononical_str = comp.cononical_str # matches your class digest = comp.component_sha256_digest sha256 = digest.hex() if isinstance(digest, bytes) else digest current_env[comp_id] = { "cononical_str": cononical_str, "sha256": sha256, } # --- Normalize STORED environment (from file) --- stored_raw = chunks["env"] stored_env = {} # Extract components regardless of format if hasattr(stored_raw, "components"): components_list = stored_raw.components elif isinstance(stored_raw, dict) and "components" in stored_raw: components_list = stored_raw["components"] else: components_list = [] for comp in components_list: if isinstance(comp, dict): comp_id = comp["component_id"] cononical_str = comp["cononical_str"] digest = comp["component_sha256_digest"] else: comp_id = comp.component_id cononical_str = comp.cononical_str # ✅ digest = comp.component_sha256_digest sha256 = digest.hex() if isinstance(digest, bytes) else digest stored_env[comp_id] = { "cononical_str": cononical_str, "sha256": sha256, } # --- Compare environments --- all_ids = set(stored_env.keys()) | set(current_env.keys()) for comp_id in all_ids: stored = stored_env.get(comp_id) current = current_env.get(comp_id) if stored and current: if stored["sha256"] != current["sha256"]: warnings.warn( ( f"Environment component '{comp_id}' differs:\n" f" File: {stored['cononical_str']}\n" f" Current: {current['cononical_str']}" ), UserWarning, ) elif stored and not current: warnings.warn( f"Environment component '{comp_id}' missing in current system", UserWarning, ) elif not stored and current: warnings.warn( f"Environment component '{comp_id}' missing in file", UserWarning, ) except Exception as e: warnings.warn(f"Failed to compare environment chunks: {e}", UserWarning) else: raise ValueError("Unknown chunk type: {chunk_type}. Supported: LATN, DATA") return {"header": header, "chunks": chunks, "metadata": metadata}
[docs] @staticmethod def png_to_bytes(png_path: str) -> bytes: """Convert a PNG image to bytes. Preserves transparency by converting to RGBA. Args: png_path (str): Path to the PNG image. Returns: bytes: PNG image data in bytes. Raises: raiiafImageError: If the image cannot be read or encoded. """ try: with Image.open(png_path) as img: img = img.convert("RGBA") buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() except Exception as e: raise raiiafImageError(f"Failed to convert PNG to bytes: {e}") from e
[docs] @staticmethod def bytes_to_png(img_bytes: bytes) -> Image.Image: """Convert PNG bytes to a PIL Image. Args: img_bytes (bytes): PNG image bytes. Returns: PIL.Image.Image: Loaded image instance. Raises: raiiafImageError: If the bytes cannot be decoded as an image. """ try: buffer = io.BytesIO(img_bytes) img = Image.open(buffer) return img except Exception as e: raise raiiafImageError(f"Failed to convert bytes to PNG: {e}") from e