Compare commits
No commits in common. "vision" and "main" have entirely different histories.
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,7 +3,6 @@ config.py
|
|||||||
|
|
||||||
# Unsloth
|
# Unsloth
|
||||||
_unsloth_sentencepiece_temp/
|
_unsloth_sentencepiece_temp/
|
||||||
unsloth_compiled_cache/
|
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
3
data/booru/.gitignore
vendored
3
data/booru/.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
*
|
|
||||||
!README.md
|
|
||||||
!.gitignore
|
|
@ -1 +0,0 @@
|
|||||||
Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"
|
|
2
data/package-lock.json
generated
2
data/package-lock.json
generated
@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "data",
|
"name": "discord",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
|
@ -1,177 +0,0 @@
|
|||||||
"""
|
|
||||||
proc_booru.py
|
|
||||||
This script assumes you have a folder called 'booru/' in the current directory,
|
|
||||||
containing a bunch of images following the Shimmie Booru naming scheme, i.e.
|
|
||||||
'12345 - Tag1 Tag2 Tag3.jpg'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple
|
|
||||||
import re
|
|
||||||
from unsloth import FastVisionModel
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import PIL
|
|
||||||
import torch
|
|
||||||
import io
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
# names of real-life people tagged in images
|
|
||||||
NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
|
|
||||||
# irrelevant tags that should just be removed
|
|
||||||
IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
|
|
||||||
|
|
||||||
def parse_filename(filename: str) -> Tuple[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
|
|
||||||
Returns tuple of (id, [tags])
|
|
||||||
"""
|
|
||||||
# Remove file extension
|
|
||||||
name = os.path.splitext(filename)[0]
|
|
||||||
|
|
||||||
# Split into ID and tags
|
|
||||||
match = re.match(r'(\d+)\s*-\s*(.*)', name)
|
|
||||||
if not match:
|
|
||||||
raise ValueError(f"Invalid filename format: {filename}")
|
|
||||||
|
|
||||||
image_id = match.group(1)
|
|
||||||
tags = match.group(2).strip().split()
|
|
||||||
|
|
||||||
# remove irrelevant tags
|
|
||||||
irrelevant_overlap = IRRELEVANT.intersection(tags)
|
|
||||||
if len(irrelevant_overlap) > 0:
|
|
||||||
for tag in irrelevant_overlap:
|
|
||||||
tags.remove(tag)
|
|
||||||
|
|
||||||
# remove ambiguous situations with people's names, since the model won't know what they look like
|
|
||||||
names_overlap = NAMES.intersection(tags)
|
|
||||||
if len(names_overlap) > 1:
|
|
||||||
for name in names_overlap:
|
|
||||||
tags.remove(name)
|
|
||||||
|
|
||||||
return image_id, tags
|
|
||||||
|
|
||||||
def create_prompt(tags: List[str]) -> str:
|
|
||||||
"""
|
|
||||||
Create a prompt for the LLM to generate a summary based on tags.
|
|
||||||
"""
|
|
||||||
tags_str = ', '.join(tags)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
|
|
||||||
},
|
|
||||||
{"type": "image"},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"Tags: {tags_str}"
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
def load_image_as_bytes(image_path: Path) -> bytes:
|
|
||||||
"""
|
|
||||||
Load an image file and return it as bytes.
|
|
||||||
"""
|
|
||||||
with PIL.Image.open(image_path) as img:
|
|
||||||
# Convert to RGB if necessary
|
|
||||||
if img.mode != 'RGB':
|
|
||||||
img = img.convert('RGB')
|
|
||||||
|
|
||||||
# Save to bytes
|
|
||||||
img_byte_arr = io.BytesIO()
|
|
||||||
img.save(img_byte_arr, format='JPEG')
|
|
||||||
return img_byte_arr.getvalue()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
model, tokenizer = FastVisionModel.from_pretrained(
|
|
||||||
"unsloth/Llama-3.2-11B-Vision-Instruct",
|
|
||||||
load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
|
|
||||||
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
|
|
||||||
)
|
|
||||||
FastVisionModel.for_inference(model)
|
|
||||||
|
|
||||||
# Process all images in the booru directory
|
|
||||||
booru_dir = Path('booru')
|
|
||||||
if not booru_dir.exists():
|
|
||||||
raise FileNotFoundError("booru directory not found")
|
|
||||||
|
|
||||||
# Create lists to store data
|
|
||||||
data = []
|
|
||||||
|
|
||||||
# Get all image files
|
|
||||||
image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
||||||
|
|
||||||
# Process each file
|
|
||||||
for filename in tqdm.tqdm(image_files):
|
|
||||||
try:
|
|
||||||
filepath = booru_dir / filename
|
|
||||||
|
|
||||||
# Parse filename
|
|
||||||
image_id, tags = parse_filename(filename)
|
|
||||||
|
|
||||||
# Create prompt
|
|
||||||
prompt = create_prompt(tags)
|
|
||||||
input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
|
|
||||||
image = PIL.Image.open(filepath)
|
|
||||||
inputs = tokenizer(
|
|
||||||
image,
|
|
||||||
input_text,
|
|
||||||
add_special_tokens = False,
|
|
||||||
return_tensors = "pt",
|
|
||||||
).to("cuda")
|
|
||||||
|
|
||||||
# Generate summary using VLLM
|
|
||||||
outputs = model.generate(**inputs, max_new_tokens=128,
|
|
||||||
use_cache=True, temperature=1.5, min_p=0.1)
|
|
||||||
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
|
||||||
generated_text = generated_text.partition('\n')[0]
|
|
||||||
|
|
||||||
# Load image as bytes
|
|
||||||
image_bytes = load_image_as_bytes(filepath)
|
|
||||||
|
|
||||||
data_dict = {
|
|
||||||
'image_id': image_id,
|
|
||||||
'filename': filename,
|
|
||||||
'tags': tags,
|
|
||||||
'tags_string': ' '.join(tags),
|
|
||||||
'summary': generated_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(data_dict)
|
|
||||||
|
|
||||||
data_dict['image_data'] = image_bytes
|
|
||||||
|
|
||||||
# Store data
|
|
||||||
data.append(data_dict)
|
|
||||||
|
|
||||||
# Print progress
|
|
||||||
print(f"Processed: {filename}")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
print(f"Error processing {filename}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Unexpected error processing {filename}: {e}")
|
|
||||||
|
|
||||||
# Convert to DataFrame
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
|
|
||||||
# Save to Parquet
|
|
||||||
output_path = 'image_summaries.parquet'
|
|
||||||
df.to_parquet(output_path, compression='snappy')
|
|
||||||
print(f"\nSaved dataset to {output_path}")
|
|
||||||
|
|
||||||
# Print summary statistics
|
|
||||||
print(f"\nDataset Summary:")
|
|
||||||
print(f"Total images processed: {len(df)}")
|
|
||||||
print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
|
|
||||||
print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
File diff suppressed because it is too large
Load Diff
15868
train_unsloth.ipynb
Normal file
15868
train_unsloth.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user