WIP: Vision model #10
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,6 +3,7 @@ config.py
 | 
			
		||||
 | 
			
		||||
# Unsloth
 | 
			
		||||
_unsloth_sentencepiece_temp/
 | 
			
		||||
unsloth_compiled_cache/
 | 
			
		||||
 | 
			
		||||
# ---> Python
 | 
			
		||||
# Byte-compiled / optimized / DLL files
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								data/booru/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								data/booru/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
*
 | 
			
		||||
!README.md
 | 
			
		||||
!.gitignore
 | 
			
		||||
							
								
								
									
										1
									
								
								data/booru/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								data/booru/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"
 | 
			
		||||
							
								
								
									
										2
									
								
								data/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								data/package-lock.json
									
									
									
										generated
									
									
									
								
							@ -1,5 +1,5 @@
 | 
			
		||||
{
 | 
			
		||||
  "name": "discord",
 | 
			
		||||
  "name": "data",
 | 
			
		||||
  "lockfileVersion": 3,
 | 
			
		||||
  "requires": true,
 | 
			
		||||
  "packages": {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										177
									
								
								data/proc_booru.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								data/proc_booru.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,177 @@
 | 
			
		||||
"""
 | 
			
		||||
proc_booru.py
 | 
			
		||||
This script assumes you have a folder called 'booru/' in the current directory,
 | 
			
		||||
containing a bunch of images following the Shimmie Booru naming scheme, i.e.
 | 
			
		||||
'12345 - Tag1 Tag2 Tag3.jpg'.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import List, Tuple
 | 
			
		||||
import re
 | 
			
		||||
from unsloth import FastVisionModel
 | 
			
		||||
import pandas as pd
 | 
			
		||||
import numpy as np
 | 
			
		||||
import PIL
 | 
			
		||||
import torch
 | 
			
		||||
import io
 | 
			
		||||
import tqdm
 | 
			
		||||
 | 
			
		||||
# names of real-life people tagged in images
 | 
			
		||||
NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
 | 
			
		||||
# irrelevant tags that should just be removed
 | 
			
		||||
IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
 | 
			
		||||
 | 
			
		||||
def parse_filename(filename: str) -> Tuple[str, List[str]]:
 | 
			
		||||
    """
 | 
			
		||||
    Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
 | 
			
		||||
    Returns tuple of (id, [tags])
 | 
			
		||||
    """
 | 
			
		||||
    # Remove file extension
 | 
			
		||||
    name = os.path.splitext(filename)[0]
 | 
			
		||||
    
 | 
			
		||||
    # Split into ID and tags
 | 
			
		||||
    match = re.match(r'(\d+)\s*-\s*(.*)', name)
 | 
			
		||||
    if not match:
 | 
			
		||||
        raise ValueError(f"Invalid filename format: {filename}")
 | 
			
		||||
    
 | 
			
		||||
    image_id = match.group(1)
 | 
			
		||||
    tags = match.group(2).strip().split()
 | 
			
		||||
    
 | 
			
		||||
    # remove irrelevant tags
 | 
			
		||||
    irrelevant_overlap = IRRELEVANT.intersection(tags)
 | 
			
		||||
    if len(irrelevant_overlap) > 0:
 | 
			
		||||
        for tag in irrelevant_overlap:
 | 
			
		||||
            tags.remove(tag)
 | 
			
		||||
 | 
			
		||||
    # remove ambiguous situations with people's names, since the model won't know what they look like
 | 
			
		||||
    names_overlap = NAMES.intersection(tags)
 | 
			
		||||
    if len(names_overlap) > 1:
 | 
			
		||||
        for name in names_overlap:
 | 
			
		||||
            tags.remove(name)
 | 
			
		||||
    
 | 
			
		||||
    return image_id, tags
 | 
			
		||||
 | 
			
		||||
def create_prompt(tags: List[str]) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Create a prompt for the LLM to generate a summary based on tags.
 | 
			
		||||
    """
 | 
			
		||||
    tags_str = ', '.join(tags)
 | 
			
		||||
    return [
 | 
			
		||||
        {
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "content": [
 | 
			
		||||
                {
 | 
			
		||||
                    "type": "text",
 | 
			
		||||
                    "text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
 | 
			
		||||
                },
 | 
			
		||||
                {"type": "image"},
 | 
			
		||||
                {
 | 
			
		||||
                    "type": "text",
 | 
			
		||||
                    "text": f"Tags: {tags_str}"
 | 
			
		||||
                },
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
def load_image_as_bytes(image_path: Path) -> bytes:
 | 
			
		||||
    """
 | 
			
		||||
    Load an image file and return it as bytes.
 | 
			
		||||
    """
 | 
			
		||||
    with PIL.Image.open(image_path) as img:
 | 
			
		||||
        # Convert to RGB if necessary
 | 
			
		||||
        if img.mode != 'RGB':
 | 
			
		||||
            img = img.convert('RGB')
 | 
			
		||||
        
 | 
			
		||||
        # Save to bytes
 | 
			
		||||
        img_byte_arr = io.BytesIO()
 | 
			
		||||
        img.save(img_byte_arr, format='JPEG')
 | 
			
		||||
        return img_byte_arr.getvalue()
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    model, tokenizer = FastVisionModel.from_pretrained(
 | 
			
		||||
        "unsloth/Llama-3.2-11B-Vision-Instruct",
 | 
			
		||||
        load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
 | 
			
		||||
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
 | 
			
		||||
    )
 | 
			
		||||
    FastVisionModel.for_inference(model)
 | 
			
		||||
    
 | 
			
		||||
    # Process all images in the booru directory
 | 
			
		||||
    booru_dir = Path('booru')
 | 
			
		||||
    if not booru_dir.exists():
 | 
			
		||||
        raise FileNotFoundError("booru directory not found")
 | 
			
		||||
    
 | 
			
		||||
    # Create lists to store data
 | 
			
		||||
    data = []
 | 
			
		||||
 | 
			
		||||
    # Get all image files
 | 
			
		||||
    image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 | 
			
		||||
 | 
			
		||||
    # Process each file
 | 
			
		||||
    for filename in tqdm.tqdm(image_files):
 | 
			
		||||
        try:
 | 
			
		||||
            filepath = booru_dir / filename
 | 
			
		||||
            
 | 
			
		||||
            # Parse filename
 | 
			
		||||
            image_id, tags = parse_filename(filename)
 | 
			
		||||
            
 | 
			
		||||
            # Create prompt
 | 
			
		||||
            prompt = create_prompt(tags)
 | 
			
		||||
            input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
 | 
			
		||||
            image = PIL.Image.open(filepath)
 | 
			
		||||
            inputs = tokenizer(
 | 
			
		||||
                image,
 | 
			
		||||
                input_text,
 | 
			
		||||
                add_special_tokens = False,
 | 
			
		||||
                return_tensors = "pt",
 | 
			
		||||
            ).to("cuda")
 | 
			
		||||
 | 
			
		||||
            # Generate summary using VLLM
 | 
			
		||||
            outputs = model.generate(**inputs, max_new_tokens=128,
 | 
			
		||||
                   use_cache=True, temperature=1.5, min_p=0.1)
 | 
			
		||||
            generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 | 
			
		||||
            generated_text = generated_text.partition('\n')[0]
 | 
			
		||||
 | 
			
		||||
            # Load image as bytes
 | 
			
		||||
            image_bytes = load_image_as_bytes(filepath)
 | 
			
		||||
            
 | 
			
		||||
            data_dict = {
 | 
			
		||||
                'image_id': image_id,
 | 
			
		||||
                'filename': filename,
 | 
			
		||||
                'tags': tags,
 | 
			
		||||
                'tags_string': ' '.join(tags),
 | 
			
		||||
                'summary': generated_text,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            print(data_dict)
 | 
			
		||||
 | 
			
		||||
            data_dict['image_data'] = image_bytes
 | 
			
		||||
 | 
			
		||||
            # Store data
 | 
			
		||||
            data.append(data_dict)
 | 
			
		||||
            
 | 
			
		||||
            # Print progress
 | 
			
		||||
            print(f"Processed: {filename}")
 | 
			
		||||
            
 | 
			
		||||
        except ValueError as e:
 | 
			
		||||
            print(f"Error processing {filename}: {e}")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"Unexpected error processing {filename}: {e}")
 | 
			
		||||
    
 | 
			
		||||
    # Convert to DataFrame
 | 
			
		||||
    df = pd.DataFrame(data)
 | 
			
		||||
    
 | 
			
		||||
    # Save to Parquet
 | 
			
		||||
    output_path = 'image_summaries.parquet'
 | 
			
		||||
    df.to_parquet(output_path, compression='snappy')
 | 
			
		||||
    print(f"\nSaved dataset to {output_path}")
 | 
			
		||||
    
 | 
			
		||||
    # Print summary statistics
 | 
			
		||||
    print(f"\nDataset Summary:")
 | 
			
		||||
    print(f"Total images processed: {len(df)}")
 | 
			
		||||
    print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
 | 
			
		||||
    print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										7864
									
								
								mikuai-vision-training-notebook.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7864
									
								
								mikuai-vision-training-notebook.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										15868
									
								
								train_unsloth.ipynb
									
									
									
									
									
								
							
							
						
						
									
										15868
									
								
								train_unsloth.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user