diff --git a/README.md b/README.md index fa52169..ed59902 100644 --- a/README.md +++ b/README.md @@ -182,3 +182,5 @@ Mostly self explanatory. Paste in the path to the directory you want processed For instance, if you are writing webscrapers, make sure to collect metadata from the webpage as you go rather than blindly just download each image. Perhaps you might include the website address or full URI of the webpage, the `` tag from the webpage, or the `alt-text` field. Save this information with each image, or in a database. Then feed into the VLM with a `hint_source`. New hint sources are very easy for an amateur Python programmer to write, or you can have an LLM write for you. See [HINTSOURCES.md](HINTSOURCES.md) for more information. + +- **Batch concurrency**: If you use a VLM host that support batch concurrency such as llama.cpp (via -np n arg) you can potentially increase speed. This is not supported by LM Studio. Example command: `llama-server -np 4 -c 32768 --mmproj "mmproj-Qwen3-VL-32B-Instruct-F16.gguf" --model "Qwen3-VL-32B-Instruct-Q4_K_M.gguf" -dev cuda0 --top-k 30 --top-p 0.95 --min-p 0.05 --temp 0.5` would launch Qwen3VL 32B with four concurrent processes (-np 4) each with 8192 tokens (32768/4) of context for each of the 4 slots. This requires additional processing power and an increase of total context size (`-np 4 -c 32768` instead of `-np 1 -c 8192` as an example), but may increase total token generation speeds by utilizing batch processing. _This feature does not utilize the OpenAI jsonl batch API suitable for commercial APIs to save on costs, but should work to speed up rates._ diff --git a/caption.yaml b/caption.yaml index 7b1e215..9ebbba9 100644 --- a/caption.yaml +++ b/caption.yaml @@ -7,6 +7,7 @@ api_key_env_vars: model: llama-4-scout-17b-16e-instruct api_key: '' max_tokens: 16384 +concurrent_batch_size: 1 system_prompt: You are to analyze an image and provide information based on what is visible in the image. Do not embellish, and avoid langauge like 'showcases' or 'features,' preferring to focus on factual diff --git a/caption_openai.py b/caption_openai.py index 0a8989e..2912a20 100644 --- a/caption_openai.py +++ b/caption_openai.py @@ -142,11 +142,39 @@ async def process_image(client: openai.AsyncOpenAI, image_path, conf) -> Tuple[s messages = remove_base64_image(messages) return final_summary_response, json.dumps(messages, indent=2), prompt_tokens_usage, completion_tokens_usage +async def process_batch(client: openai.AsyncOpenAI, image_paths: list, conf) -> Tuple[int, int]: + """Process a batch of images concurrently and return aggregated token usage.""" + tasks = [process_image(client, image_path, conf) for image_path in image_paths] + + # Process all images in the batch concurrently + results = await asyncio.gather(*tasks, return_exceptions=True) + + batch_prompt_tokens = 0 + batch_completion_tokens = 0 + + for i, result in enumerate(results): + if isinstance(result, Exception): + print(filter_ascii(f"Error processing {image_paths[i]}: {result}")) + else: + caption_text, chat_history, prompt_token_usage, completion_token_usage = result # type: ignore + + await save_caption(file_path=image_paths[i], caption_text=caption_text, debug_info=chat_history) + + batch_prompt_tokens += prompt_token_usage + batch_completion_tokens += completion_token_usage + + print(filter_ascii(f" --> Processed {image_paths[i]}")) + print(f" --> prompt_token_usage: {prompt_token_usage}, completion_token_usage: {completion_token_usage}") + + return batch_prompt_tokens, batch_completion_tokens + async def main(): import hints.registration as registration registration._validate_hint_sources() conf = OmegaConf.load("caption.yaml") + + concurrent_batch_size = conf.concurrent_batch_size if conf.get("global_metadata_file"): # type: ignore async with aiofiles.open(conf.global_metadata_file) as f: @@ -154,6 +182,7 @@ async def main(): conf.system_prompt = f"{global_metadata}\n{conf.system_prompt}" print(filter_ascii(f" -> SYSTEM PROMPT:\n{conf.system_prompt}\n")) + print(filter_ascii(f" -> CONCURRENT BATCH SIZE: {concurrent_batch_size}\n")) api_key = resolve_api_key(conf) @@ -162,32 +191,31 @@ async def main(): aggregated_prompt_token_usage = 0 aggregated_completion_token_usage = 0 + # Collect images in batches for concurrent processing + batch = [] async for image_path in image_walk(conf.base_directory, recursive=conf.recursive, skip_if_txt_exists=conf.skip_if_txt_exists): current_task = asyncio.current_task() if current_task is not None and current_task.cancelled(): print("Captioning task was cancelled by user") return - print(filter_ascii(f"\nProcessing {image_path}")) - try: - start_time = time.perf_counter() - caption_text, chat_history, prompt_token_usage, completion_token_usage = await process_image(client, image_path, conf) - total_time = (time.perf_counter() - start_time) - except openai.APIConnectionError as e: - print(f"{e}\nAPI Error. Check that your service is running and caption.yaml has the correct base_url") - except asyncio.CancelledError: - print("Captioning task was cancelled during image processing") - raise + batch.append(image_path) - aggregated_prompt_token_usage += prompt_token_usage - aggregated_completion_token_usage += aggregated_completion_token_usage + if len(batch) >= concurrent_batch_size: + print(filter_ascii(f"\nProcessing batch of {len(batch)} images:")) + start_time = time.perf_counter() + + batch_prompt_tokens, batch_completion_tokens = await process_batch(client, batch, conf) + + batch_time = (time.perf_counter() - start_time) + print(filter_ascii(f" --> Batch completed in {batch_time:.2f}s, {batch_time/concurrent_batch_size:.2f}s per image")) + + aggregated_prompt_token_usage += batch_prompt_tokens + aggregated_completion_token_usage += batch_completion_tokens + batch = [] - print(filter_ascii(f" --> Took {total_time:.2f}s, Final caption:\n{caption_text}")) - print(f" --> prompt_token_usage: {prompt_token_usage}, completion_token_usage: {completion_token_usage}") - await save_caption(file_path=image_path, caption_text=caption_text, debug_info=chat_history) print(F" -> JOB COMPLETE.") - # not working? print(f"aggregated_prompt_token_usage: {aggregated_prompt_token_usage}, aggregated_completion_token_usage: {aggregated_completion_token_usage}") if __name__ == "__main__": diff --git a/ui/src/App.js b/ui/src/App.js index 5d017df..13bb3aa 100644 --- a/ui/src/App.js +++ b/ui/src/App.js @@ -18,7 +18,8 @@ function App() { recursive: false, hint_sources: [], global_metadata_file: '', - skip_if_txt_exists: false + skip_if_txt_exists: false, + concurrent_batch_size: 1 }); const [configLoading, setConfigLoading] = useState(false); const [configError, setConfigError] = useState(''); @@ -106,7 +107,8 @@ function App() { recursive: data.config.recursive || false, hint_sources: data.config.hint_sources || [], global_metadata_file: data.config.global_metadata_file || '', - skip_if_txt_exists: data.config.skip_if_txt_exists || false + skip_if_txt_exists: data.config.skip_if_txt_exists || false, + concurrent_batch_size: data.config.concurrent_batch_size || 1 }; setConfig(newConfig); if (newConfig.base_url) { @@ -147,7 +149,8 @@ function App() { recursive: config.recursive, hint_sources: config.hint_sources, global_metadata_file: config.global_metadata_file, - skip_if_txt_exists: config.skip_if_txt_exists + skip_if_txt_exists: config.skip_if_txt_exists, + concurrent_batch_size: config.concurrent_batch_size } }), }); diff --git a/ui/src/components/ConfigForm.js b/ui/src/components/ConfigForm.js index a1ab068..5ec40bf 100644 --- a/ui/src/components/ConfigForm.js +++ b/ui/src/components/ConfigForm.js @@ -105,6 +105,13 @@ const ConfigForm = ({ onConfigChange('retry_rules', newRetryRules); }; + const handleConcurrentBatchSizeChange = (e) => { + const value = parseInt(e.target.value); + if (value >= 1 && value <= 16) { + onConfigChange('concurrent_batch_size', value); + } + }; + if (configLoading) return <p>Loading configuration...</p>; return ( @@ -127,7 +134,7 @@ const ConfigForm = ({ onChange={(e) => onConfigChange('base_url', e.target.value)} placeholder="e.g., http://localhost:1234/v1" /> - <span className="description-text">Copy from LM Studio developer tab.</span> + <span className="description-text">Copy from LM Studio developer tab or llama.cpp console output. Make sure /v1 at end is present, ex. http://127.0.0.1:8080/v1</span> </div> <div> @@ -161,6 +168,20 @@ const ConfigForm = ({ </div> </div> + <div className="form-group"> + <label htmlFor="concurrent_batch_size">Concurrent Batch Size</label> + <input + type="number" + id="concurrent_batch_size" + min="1" + max="16" + value={config.concurrent_batch_size || 4} + onChange={handleConcurrentBatchSizeChange} + style={{ width: '100px' }} + /> + <span className="description-text">Batch concurrency if using API with support (i.e. "llama-server -np n"), otherwise leave 1</span> + </div> + <div className="form-group side-by-side api-key-directory"> <div> <label htmlFor="api_key">API Key</label>