Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,5 @@ Mostly self explanatory. Paste in the path to the directory you want processed
For instance, if you are writing webscrapers, make sure to collect metadata from the webpage as you go rather than blindly just download each image. Perhaps you might include the website address or full URI of the webpage, the `<title>` tag from the webpage, or the `alt-text` field. Save this information with each image, or in a database. Then feed into the VLM with a `hint_source`. New hint sources are very easy for an amateur Python programmer to write, or you can have an LLM write for you.

See [HINTSOURCES.md](HINTSOURCES.md) for more information.

- **Batch concurrency**: If you use a VLM host that support batch concurrency such as llama.cpp (via -np n arg) you can potentially increase speed. This is not supported by LM Studio. Example command: `llama-server -np 4 -c 32768 --mmproj "mmproj-Qwen3-VL-32B-Instruct-F16.gguf" --model "Qwen3-VL-32B-Instruct-Q4_K_M.gguf" -dev cuda0 --top-k 30 --top-p 0.95 --min-p 0.05 --temp 0.5` would launch Qwen3VL 32B with four concurrent processes (-np 4) each with 8192 tokens (32768/4) of context for each of the 4 slots. This requires additional processing power and an increase of total context size (`-np 4 -c 32768` instead of `-np 1 -c 8192` as an example), but may increase total token generation speeds by utilizing batch processing. _This feature does not utilize the OpenAI jsonl batch API suitable for commercial APIs to save on costs, but should work to speed up rates._
1 change: 1 addition & 0 deletions caption.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ api_key_env_vars:
model: llama-4-scout-17b-16e-instruct
api_key: ''
max_tokens: 16384
concurrent_batch_size: 1
system_prompt: You are to analyze an image
and provide information based on what is visible in the image. Do not embellish,
and avoid langauge like 'showcases' or 'features,' preferring to focus on factual
Expand Down
60 changes: 44 additions & 16 deletions caption_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,47 @@ async def process_image(client: openai.AsyncOpenAI, image_path, conf) -> Tuple[s
messages = remove_base64_image(messages)
return final_summary_response, json.dumps(messages, indent=2), prompt_tokens_usage, completion_tokens_usage

async def process_batch(client: openai.AsyncOpenAI, image_paths: list, conf) -> Tuple[int, int]:
"""Process a batch of images concurrently and return aggregated token usage."""
tasks = [process_image(client, image_path, conf) for image_path in image_paths]

# Process all images in the batch concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)

batch_prompt_tokens = 0
batch_completion_tokens = 0

for i, result in enumerate(results):
if isinstance(result, Exception):
print(filter_ascii(f"Error processing {image_paths[i]}: {result}"))
else:
caption_text, chat_history, prompt_token_usage, completion_token_usage = result # type: ignore

await save_caption(file_path=image_paths[i], caption_text=caption_text, debug_info=chat_history)

batch_prompt_tokens += prompt_token_usage
batch_completion_tokens += completion_token_usage

print(filter_ascii(f" --> Processed {image_paths[i]}"))
print(f" --> prompt_token_usage: {prompt_token_usage}, completion_token_usage: {completion_token_usage}")

return batch_prompt_tokens, batch_completion_tokens

async def main():
import hints.registration as registration
registration._validate_hint_sources()

conf = OmegaConf.load("caption.yaml")

concurrent_batch_size = conf.concurrent_batch_size

if conf.get("global_metadata_file"): # type: ignore
async with aiofiles.open(conf.global_metadata_file) as f:
global_metadata = await f.read()
conf.system_prompt = f"{global_metadata}\n{conf.system_prompt}"

print(filter_ascii(f" -> SYSTEM PROMPT:\n{conf.system_prompt}\n"))
print(filter_ascii(f" -> CONCURRENT BATCH SIZE: {concurrent_batch_size}\n"))

api_key = resolve_api_key(conf)

Expand All @@ -162,32 +191,31 @@ async def main():
aggregated_prompt_token_usage = 0
aggregated_completion_token_usage = 0

# Collect images in batches for concurrent processing
batch = []
async for image_path in image_walk(conf.base_directory, recursive=conf.recursive, skip_if_txt_exists=conf.skip_if_txt_exists):
current_task = asyncio.current_task()
if current_task is not None and current_task.cancelled():
print("Captioning task was cancelled by user")
return

print(filter_ascii(f"\nProcessing {image_path}"))
try:
start_time = time.perf_counter()
caption_text, chat_history, prompt_token_usage, completion_token_usage = await process_image(client, image_path, conf)
total_time = (time.perf_counter() - start_time)
except openai.APIConnectionError as e:
print(f"{e}\nAPI Error. Check that your service is running and caption.yaml has the correct base_url")
except asyncio.CancelledError:
print("Captioning task was cancelled during image processing")
raise
batch.append(image_path)

aggregated_prompt_token_usage += prompt_token_usage
aggregated_completion_token_usage += aggregated_completion_token_usage
if len(batch) >= concurrent_batch_size:
print(filter_ascii(f"\nProcessing batch of {len(batch)} images:"))
start_time = time.perf_counter()

batch_prompt_tokens, batch_completion_tokens = await process_batch(client, batch, conf)

batch_time = (time.perf_counter() - start_time)
print(filter_ascii(f" --> Batch completed in {batch_time:.2f}s, {batch_time/concurrent_batch_size:.2f}s per image"))

aggregated_prompt_token_usage += batch_prompt_tokens
aggregated_completion_token_usage += batch_completion_tokens
batch = []

print(filter_ascii(f" --> Took {total_time:.2f}s, Final caption:\n{caption_text}"))
print(f" --> prompt_token_usage: {prompt_token_usage}, completion_token_usage: {completion_token_usage}")
await save_caption(file_path=image_path, caption_text=caption_text, debug_info=chat_history)

print(F" -> JOB COMPLETE.")
# not working?
print(f"aggregated_prompt_token_usage: {aggregated_prompt_token_usage}, aggregated_completion_token_usage: {aggregated_completion_token_usage}")

if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions ui/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ function App() {
recursive: false,
hint_sources: [],
global_metadata_file: '',
skip_if_txt_exists: false
skip_if_txt_exists: false,
concurrent_batch_size: 1
});
const [configLoading, setConfigLoading] = useState(false);
const [configError, setConfigError] = useState('');
Expand Down Expand Up @@ -106,7 +107,8 @@ function App() {
recursive: data.config.recursive || false,
hint_sources: data.config.hint_sources || [],
global_metadata_file: data.config.global_metadata_file || '',
skip_if_txt_exists: data.config.skip_if_txt_exists || false
skip_if_txt_exists: data.config.skip_if_txt_exists || false,
concurrent_batch_size: data.config.concurrent_batch_size || 1
};
setConfig(newConfig);
if (newConfig.base_url) {
Expand Down Expand Up @@ -147,7 +149,8 @@ function App() {
recursive: config.recursive,
hint_sources: config.hint_sources,
global_metadata_file: config.global_metadata_file,
skip_if_txt_exists: config.skip_if_txt_exists
skip_if_txt_exists: config.skip_if_txt_exists,
concurrent_batch_size: config.concurrent_batch_size
}
}),
});
Expand Down
23 changes: 22 additions & 1 deletion ui/src/components/ConfigForm.js
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ const ConfigForm = ({
onConfigChange('retry_rules', newRetryRules);
};

const handleConcurrentBatchSizeChange = (e) => {
const value = parseInt(e.target.value);
if (value >= 1 && value <= 16) {
onConfigChange('concurrent_batch_size', value);
}
};

if (configLoading) return <p>Loading configuration...</p>;

return (
Expand All @@ -127,7 +134,7 @@ const ConfigForm = ({
onChange={(e) => onConfigChange('base_url', e.target.value)}
placeholder="e.g., http://localhost:1234/v1"
/>
<span className="description-text">Copy from LM Studio developer tab.</span>
<span className="description-text">Copy from LM Studio developer tab or llama.cpp console output. Make sure /v1 at end is present, ex. http://127.0.0.1:8080/v1</span>
</div>

<div>
Expand Down Expand Up @@ -161,6 +168,20 @@ const ConfigForm = ({
</div>
</div>

<div className="form-group">
<label htmlFor="concurrent_batch_size">Concurrent Batch Size</label>
<input
type="number"
id="concurrent_batch_size"
min="1"
max="16"
value={config.concurrent_batch_size || 4}
onChange={handleConcurrentBatchSizeChange}
style={{ width: '100px' }}
/>
<span className="description-text">Batch concurrency if using API with support (i.e. "llama-server -np n"), otherwise leave 1</span>
</div>

<div className="form-group side-by-side api-key-directory">
<div>
<label htmlFor="api_key">API Key</label>
Expand Down