Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/*
* conv2d_gemm: GEMM step of im2col-backed conv2d.
*
* Reads the im2col'd input produced by conv2d_im2col.glsl as a 2D matrix
* of shape [M, K_total] (M = H_out * W_out, K_total = Kh*Kw*Cin_padded)
* and writes the conv2d output as texture3D channels-packed
* logical shape [1, C_out, H_out, W_out].
*
* The im2col input can be any of:
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for row m.
* IN_STORAGE=texture2d codegen.
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
* for output spatial position (oh, ow). Used when M would exceed
* max_texture2d_dim. IN_STORAGE=texture3d codegen.
* - buffer: vec4 at offset m*K4 + k4, same K packing.
* IN_STORAGE=buffer codegen.
*
* The matmul interpretation is:
* out[m, n] = sum_k im2col[m, k] * weight[n, k] + bias[n]
* with M = H_out * W_out, K = K_total, N = C_out.
*/

#version 450 core

#define PRECISION ${PRECISION}

$if IN_STORAGE == "buffer" and DTYPE == "half":
${define_explicit_type_extensions(DTYPE)}

// VEC4_T is the input storage's natural texel type, which is also the tile type
// (the linear_fp_*_tile headers default the tile vec4 type to VEC4_T). For the
// buffer/half path this resolves to f16vec4, so the GEMM inner loop accumulates
// in true FP16 — the fma emits mad.f16 and the accumulators live in half-width
// registers. Texture-sampled half always returns vec4, so FP16 accumulation is
// naturally confined to the buffer (Mali) path; the texture variants (Adreno),
// where FP16 accumulation regresses, stay vec4 / FP32 with no extra gating.
#define VEC4_T ${texel_load_type(DTYPE, IN_STORAGE)}

// OUT_VEC4_T is the output surface type. t_out is always texture3d, whose
// imageStore ABI takes vec4 (fp32) regardless of DTYPE, so the accumulator tile
// is cast from VEC4_T to OUT_VEC4_T at store time.
#define OUT_VEC4_T ${texel_load_type(DTYPE, "texture3d")}

#define TILE_M4 ${TILE_M4}
#define TILE_K4 ${TILE_K4}
#define TILE_N4 ${TILE_N4}

#define TILE_M ${TILE_M}
#define TILE_K ${TILE_K4 * 4}
#define TILE_N ${TILE_N4 * 4}

$if IN_STORAGE == "buffer":
#define INPUT_BUFFER
$elif IN_STORAGE == "texture3d":
#define INPUT_TEXTURE3D

${define_required_extensions("texture3d", DTYPE)}
$if IN_STORAGE == "buffer":
${define_required_extensions("buffer", DTYPE)}

layout(std430) buffer;

#include "common.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
$if IN_STORAGE == "buffer":
${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer", is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE)}
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "texture2d")}
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture2d")}

${layout_declare_ubo(B, "ivec4", "out_sizes")}

// Push constants are uploaded in 16-byte chunks (one ivec4 each).
layout(push_constant) uniform restrict Block {
ivec4 gemm_dims; // (K_total, K4_total, M, _unused)
vec4 clamp_vals; // (out_min, out_max, _unused, _unused)
};

#define K_TOTAL gemm_dims.x
#define K4_TOTAL gemm_dims.y
#define M_TOTAL gemm_dims.z
#define OUT_MIN clamp_vals.x
#define OUT_MAX clamp_vals.y

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "activation_type", "0")}

#include "linear_fp_input_tile.glslh"
#include "linear_fp_packed_weight_tile_load.glslh"
#include "linear_fp_output_tile_fp_compute.glslh"

/*
* Load TILE_M rows × TILE_K4 K-tiles of the im2col'd input.
* The im2col output is a contiguous (M, K_total/4) matrix of vec4s, so the
* load is a plain 2D fetch — no spatial decomposition.
*/
void load_input_tile_with_checks(
out FPInputTile tile,
const int k4_start,
const int m_start,
const int K4,
const int M,
const int W_out) {
// W_out is only consumed by the texture3d variant below.
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
if (k4_start + k4 < K4 && m_start + m < M) {
const int row = m_start + m;
const int col = k4_start + k4;
#if defined(INPUT_BUFFER)
// Cast SSBO texel into the input tile type (f16vec4 for half, vec4 for
// float).
tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(t_in[row * K4 + col]);
#elif defined(INPUT_TEXTURE3D)
// texture3d layout: row (the flat M index) decomposes into (ow, oh)
// and K4 is along the Z axis. texelFetch returns vec4 (fp32); cast to
// the input tile type.
tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(
texelFetch(t_in, ivec3(row % W_out, row / W_out, col), 0));
#else
tile.data[m][k4] =
LINEAR_FP_INPUT_TILE_VEC4_T(texelFetch(t_in, ivec2(col, row), 0));
#endif
} else {
tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(0.0);
}
}
}
}

void store_output_tile_with_checks(
const FPOutTile out_tile,
const int n4_start,
const int m_start,
const int N4,
const int M,
const int W_out) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
if (m_start + m < M && n4_start + n4 < N4) {
const int spatial = m_start + m;
// Cast the accumulator (f16vec4 for the buffer/half path) to the
// texture3d output surface type for the activation clamp and store.
OUT_VEC4_T texel = OUT_VEC4_T(out_tile.data[m][n4]);
if (activation_type == 1) {
texel = max(texel, OUT_VEC4_T(0.0));
} else if (activation_type == 2) {
texel = clamp(texel, OUT_VEC4_T(OUT_MIN), OUT_VEC4_T(OUT_MAX));
}
imageStore(
t_out, ivec3(spatial % W_out, spatial / W_out, n4_start + n4), texel);
}
}
}
}

void main() {
const int tile_idx_n = int(gl_GlobalInvocationID.x);
const int tile_idx_m = int(gl_GlobalInvocationID.y);

const int n4_start = tile_idx_n * TILE_N4;
const int m_start = tile_idx_m * TILE_M;

const int W_out = out_sizes.x;
const int H_out = out_sizes.y;
const int M = M_TOTAL;
const int K4 = K4_TOTAL;
const int N = out_sizes.z;
const int N4 = div_up_4(N);

if (n4_start >= N4 || m_start >= M) {
return;
}

FPOutTile out_tile;
initialize(out_tile);

FPInputTile in_tile;
FPWeightTile w_tile;

for (int k4 = 0; k4 < K4; k4 += TILE_K4) {
load_input_tile_with_checks(in_tile, k4, m_start, K4, M, W_out);
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
}

// Apply bias. The bias texel depends only on n4, so fetch it once per n4 and
// add it to every m row rather than re-fetching inside the M loop.
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
if (n4_start + n4 < N4) {
// t_bias is an fp32 texture2d; cast its texel to the accumulator type.
const LINEAR_FP_OUTPUT_TILE_VEC4_T bias_texel =
LINEAR_FP_OUTPUT_TILE_VEC4_T(
texelFetch(t_bias, ivec2(n4_start + n4, 0), 0));
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
out_tile.data[m][n4] += bias_texel;
}
}
}

store_output_tile_with_checks(out_tile, n4_start, m_start, N4, M, W_out);
}
26 changes: 26 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_gemm:
parameter_names_with_default_values:
DTYPE: float
IN_STORAGE: texture2d
TILE_M4: 1
TILE_K4: 1
TILE_N4: 1
TILE_M: 4
generate_variant_forall:
combination:
parameter_names: [IN_STORAGE, DTYPE]
combos:
- parameter_values: [texture2d, float]
- parameter_values: [texture2d, half]
- parameter_values: [texture3d, float]
- parameter_values: [texture3d, half]
- parameter_values: [buffer, float]
- parameter_values: [buffer, half]
shader_variants:
- NAME: conv2d_gemm
120 changes: 120 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/*
* Im2col transformation for FP32 / FP16 conv2d.
*
* The output is a 2D matrix of shape [M, K_total] where
* M = H_out * W_out (number of output spatial positions)
* K_total = Kh * Kw * align_up_4(C_in) (flattened receptive field)
*
* K layout (so a 4-tile in K — one vec4 — holds the same kernel position):
* K = (ki * Kw + kj) * Cin_padded + ci
*
* Three codegen'd storage variants of the output tensor:
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for spatial
* position m. Extents = (K_total/4, M).
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
* for output spatial position (oh, ow). Extents = (W_out, H_out, K4).
* Used as a fallback when M would exceed max_texture2d_dim.
* - buffer: vec4 at offset (m * K4 + k4), same K packing.
*
* The caller picks storage per device (Mali → buffer; others → texture2d
* when its 2D extents fit, texture3d when its 3D extents fit, else buffer).
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}

$if OUT_STORAGE == "buffer":
#define OUTPUT_BUFFER
#define VEC4_BUF_T ${texel_load_type(DTYPE, "buffer")}
$elif OUT_STORAGE == "texture3d":
#define OUTPUT_TEXTURE3D

${define_required_extensions("texture3d", DTYPE)}
$if OUT_STORAGE == "buffer":
${define_required_extensions("buffer", DTYPE)}

layout(std430) buffer;

$if OUT_STORAGE == "buffer":
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer", is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")}

${layout_declare_ubo(B, "ivec4", "in_sizes")}

// Push constants are uploaded in 16-byte chunks (one ivec4 each) to comply
// with the per-entry size limit.
layout(push_constant) uniform restrict Block {
ivec4 kernel_stride; // (Kh, Kw, Sh, Sw)
ivec4 padding_dil; // (Ph, Pw, Dh, Dw)
ivec4 dims; // (Cin_padded, W_out, H_out, K4_total)
};

#define KERNEL_H kernel_stride.x
#define KERNEL_W kernel_stride.y
#define STRIDE_H kernel_stride.z
#define STRIDE_W kernel_stride.w
#define PADDING_H padding_dil.x
#define PADDING_W padding_dil.y
#define DILATION_H padding_dil.z
#define DILATION_W padding_dil.w
#define CIN_PADDED dims.x
#define W_OUT dims.y
#define H_OUT dims.z
#define K4_TOTAL dims.w

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const int k4 = int(gl_GlobalInvocationID.x);
const int m = int(gl_GlobalInvocationID.y);
const int M = H_OUT * W_OUT;

if (k4 >= K4_TOTAL || m >= M) {
return;
}

const int k_start = k4 * 4;

// K = (ki * Kw + kj) * Cin_padded + ci ; since Cin_padded % 4 == 0, all 4
// K values in this texel share the same (ki, kj) and span 4 consecutive
// ci values starting at ci_start.
const int krow_idx = k_start / CIN_PADDED; // ki * Kw + kj
const int ci_start = k_start % CIN_PADDED;
const int kj = krow_idx % KERNEL_W;
const int ki = krow_idx / KERNEL_W;
const int ci_blk = ci_start >> 2; // ci_start / 4

// Decompose flat output position m back into (oh, ow).
const int ow = m % W_OUT;
const int oh = m / W_OUT;

// Compute the input spatial position for this (oh, ow, ki, kj).
const int ih = oh * STRIDE_H - PADDING_H + ki * DILATION_H;
const int iw = ow * STRIDE_W - PADDING_W + kj * DILATION_W;

VEC4_T out_texel = VEC4_T(0);
if (ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) {
out_texel = texelFetch(t_in, ivec3(iw, ih, ci_blk), 0);
}

#if defined(OUTPUT_BUFFER)
t_out[m * K4_TOTAL + k4] = VEC4_BUF_T(out_texel);
#elif defined(OUTPUT_TEXTURE3D)
imageStore(t_out, ivec3(ow, oh, k4), out_texel);
#else
imageStore(t_out, ivec2(k4, m), out_texel);
#endif
}
22 changes: 22 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_im2col:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture2d
generate_variant_forall:
combination:
parameter_names: [OUT_STORAGE, DTYPE]
combos:
- parameter_values: [texture2d, float]
- parameter_values: [texture2d, half]
- parameter_values: [texture3d, float]
- parameter_values: [texture3d, half]
- parameter_values: [buffer, float]
- parameter_values: [buffer, half]
shader_variants:
- NAME: conv2d_im2col
Loading
Loading