Files

159 lines
6.1 KiB
Plaintext

#version 450
/*
* VKZip GPU Compression Shader
*
* Each workgroup compresses one independent block using a simplified
* LZ77 variant optimized for GPU parallelism:
*
* Algorithm:
* 1. Each thread scans a portion of the block for matches using hash chains
* 2. Matches are encoded as (distance, length) pairs
* 3. Non-matching bytes are stored as literals
* 4. Output format per token:
* - Literal: [0x00] [byte]
* - Match: [0x01] [length: u16] [distance: u16]
*
* This is a simplified approach focusing on parallel match finding.
* The compression ratio won't match gzip/zstd, but the speed
* advantage from GPU parallelism makes up for it on large files.
*/
layout(local_size_x = 256) in;
// ── Push constants ─────────────────────────────────────────────────
layout(push_constant) uniform PushConstants {
uint block_count; // Total number of blocks to process
uint block_size; // Size of each block (e.g., 65536)
uint max_match_len; // Maximum match length
uint window_size; // Sliding window size
} params;
// ── Buffers ────────────────────────────────────────────────────────
// Input: raw uncompressed data (all blocks concatenated)
layout(std430, set = 0, binding = 0) readonly buffer InputBuffer {
uint data[];
} input_buf;
// Output: compressed data (pre-allocated with worst-case size)
layout(std430, set = 0, binding = 1) writeonly buffer OutputBuffer {
uint data[];
} output_buf;
// Block metadata: [block_idx] = { input_offset, input_size, output_offset, output_size }
layout(std430, set = 0, binding = 2) buffer MetadataBuffer {
uvec4 blocks[]; // x=in_offset, y=in_size, z=out_offset, w=out_size(result)
} meta;
// ── Shared memory for workgroup ────────────────────────────────────
shared uint s_hash_table[4096]; // Hash table for match finding
shared uint s_output_pos; // Current output position (atomic)
// ── Helper: read a byte from packed uint buffer ────────────────────
uint read_byte(uint base_offset, uint byte_idx) {
uint word_idx = (base_offset + byte_idx) >> 2;
uint byte_pos = (base_offset + byte_idx) & 3;
return (input_buf.data[word_idx] >> (byte_pos * 8)) & 0xFF;
}
// ── Helper: write a byte to packed uint buffer ─────────────────────
void write_byte(uint base_offset, uint byte_idx, uint value) {
uint word_idx = (base_offset + byte_idx) >> 2;
uint byte_pos = (base_offset + byte_idx) & 3;
atomicOr(output_buf.data[word_idx], (value & 0xFF) << (byte_pos * 8));
}
// ── Hash function for string matching ──────────────────────────────
uint hash3(uint base_offset, uint pos) {
uint b0 = read_byte(base_offset, pos);
uint b1 = read_byte(base_offset, pos + 1);
uint b2 = read_byte(base_offset, pos + 2);
return ((b0 << 16) ^ (b1 << 8) ^ b2) & 0xFFF;
}
void main() {
uint block_idx = gl_WorkGroupID.x;
uint thread_id = gl_LocalInvocationID.x;
if (block_idx >= params.block_count) return;
uint in_offset = meta.blocks[block_idx].x;
uint in_size = meta.blocks[block_idx].y;
uint out_offset = meta.blocks[block_idx].z;
// Initialize shared memory
if (thread_id < 256) {
for (uint i = thread_id; i < 4096; i += 256) {
s_hash_table[i] = 0xFFFFFFFF;
}
}
if (thread_id == 0) {
s_output_pos = 0;
}
barrier();
memoryBarrierShared();
// ── Single-threaded compression for correctness ─────────────
// Thread 0 does the sequential LZ77 compression.
// Other threads could assist with parallel hash updates in a
// more advanced version.
if (thread_id == 0) {
uint pos = 0;
uint out_pos = 0;
while (pos < in_size) {
uint best_len = 0;
uint best_dist = 0;
// Try to find a match (need at least 3 bytes remaining)
if (pos + 2 < in_size) {
uint h = hash3(in_offset, pos);
uint match_pos = s_hash_table[h];
// Scan hash chain for matches
if (match_pos != 0xFFFFFFFF && pos > match_pos) {
uint dist = pos - match_pos;
if (dist <= params.window_size && dist > 0) {
// Count matching bytes
uint len = 0;
uint max_len = min(params.max_match_len, in_size - pos);
while (len < max_len &&
read_byte(in_offset, match_pos + len) ==
read_byte(in_offset, pos + len)) {
len++;
}
if (len >= 3) {
best_len = len;
best_dist = dist;
}
}
}
// Update hash table
s_hash_table[h] = pos;
}
if (best_len >= 3) {
// Write match token: [0x01] [len_lo] [len_hi] [dist_lo] [dist_hi]
write_byte(out_offset, out_pos++, 0x01);
write_byte(out_offset, out_pos++, best_len & 0xFF);
write_byte(out_offset, out_pos++, (best_len >> 8) & 0xFF);
write_byte(out_offset, out_pos++, best_dist & 0xFF);
write_byte(out_offset, out_pos++, (best_dist >> 8) & 0xFF);
pos += best_len;
} else {
// Write literal token: [0x00] [byte]
uint b = read_byte(in_offset, pos);
write_byte(out_offset, out_pos++, 0x00);
write_byte(out_offset, out_pos++, b);
pos++;
}
}
// Store output size
meta.blocks[block_idx].w = out_pos;
}
barrier();
}