WIP: Vision model #10

Draft
james wants to merge 1 commits from vision into main
7 changed files with 8047 additions and 15869 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@ config.py
# Unsloth
_unsloth_sentencepiece_temp/
unsloth_compiled_cache/
# ---> Python
# Byte-compiled / optimized / DLL files

3
data/booru/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*
!README.md
!.gitignore

1
data/booru/README.md Normal file
View File

@ -0,0 +1 @@
Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"

View File

@ -1,5 +1,5 @@
{
"name": "discord",
"name": "data",
"lockfileVersion": 3,
"requires": true,
"packages": {

177
data/proc_booru.py Normal file
View File

@ -0,0 +1,177 @@
"""
proc_booru.py
This script assumes you have a folder called 'booru/' in the current directory,
containing a bunch of images following the Shimmie Booru naming scheme, i.e.
'12345 - Tag1 Tag2 Tag3.jpg'.
"""
import json
import os
from pathlib import Path
from typing import List, Tuple
import re
from unsloth import FastVisionModel
import pandas as pd
import numpy as np
import PIL
import torch
import io
import tqdm
# names of real-life people tagged in images
NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
# irrelevant tags that should just be removed
IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
def parse_filename(filename: str) -> Tuple[str, List[str]]:
"""
Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
Returns tuple of (id, [tags])
"""
# Remove file extension
name = os.path.splitext(filename)[0]
# Split into ID and tags
match = re.match(r'(\d+)\s*-\s*(.*)', name)
if not match:
raise ValueError(f"Invalid filename format: {filename}")
image_id = match.group(1)
tags = match.group(2).strip().split()
# remove irrelevant tags
irrelevant_overlap = IRRELEVANT.intersection(tags)
if len(irrelevant_overlap) > 0:
for tag in irrelevant_overlap:
tags.remove(tag)
# remove ambiguous situations with people's names, since the model won't know what they look like
names_overlap = NAMES.intersection(tags)
if len(names_overlap) > 1:
for name in names_overlap:
tags.remove(name)
return image_id, tags
def create_prompt(tags: List[str]) -> str:
"""
Create a prompt for the LLM to generate a summary based on tags.
"""
tags_str = ', '.join(tags)
return [
{
"role": "user",
"content": [
{
"type": "text",
"text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
},
{"type": "image"},
{
"type": "text",
"text": f"Tags: {tags_str}"
},
]
}
]
def load_image_as_bytes(image_path: Path) -> bytes:
"""
Load an image file and return it as bytes.
"""
with PIL.Image.open(image_path) as img:
# Convert to RGB if necessary
if img.mode != 'RGB':
img = img.convert('RGB')
# Save to bytes
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG')
return img_byte_arr.getvalue()
def main():
model, tokenizer = FastVisionModel.from_pretrained(
"unsloth/Llama-3.2-11B-Vision-Instruct",
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)
FastVisionModel.for_inference(model)
# Process all images in the booru directory
booru_dir = Path('booru')
if not booru_dir.exists():
raise FileNotFoundError("booru directory not found")
# Create lists to store data
data = []
# Get all image files
image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Process each file
for filename in tqdm.tqdm(image_files):
try:
filepath = booru_dir / filename
# Parse filename
image_id, tags = parse_filename(filename)
# Create prompt
prompt = create_prompt(tags)
input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
image = PIL.Image.open(filepath)
inputs = tokenizer(
image,
input_text,
add_special_tokens = False,
return_tensors = "pt",
).to("cuda")
# Generate summary using VLLM
outputs = model.generate(**inputs, max_new_tokens=128,
use_cache=True, temperature=1.5, min_p=0.1)
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
generated_text = generated_text.partition('\n')[0]
# Load image as bytes
image_bytes = load_image_as_bytes(filepath)
data_dict = {
'image_id': image_id,
'filename': filename,
'tags': tags,
'tags_string': ' '.join(tags),
'summary': generated_text,
}
print(data_dict)
data_dict['image_data'] = image_bytes
# Store data
data.append(data_dict)
# Print progress
print(f"Processed: {filename}")
except ValueError as e:
print(f"Error processing {filename}: {e}")
except Exception as e:
print(f"Unexpected error processing {filename}: {e}")
# Convert to DataFrame
df = pd.DataFrame(data)
# Save to Parquet
output_path = 'image_summaries.parquet'
df.to_parquet(output_path, compression='snappy')
print(f"\nSaved dataset to {output_path}")
# Print summary statistics
print(f"\nDataset Summary:")
print(f"Total images processed: {len(df)}")
print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff