Vision training playground
This commit is contained in:
parent
07920c88ec
commit
247477edc8
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ config.py
|
||||
|
||||
# Unsloth
|
||||
_unsloth_sentencepiece_temp/
|
||||
unsloth_compiled_cache/
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
3
data/booru/.gitignore
vendored
Normal file
3
data/booru/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
*
|
||||
!README.md
|
||||
!.gitignore
|
1
data/booru/README.md
Normal file
1
data/booru/README.md
Normal file
@ -0,0 +1 @@
|
||||
Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"
|
2
data/package-lock.json
generated
2
data/package-lock.json
generated
@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "discord",
|
||||
"name": "data",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
|
177
data/proc_booru.py
Normal file
177
data/proc_booru.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""
|
||||
proc_booru.py
|
||||
This script assumes you have a folder called 'booru/' in the current directory,
|
||||
containing a bunch of images following the Shimmie Booru naming scheme, i.e.
|
||||
'12345 - Tag1 Tag2 Tag3.jpg'.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from unsloth import FastVisionModel
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import io
|
||||
import tqdm
|
||||
|
||||
# names of real-life people tagged in images
|
||||
NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
|
||||
# irrelevant tags that should just be removed
|
||||
IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
|
||||
|
||||
def parse_filename(filename: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
|
||||
Returns tuple of (id, [tags])
|
||||
"""
|
||||
# Remove file extension
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
# Split into ID and tags
|
||||
match = re.match(r'(\d+)\s*-\s*(.*)', name)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid filename format: {filename}")
|
||||
|
||||
image_id = match.group(1)
|
||||
tags = match.group(2).strip().split()
|
||||
|
||||
# remove irrelevant tags
|
||||
irrelevant_overlap = IRRELEVANT.intersection(tags)
|
||||
if len(irrelevant_overlap) > 0:
|
||||
for tag in irrelevant_overlap:
|
||||
tags.remove(tag)
|
||||
|
||||
# remove ambiguous situations with people's names, since the model won't know what they look like
|
||||
names_overlap = NAMES.intersection(tags)
|
||||
if len(names_overlap) > 1:
|
||||
for name in names_overlap:
|
||||
tags.remove(name)
|
||||
|
||||
return image_id, tags
|
||||
|
||||
def create_prompt(tags: List[str]) -> str:
|
||||
"""
|
||||
Create a prompt for the LLM to generate a summary based on tags.
|
||||
"""
|
||||
tags_str = ', '.join(tags)
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
|
||||
},
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tags: {tags_str}"
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def load_image_as_bytes(image_path: Path) -> bytes:
|
||||
"""
|
||||
Load an image file and return it as bytes.
|
||||
"""
|
||||
with PIL.Image.open(image_path) as img:
|
||||
# Convert to RGB if necessary
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Save to bytes
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format='JPEG')
|
||||
return img_byte_arr.getvalue()
|
||||
|
||||
def main():
|
||||
model, tokenizer = FastVisionModel.from_pretrained(
|
||||
"unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||
)
|
||||
FastVisionModel.for_inference(model)
|
||||
|
||||
# Process all images in the booru directory
|
||||
booru_dir = Path('booru')
|
||||
if not booru_dir.exists():
|
||||
raise FileNotFoundError("booru directory not found")
|
||||
|
||||
# Create lists to store data
|
||||
data = []
|
||||
|
||||
# Get all image files
|
||||
image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
||||
|
||||
# Process each file
|
||||
for filename in tqdm.tqdm(image_files):
|
||||
try:
|
||||
filepath = booru_dir / filename
|
||||
|
||||
# Parse filename
|
||||
image_id, tags = parse_filename(filename)
|
||||
|
||||
# Create prompt
|
||||
prompt = create_prompt(tags)
|
||||
input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
|
||||
image = PIL.Image.open(filepath)
|
||||
inputs = tokenizer(
|
||||
image,
|
||||
input_text,
|
||||
add_special_tokens = False,
|
||||
return_tensors = "pt",
|
||||
).to("cuda")
|
||||
|
||||
# Generate summary using VLLM
|
||||
outputs = model.generate(**inputs, max_new_tokens=128,
|
||||
use_cache=True, temperature=1.5, min_p=0.1)
|
||||
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||
generated_text = generated_text.partition('\n')[0]
|
||||
|
||||
# Load image as bytes
|
||||
image_bytes = load_image_as_bytes(filepath)
|
||||
|
||||
data_dict = {
|
||||
'image_id': image_id,
|
||||
'filename': filename,
|
||||
'tags': tags,
|
||||
'tags_string': ' '.join(tags),
|
||||
'summary': generated_text,
|
||||
}
|
||||
|
||||
print(data_dict)
|
||||
|
||||
data_dict['image_data'] = image_bytes
|
||||
|
||||
# Store data
|
||||
data.append(data_dict)
|
||||
|
||||
# Print progress
|
||||
print(f"Processed: {filename}")
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Error processing {filename}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Unexpected error processing {filename}: {e}")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Save to Parquet
|
||||
output_path = 'image_summaries.parquet'
|
||||
df.to_parquet(output_path, compression='snappy')
|
||||
print(f"\nSaved dataset to {output_path}")
|
||||
|
||||
# Print summary statistics
|
||||
print(f"\nDataset Summary:")
|
||||
print(f"Total images processed: {len(df)}")
|
||||
print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
|
||||
print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
7864
mikuai-vision-training-notebook.ipynb
Normal file
7864
mikuai-vision-training-notebook.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
15868
train_unsloth.ipynb
15868
train_unsloth.ipynb
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user