WIP: Vision model #10
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,6 +3,7 @@ config.py
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Unsloth
 | 
					# Unsloth
 | 
				
			||||||
_unsloth_sentencepiece_temp/
 | 
					_unsloth_sentencepiece_temp/
 | 
				
			||||||
 | 
					unsloth_compiled_cache/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ---> Python
 | 
					# ---> Python
 | 
				
			||||||
# Byte-compiled / optimized / DLL files
 | 
					# Byte-compiled / optimized / DLL files
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										3
									
								
								data/booru/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								data/booru/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					*
 | 
				
			||||||
 | 
					!README.md
 | 
				
			||||||
 | 
					!.gitignore
 | 
				
			||||||
							
								
								
									
										1
									
								
								data/booru/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								data/booru/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					Place booru images here, with filenames of the form "12345 - Tag1 Tag_2 Tag3.jpg"
 | 
				
			||||||
							
								
								
									
										2
									
								
								data/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								data/package-lock.json
									
									
									
										generated
									
									
									
								
							@ -1,5 +1,5 @@
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  "name": "discord",
 | 
					  "name": "data",
 | 
				
			||||||
  "lockfileVersion": 3,
 | 
					  "lockfileVersion": 3,
 | 
				
			||||||
  "requires": true,
 | 
					  "requires": true,
 | 
				
			||||||
  "packages": {
 | 
					  "packages": {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										177
									
								
								data/proc_booru.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								data/proc_booru.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,177 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					proc_booru.py
 | 
				
			||||||
 | 
					This script assumes you have a folder called 'booru/' in the current directory,
 | 
				
			||||||
 | 
					containing a bunch of images following the Shimmie Booru naming scheme, i.e.
 | 
				
			||||||
 | 
					'12345 - Tag1 Tag2 Tag3.jpg'.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from typing import List, Tuple
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					from unsloth import FastVisionModel
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import PIL
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# names of real-life people tagged in images
 | 
				
			||||||
 | 
					NAMES = set(["James", "Vincent", "Myles", "Sam", "Jake", "Nicolai", "David", "ren", "Nazar"])
 | 
				
			||||||
 | 
					# irrelevant tags that should just be removed
 | 
				
			||||||
 | 
					IRRELEVANT = set(["_", "Myles'", "Vinny's", "Jake's", "tagme", "Nguyen"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_filename(filename: str) -> Tuple[str, List[str]]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Parse a filename of format '12345 - Tag1 Tag2 Tag3.jpg' into ID and tags.
 | 
				
			||||||
 | 
					    Returns tuple of (id, [tags])
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Remove file extension
 | 
				
			||||||
 | 
					    name = os.path.splitext(filename)[0]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Split into ID and tags
 | 
				
			||||||
 | 
					    match = re.match(r'(\d+)\s*-\s*(.*)', name)
 | 
				
			||||||
 | 
					    if not match:
 | 
				
			||||||
 | 
					        raise ValueError(f"Invalid filename format: {filename}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    image_id = match.group(1)
 | 
				
			||||||
 | 
					    tags = match.group(2).strip().split()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # remove irrelevant tags
 | 
				
			||||||
 | 
					    irrelevant_overlap = IRRELEVANT.intersection(tags)
 | 
				
			||||||
 | 
					    if len(irrelevant_overlap) > 0:
 | 
				
			||||||
 | 
					        for tag in irrelevant_overlap:
 | 
				
			||||||
 | 
					            tags.remove(tag)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove ambiguous situations with people's names, since the model won't know what they look like
 | 
				
			||||||
 | 
					    names_overlap = NAMES.intersection(tags)
 | 
				
			||||||
 | 
					    if len(names_overlap) > 1:
 | 
				
			||||||
 | 
					        for name in names_overlap:
 | 
				
			||||||
 | 
					            tags.remove(name)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return image_id, tags
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_prompt(tags: List[str]) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Create a prompt for the LLM to generate a summary based on tags.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    tags_str = ', '.join(tags)
 | 
				
			||||||
 | 
					    return [
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "role": "user",
 | 
				
			||||||
 | 
					            "content": [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "type": "text",
 | 
				
			||||||
 | 
					                    "text": "You are a helpful assistant. You must write a caption describing the following image, given a list of tags describing the image. Your response must contain absolutely nothing apart from a caption. Keep it as concise as you possibly can, at a hard maximum of two sentences. Avoid describing any small details, simply focus on the main subject of the image. Responses that are simply a repeat of the input are strictly forbidden. Your responses should be said with certainty.\n\nExample:\n```\nTags: 1991_Honda_Civic, Cisco_Parking_Lot, Grayscale, Milpitas, UnionPay\nThe image depicts a black and white photograph of a 1991 Honda Civic sedan parked in a Cisco parking lot in Milpitas, with a partial UnionPay advertisement visible.\n```\n\nExample:\n```\nTags: 2015_Honda_CB300F, Encinal_Canyon_Road, Malibu\nThe image features a 2015 Honda CB300F motorcycle parked on the side of Encinal Canyon Road in Malibu.\n```"
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                {"type": "image"},
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "type": "text",
 | 
				
			||||||
 | 
					                    "text": f"Tags: {tags_str}"
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_image_as_bytes(image_path: Path) -> bytes:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Load an image file and return it as bytes.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    with PIL.Image.open(image_path) as img:
 | 
				
			||||||
 | 
					        # Convert to RGB if necessary
 | 
				
			||||||
 | 
					        if img.mode != 'RGB':
 | 
				
			||||||
 | 
					            img = img.convert('RGB')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Save to bytes
 | 
				
			||||||
 | 
					        img_byte_arr = io.BytesIO()
 | 
				
			||||||
 | 
					        img.save(img_byte_arr, format='JPEG')
 | 
				
			||||||
 | 
					        return img_byte_arr.getvalue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    model, tokenizer = FastVisionModel.from_pretrained(
 | 
				
			||||||
 | 
					        "unsloth/Llama-3.2-11B-Vision-Instruct",
 | 
				
			||||||
 | 
					        load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
 | 
				
			||||||
 | 
					        use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    FastVisionModel.for_inference(model)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Process all images in the booru directory
 | 
				
			||||||
 | 
					    booru_dir = Path('booru')
 | 
				
			||||||
 | 
					    if not booru_dir.exists():
 | 
				
			||||||
 | 
					        raise FileNotFoundError("booru directory not found")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Create lists to store data
 | 
				
			||||||
 | 
					    data = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get all image files
 | 
				
			||||||
 | 
					    image_files = [f for f in os.listdir(booru_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Process each file
 | 
				
			||||||
 | 
					    for filename in tqdm.tqdm(image_files):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            filepath = booru_dir / filename
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Parse filename
 | 
				
			||||||
 | 
					            image_id, tags = parse_filename(filename)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Create prompt
 | 
				
			||||||
 | 
					            prompt = create_prompt(tags)
 | 
				
			||||||
 | 
					            input_text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True)
 | 
				
			||||||
 | 
					            image = PIL.Image.open(filepath)
 | 
				
			||||||
 | 
					            inputs = tokenizer(
 | 
				
			||||||
 | 
					                image,
 | 
				
			||||||
 | 
					                input_text,
 | 
				
			||||||
 | 
					                add_special_tokens = False,
 | 
				
			||||||
 | 
					                return_tensors = "pt",
 | 
				
			||||||
 | 
					            ).to("cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Generate summary using VLLM
 | 
				
			||||||
 | 
					            outputs = model.generate(**inputs, max_new_tokens=128,
 | 
				
			||||||
 | 
					                   use_cache=True, temperature=1.5, min_p=0.1)
 | 
				
			||||||
 | 
					            generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 | 
				
			||||||
 | 
					            generated_text = generated_text.partition('\n')[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Load image as bytes
 | 
				
			||||||
 | 
					            image_bytes = load_image_as_bytes(filepath)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            data_dict = {
 | 
				
			||||||
 | 
					                'image_id': image_id,
 | 
				
			||||||
 | 
					                'filename': filename,
 | 
				
			||||||
 | 
					                'tags': tags,
 | 
				
			||||||
 | 
					                'tags_string': ' '.join(tags),
 | 
				
			||||||
 | 
					                'summary': generated_text,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            print(data_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            data_dict['image_data'] = image_bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Store data
 | 
				
			||||||
 | 
					            data.append(data_dict)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # Print progress
 | 
				
			||||||
 | 
					            print(f"Processed: {filename}")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					        except ValueError as e:
 | 
				
			||||||
 | 
					            print(f"Error processing {filename}: {e}")
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            print(f"Unexpected error processing {filename}: {e}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Convert to DataFrame
 | 
				
			||||||
 | 
					    df = pd.DataFrame(data)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Save to Parquet
 | 
				
			||||||
 | 
					    output_path = 'image_summaries.parquet'
 | 
				
			||||||
 | 
					    df.to_parquet(output_path, compression='snappy')
 | 
				
			||||||
 | 
					    print(f"\nSaved dataset to {output_path}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Print summary statistics
 | 
				
			||||||
 | 
					    print(f"\nDataset Summary:")
 | 
				
			||||||
 | 
					    print(f"Total images processed: {len(df)}")
 | 
				
			||||||
 | 
					    print(f"Unique tags: {len(set(' '.join(df['tags_string']).split()))}")
 | 
				
			||||||
 | 
					    print(f"Average summary length: {df['summary'].str.len().mean():.1f} characters")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
							
								
								
									
										7864
									
								
								mikuai-vision-training-notebook.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7864
									
								
								mikuai-vision-training-notebook.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										15868
									
								
								train_unsloth.ipynb
									
									
									
									
									
								
							
							
						
						
									
										15868
									
								
								train_unsloth.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user