Vision training playground
This commit is contained in:
parent
07920c88ec
commit
247477edc8
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,6 +3,7 @@ config.py
|
|||||||
|
|
||||||
# Unsloth
|
# Unsloth
|
||||||
_unsloth_sentencepiece_temp/
|
_unsloth_sentencepiece_temp/
|
||||||
|
unsloth_compiled_cache/
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
3
data/booru/.gitignore
vendored
Normal file
3
data/booru/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
*
|
||||||
|
!README.md
|
||||||
|
!.gitignore
|
1
data/booru/README.md
Normal file
1
data/booru/README.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"
|
2
data/package-lock.json
generated
2
data/package-lock.json
generated
@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "discord",
|
"name": "data",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
|
177
data/proc_booru.py
Normal file
177
data/proc_booru.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
proc_booru.py
|
||||||
|
This script assumes you have a folder called 'booru/' in the current directory,
|
||||||
|
containing a bunch of images following the Shimmie Booru naming scheme, i.e.
|
||||||
|
'12345 - Tag1 Tag2 Tag3.jpg'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
import re
|
||||||
|
from unsloth import FastVisionModel
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import io
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
# names of real-life people tagged in images
|
||||||
|
NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
|
||||||
|
# irrelevant tags that should just be removed
|
||||||
|
IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
|
||||||
|
|
||||||
|
def parse_filename(filename: str) -> Tuple[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
|
||||||
|
Returns tuple of (id, [tags])
|
||||||
|
"""
|
||||||
|
# Remove file extension
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
# Split into ID and tags
|
||||||
|
match = re.match(r'(\d+)\s*-\s*(.*)', name)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid filename format: {filename}")
|
||||||
|
|
||||||
|
image_id = match.group(1)
|
||||||
|
tags = match.group(2).strip().split()
|
||||||
|
|
||||||
|
# remove irrelevant tags
|
||||||
|
irrelevant_overlap = IRRELEVANT.intersection(tags)
|
||||||
|
if len(irrelevant_overlap) > 0:
|
||||||
|
for tag in irrelevant_overlap:
|
||||||
|
tags.remove(tag)
|
||||||
|
|
||||||
|
# remove ambiguous situations with people's names, since the model won't know what they look like
|
||||||
|
names_overlap = NAMES.intersection(tags)
|
||||||
|
if len(names_overlap) > 1:
|
||||||
|
for name in names_overlap:
|
||||||
|
tags.remove(name)
|
||||||
|
|
||||||
|
return image_id, tags
|
||||||
|
|
||||||
|
def create_prompt(tags: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Create a prompt for the LLM to generate a summary based on tags.
|
||||||
|
"""
|
||||||
|
tags_str = ', '.join(tags)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
|
||||||
|
},
|
||||||
|
{"type": "image"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Tags: {tags_str}"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_image_as_bytes(image_path: Path) -> bytes:
|
||||||
|
"""
|
||||||
|
Load an image file and return it as bytes.
|
||||||
|
"""
|
||||||
|
with PIL.Image.open(image_path) as img:
|
||||||
|
# Convert to RGB if necessary
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
# Save to bytes
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
img.save(img_byte_arr, format='JPEG')
|
||||||
|
return img_byte_arr.getvalue()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model, tokenizer = FastVisionModel.from_pretrained(
|
||||||
|
"unsloth/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
||||||
|
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
||||||
|
)
|
||||||
|
FastVisionModel.for_inference(model)
|
||||||
|
|
||||||
|
# Process all images in the booru directory
|
||||||
|
booru_dir = Path('booru')
|
||||||
|
if not booru_dir.exists():
|
||||||
|
raise FileNotFoundError("booru directory not found")
|
||||||
|
|
||||||
|
# Create lists to store data
|
||||||
|
data = []
|
||||||
|
|
||||||
|
# Get all image files
|
||||||
|
image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
||||||
|
|
||||||
|
# Process each file
|
||||||
|
for filename in tqdm.tqdm(image_files):
|
||||||
|
try:
|
||||||
|
filepath = booru_dir / filename
|
||||||
|
|
||||||
|
# Parse filename
|
||||||
|
image_id, tags = parse_filename(filename)
|
||||||
|
|
||||||
|
# Create prompt
|
||||||
|
prompt = create_prompt(tags)
|
||||||
|
input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
|
||||||
|
image = PIL.Image.open(filepath)
|
||||||
|
inputs = tokenizer(
|
||||||
|
image,
|
||||||
|
input_text,
|
||||||
|
add_special_tokens = False,
|
||||||
|
return_tensors = "pt",
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
# Generate summary using VLLM
|
||||||
|
outputs = model.generate(**inputs, max_new_tokens=128,
|
||||||
|
use_cache=True, temperature=1.5, min_p=0.1)
|
||||||
|
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
||||||
|
generated_text = generated_text.partition('\n')[0]
|
||||||
|
|
||||||
|
# Load image as bytes
|
||||||
|
image_bytes = load_image_as_bytes(filepath)
|
||||||
|
|
||||||
|
data_dict = {
|
||||||
|
'image_id': image_id,
|
||||||
|
'filename': filename,
|
||||||
|
'tags': tags,
|
||||||
|
'tags_string': ' '.join(tags),
|
||||||
|
'summary': generated_text,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(data_dict)
|
||||||
|
|
||||||
|
data_dict['image_data'] = image_bytes
|
||||||
|
|
||||||
|
# Store data
|
||||||
|
data.append(data_dict)
|
||||||
|
|
||||||
|
# Print progress
|
||||||
|
print(f"Processed: {filename}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Error processing {filename}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error processing {filename}: {e}")
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# Save to Parquet
|
||||||
|
output_path = 'image_summaries.parquet'
|
||||||
|
df.to_parquet(output_path, compression='snappy')
|
||||||
|
print(f"\nSaved dataset to {output_path}")
|
||||||
|
|
||||||
|
# Print summary statistics
|
||||||
|
print(f"\nDataset Summary:")
|
||||||
|
print(f"Total images processed: {len(df)}")
|
||||||
|
print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
|
||||||
|
print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
7864
mikuai-vision-training-notebook.ipynb
Normal file
7864
mikuai-vision-training-notebook.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
15868
train_unsloth.ipynb
15868
train_unsloth.ipynb
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user