Kuwahara Filter

By Antonio Cheong on on Permalink.

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "numpy",
#     "pillow",
#     "pyopencl",
# ]
# ///
import pyopencl as cl
import numpy as np
from PIL import Image


def generalized_kuwahara_gpu_opencl(
    image_path,
    output_path,
    kernel_size=8,
    sharpness=8.0,
    hardness=8.0,
    zero_crossing=0.58,
    zeta=None,
    passes=1,
):
    ctx = cl.create_some_context(False)
    queue = cl.CommandQueue(ctx)

    img = Image.open(image_path).convert("RGB")
    img_array = np.array(img, dtype=np.float32) / 255.0
    height, width, _ = img_array.shape

    kernel_radius = kernel_size // 2

    if zeta is None:
        zeta = np.float32(2.0 / kernel_radius)

    mf = cl.mem_flags
    img_buf = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=img_array)
    result_buf = cl.Buffer(ctx, mf.READ_WRITE, img_array.nbytes)

    kernel_code = """
    #define PI 3.14159265358979323846f
    #define clamp(x, low, high) (((x) <= (high)) ? (((x) >= (low)) ? (x) : (low)) : (high))

    __kernel void kuwahara_filter(
        __global const float *input,
        __global float *output,
        const int width,
        const int height,
        const int kernel_radius,
        const float sharpness,
        const float hardness,
        const float zero_crossing,
        const float zeta)
    {
        int x = get_global_id(0);
        int y = get_global_id(1);

        if (x < kernel_radius || x >= width - kernel_radius ||
            y < kernel_radius || y >= height - kernel_radius) {
            // Copy border pixels unchanged
            int idx = (y * width + x) * 3;
            output[idx] = input[idx];
            output[idx+1] = input[idx+1];
            output[idx+2] = input[idx+2];
            return;
        }

        float4 m[8];
        float3 s[8];

        for (int k = 0; k < 8; ++k) {
            m[k] = (float4)(0.0f);
            s[k] = (float3)(0.0f);
        }

        float sin_zero_cross = sin(zero_crossing);
        float eta = (zeta + cos(zero_crossing)) / (sin_zero_cross * sin_zero_cross);

        for (int ky = -kernel_radius; ky <= kernel_radius; ++ky) {
            for (int kx = -kernel_radius; kx <= kernel_radius; ++kx) {
                float2 v = (float2)((float)kx/kernel_radius, (float)ky/kernel_radius);

                // Read pixel with boundary checking
                int px = clamp(x + kx, 0, width-1);
                int py = clamp(y + ky, 0, height-1);
                int pidx = (py * width + px) * 3;

                float3 c = (float3)(
                    input[pidx],
                    input[pidx+1],
                    input[pidx+2]
                );
                c = clamp(c, 0.0f, 1.0f);

                float w[8];
                float sum = 0.0f;

                // Calculate weights
                float vxx = zeta - eta * v.x * v.x;
                float vyy = zeta - eta * v.y * v.y;

                w[0] = pow(max(0.0f, v.y + vxx), 2.0f);
                sum += w[0];
                w[2] = pow(max(0.0f, -v.x + vyy), 2.0f);
                sum += w[2];
                w[4] = pow(max(0.0f, -v.y + vxx), 2.0f);
                sum += w[4];
                w[6] = pow(max(0.0f, v.x + vyy), 2.0f);
                sum += w[6];

                v = (float2)(v.x - v.y, v.x + v.y) * 0.707106781f; // sqrt(2)/2
                vxx = zeta - eta * v.x * v.x;
                vyy = zeta - eta * v.y * v.y;

                w[1] = pow(max(0.0f, v.y + vxx), 2.0f);
                sum += w[1];
                w[3] = pow(max(0.0f, -v.x + vyy), 2.0f);
                sum += w[3];
                w[5] = pow(max(0.0f, -v.y + vxx), 2.0f);
                sum += w[5];
                w[7] = pow(max(0.0f, v.x + vyy), 2.0f);
                sum += w[7];

                float g = exp(-3.125f * dot(v,v)) / sum;

                for (int k = 0; k < 8; ++k) {
                    float wk = w[k] * g;
                    m[k].xyz += c * wk;
                    m[k].w += wk;
                    s[k] += c * c * wk;
                }
            }
        }

        float4 output_pixel = (float4)(0.0f);

        for (int k = 0; k < 8; ++k) {
            if (m[k].w > 0.0f) {
                float3 mean = m[k].xyz / m[k].w;
                float3 variance = fabs(s[k] / m[k].w - mean * mean);
                float sigma2 = variance.x + variance.y + variance.z;
                float wk = 1.0f / (1.0f + pow(hardness * 1000.0f * sigma2, 0.5f * sharpness));

                output_pixel.xyz += mean * wk;
                output_pixel.w += wk;
            }
        }

        int idx = (y * width + x) * 3;
        if (output_pixel.w > 0.0f) {
            float3 result = clamp(output_pixel.xyz / output_pixel.w, 0.0f, 1.0f);
            output[idx] = result.x;
            output[idx+1] = result.y;
            output[idx+2] = result.z;
        } else {
            output[idx] = input[idx];
            output[idx+1] = input[idx+1];
            output[idx+2] = input[idx+2];
        }
    }
    """

    prg = cl.Program(ctx, kernel_code).build()

    # Set work group size (important for avoiding artifacts)
    local_size = (16, 16)
    global_size = (
        ((width + local_size[0] - 1) // local_size[0]) * local_size[0],
        ((height + local_size[1] - 1) // local_size[1]) * local_size[1],
    )

    # Run kernel for each pass
    for _ in range(passes):
        prg.kuwahara_filter(
            queue,
            global_size,
            local_size,
            img_buf,
            result_buf,
            np.int32(width),
            np.int32(height),
            np.int32(kernel_radius),
            np.float32(sharpness),
            np.float32(hardness),
            np.float32(zero_crossing),
            np.float32(zeta),
        )
        # Swap buffers for next pass
        img_buf, result_buf = result_buf, img_buf

    result = np.empty_like(img_array)
    cl.enqueue_copy(queue, result, img_buf)

    output_array = (np.clip(result, 0, 1) * 255).astype(np.uint8)
    Image.fromarray(output_array).save(output_path)


if __name__ == "__main__":
    generalized_kuwahara_gpu_opencl(
        "input.jpg",
        "output.jpg",
        kernel_size=15,
        sharpness=8.0,
        hardness=8.0,
        zero_crossing=0.58,
        passes=1,
    )

Derived from https://github.com/GarrettGunnell/Post-Processing/