Skip to content

Commit 1af99b7

Browse files
Aelphyxnnpack-bot
authored andcommitted
Add avx2 kernels for statically quantized 2-bit FC.
qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 10335342 ns 10334988 ns 1 OPS=103.89G/s qs8_qc2w_gemm_minmax_fp32_ukernel_2x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 8453181 ns 8451818 ns 1 OPS=127.022G/s qs8_qc2w_gemm_minmax_fp32_ukernel_3x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 7859592 ns 7858406 ns 1 OPS=136.615G/s qs8_qc2w_gemm_minmax_fp32_ukernel_4x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 7383201 ns 7382116 ns 1 OPS=145.43G/s qs8_qc2w_gemm_minmax_fp32_ukernel_5x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 7219861 ns 7221366 ns 1 OPS=148.721G/s qs8_qc2w_gemm_minmax_fp32_ukernel_6x8c8__avx2_madd/llm/M:128/N:4096/K:1024/real_time 7582741 ns 7581205 ns 1 OPS=141.603G/s qs8_qc2w_gemm_minmax_fp32_ukernel_1x2__scalar_lrintf/llm/M:128/N:4096/K:1024/real_time 120441433 ns 120432262 ns 1 OPS=8.91505G/s qs8_qc2w_gemm_minmax_fp32_ukernel_1x2__scalar_fmagic/llm/M:128/N:4096/K:1024/real_time 120085104 ns 120071060 ns 1 OPS=8.94151G/s qs8_qc2w_gemm_minmax_fp32_ukernel_1x4__scalar_fmagic/llm/M:128/N:4096/K:1024/real_time 111649263 ns 111650284 ns 1 OPS=9.6171G/s qs8_qc2w_gemm_minmax_fp32_ukernel_2x4__scalar_fmagic/llm/M:128/N:4096/K:1024/real_time 72375198 ns 72376231 ns 1 OPS=14.8358G/s qs8_qc2w_gemm_minmax_fp32_ukernel_3x4__scalar_fmagic/llm/M:128/N:4096/K:1024/real_time 62368117 ns 62369613 ns 1 OPS=17.2162G/s qs8_qc2w_gemm_minmax_fp32_ukernel_4x4__scalar_fmagic/llm/M:128/N:4096/K:1024/real_time 57041776 ns 57043407 ns 1 OPS=18.8238G/s PiperOrigin-RevId: 864451367
1 parent 313afbb commit 1af99b7

22 files changed

+3574
-84
lines changed

bench/qs8-qc2w-gemm-fp32.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,75 @@ namespace {
117117
#endif // XNN_ENABLE_ARM_DOTPROD && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
118118

119119

120+
#if XNN_ENABLE_AVX2 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
121+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c8__avx2_madd(benchmark::State& state, const char* net) {
122+
GEMMBenchmark(state,
123+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c8__avx2_madd,
124+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
125+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
126+
/*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1,
127+
/*arch_flags=*/xnn_arch_x86_avx2);
128+
}
129+
130+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c8__avx2_madd)
131+
132+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_2x8c8__avx2_madd(benchmark::State& state, const char* net) {
133+
GEMMBenchmark(state,
134+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_2x8c8__avx2_madd,
135+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
136+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
137+
/*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1,
138+
/*arch_flags=*/xnn_arch_x86_avx2);
139+
}
140+
141+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_2x8c8__avx2_madd)
142+
143+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_3x8c8__avx2_madd(benchmark::State& state, const char* net) {
144+
GEMMBenchmark(state,
145+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_3x8c8__avx2_madd,
146+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
147+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
148+
/*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1,
149+
/*arch_flags=*/xnn_arch_x86_avx2);
150+
}
151+
152+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_3x8c8__avx2_madd)
153+
154+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_4x8c8__avx2_madd(benchmark::State& state, const char* net) {
155+
GEMMBenchmark(state,
156+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x8c8__avx2_madd,
157+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
158+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
159+
/*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1,
160+
/*arch_flags=*/xnn_arch_x86_avx2);
161+
}
162+
163+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_4x8c8__avx2_madd)
164+
165+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_5x8c8__avx2_madd(benchmark::State& state, const char* net) {
166+
GEMMBenchmark(state,
167+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_5x8c8__avx2_madd,
168+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
169+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
170+
/*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1,
171+
/*arch_flags=*/xnn_arch_x86_avx2);
172+
}
173+
174+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_5x8c8__avx2_madd)
175+
176+
static void qs8_qc2w_gemm_minmax_fp32_ukernel_6x8c8__avx2_madd(benchmark::State& state, const char* net) {
177+
GEMMBenchmark(state,
178+
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_6x8c8__avx2_madd,
179+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
180+
xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w,
181+
/*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1,
182+
/*arch_flags=*/xnn_arch_x86_avx2);
183+
}
184+
185+
BENCHMARK_GEMM(qs8_qc2w_gemm_minmax_fp32_ukernel_6x8c8__avx2_madd)
186+
#endif // XNN_ENABLE_AVX2 && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
187+
188+
120189
static void qs8_qc2w_gemm_minmax_fp32_ukernel_1x2__scalar_lrintf(benchmark::State& state, const char* net) {
121190
GEMMBenchmark(state,
122191
xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x2__scalar_lrintf,

cmake/gen/avx2_microkernels.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ SET(PROD_AVX2_MICROKERNEL_SRCS
4949
src/qs8-f16-vcvt/gen/qs8-f16-vcvt-avx2-u16.c
5050
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u16.c
5151
src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c
52+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c8-minmax-avx2-madd.c
53+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-5x8c8-minmax-avx2-madd.c
5254
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c
5355
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c
5456
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p16c-minmax-fp32-avx2-mul32.c
@@ -340,6 +342,12 @@ SET(NON_PROD_AVX2_MICROKERNEL_SRCS
340342
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u8.c
341343
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u24.c
342344
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u32.c
345+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x8c8-minmax-avx2-madd.c
346+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x8c8-minmax-avx2-madd.c
347+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c8-minmax-avx2-madd.c
348+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c8-minmax-avx2-madd.c
349+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-7x8c8-minmax-avx2-madd.c
350+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c8-minmax-avx2-madd.c
343351
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x8c8-minmax-avx2-madd.c
344352
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c
345353
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x8c8-minmax-avx2-madd.c

cmake/gen/scalar_microkernels.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,10 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS
151151
src/qs8-packw/gen/qs8-packw-x4c8-gemm-goi-scalar.c
152152
src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-scalar.c
153153
src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c
154+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-fmagic.c
154155
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x4-minmax-fp32-scalar-fmagic.c
155156
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x4-minmax-fp32-scalar-fmagic.c
157+
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x4-minmax-fp32-scalar-fmagic.c
156158
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4-minmax-fp32-scalar-fmagic.c
157159
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4-minmax-fp32-scalar-fmagic.c
158160
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c
@@ -577,10 +579,8 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS
577579
src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c
578580
src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c
579581
src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c
580-
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-fmagic.c
581582
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-lrintf.c
582583
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x4-minmax-fp32-scalar-fmagic.c
583-
src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x4-minmax-fp32-scalar-fmagic.c
584584
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x2-minmax-fp32-scalar-fmagic.c
585585
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x2-minmax-fp32-scalar-lrintf.c
586586
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4-minmax-fp32-scalar-fmagic.c

gen/avx2_microkernels.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ PROD_AVX2_MICROKERNEL_SRCS = [
4545
"src/qs8-f16-vcvt/gen/qs8-f16-vcvt-avx2-u16.c",
4646
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u16.c",
4747
"src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c",
48+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c8-minmax-avx2-madd.c",
49+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-5x8c8-minmax-avx2-madd.c",
4850
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c",
4951
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c",
5052
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p16c-minmax-fp32-avx2-mul32.c",
@@ -337,6 +339,12 @@ NON_PROD_AVX2_MICROKERNEL_SRCS = [
337339
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u8.c",
338340
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u24.c",
339341
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-avx2-u32.c",
342+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x8c8-minmax-avx2-madd.c",
343+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x8c8-minmax-avx2-madd.c",
344+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c8-minmax-avx2-madd.c",
345+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c8-minmax-avx2-madd.c",
346+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-7x8c8-minmax-avx2-madd.c",
347+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c8-minmax-avx2-madd.c",
340348
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x8c8-minmax-avx2-madd.c",
341349
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c",
342350
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x8c8-minmax-avx2-madd.c",

gen/scalar_microkernels.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ PROD_SCALAR_MICROKERNEL_SRCS = [
147147
"src/qs8-packw/gen/qs8-packw-x4c8-gemm-goi-scalar.c",
148148
"src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-scalar.c",
149149
"src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c",
150+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-fmagic.c",
150151
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x4-minmax-fp32-scalar-fmagic.c",
151152
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x4-minmax-fp32-scalar-fmagic.c",
153+
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x4-minmax-fp32-scalar-fmagic.c",
152154
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4-minmax-fp32-scalar-fmagic.c",
153155
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4-minmax-fp32-scalar-fmagic.c",
154156
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c",
@@ -574,10 +576,8 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [
574576
"src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c",
575577
"src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c",
576578
"src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c",
577-
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-fmagic.c",
578579
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x2-minmax-fp32-scalar-lrintf.c",
579580
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x4-minmax-fp32-scalar-fmagic.c",
580-
"src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x4-minmax-fp32-scalar-fmagic.c",
581581
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x2-minmax-fp32-scalar-fmagic.c",
582582
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x2-minmax-fp32-scalar-lrintf.c",
583583
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4-minmax-fp32-scalar-fmagic.c",

scripts/generate-qs8-gemm.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,15 @@ tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=6 -D DATATYPE=QS8_QC4 -D AVX=
19111911
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=7 -D DATATYPE=QS8_QC4 -D AVX=2 -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x8c8-minmax-avxvnni.c &
19121912
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=8 -D DATATYPE=QS8_QC4 -D AVX=2 -D VARIANT= -D GFNI=1 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-8x8c8-minmax-avxvnni.c &
19131913

1914+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=1 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-1x8c8-minmax-avx2-madd.c &
1915+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=2 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-2x8c8-minmax-avx2-madd.c &
1916+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=3 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-3x8c8-minmax-avx2-madd.c &
1917+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=4 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-4x8c8-minmax-avx2-madd.c &
1918+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=5 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-5x8c8-minmax-avx2-madd.c &
1919+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=6 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-6x8c8-minmax-avx2-madd.c &
1920+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=7 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-7x8c8-minmax-avx2-madd.c &
1921+
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=8 -D DATATYPE=QS8_QC2 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc2w-gemm/gen/qs8-qc2w-gemm-8x8c8-minmax-avx2-madd.c &
1922+
19141923
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=1 -D DATATYPE=QS8_QC4 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x8c8-minmax-avx2-madd.c &
19151924
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=2 -D DATATYPE=QS8_QC4 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x8c8-minmax-avx2-madd.c &
19161925
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=3 -D DATATYPE=QS8_QC4 -D AVX=2 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x8c8-minmax-avx2-madd.c &

src/configs/gemm-config.c

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ static void init_bf16_f32_gemm_config(void) {
414414
#endif // XNN_ENABLE_AVX512BF16
415415
}
416416
assert(bf16_f32_gemm_config.mr <= XNN_MAX_MR);
417-
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
417+
#endif // XNN_ARCH_X86_64
418418
}
419419

420420
static void init_pf32_gemm_config(void) {
@@ -4368,14 +4368,40 @@ static void init_qs8_qc2w_gemm_config(void) {
43684368
qs8_qc2w_gemm_config.mr = 3;
43694369
qs8_qc2w_gemm_config.nr = 4;
43704370
}
4371+
#elif XNN_ARCH_X86 || XNN_ARCH_X86_64
4372+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
4373+
assert(hardware_config != NULL);
4374+
(void) hardware_config; // May be unused.
4375+
4376+
#if XNN_ENABLE_AVX2
4377+
if (hardware_config->arch_flags & xnn_arch_x86_avx2) {
4378+
qs8_qc2w_gemm_config.arch = xnn_arch_x86_avx2;
4379+
qs8_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x8c8__avx2_madd);
4380+
qs8_qc2w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = XNN_INIT_HMP_DQGEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_5x8c8__avx2_madd);
4381+
qs8_qc2w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
4382+
qs8_qc2w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w;
4383+
qs8_qc2w_gemm_config.planes = 4;
4384+
qs8_qc2w_gemm_config.mr = 5;
4385+
qs8_qc2w_gemm_config.nr = 8;
4386+
qs8_qc2w_gemm_config.log2_kr = 3;
4387+
} else
4388+
#endif
4389+
{
4390+
qs8_qc2w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
4391+
qs8_qc2w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x4__scalar_fmagic);
4392+
qs8_qc2w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_4x4__scalar_fmagic);
4393+
qs8_qc2w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc2w_gemm_goi_w;
4394+
qs8_qc2w_gemm_config.planes = 4;
4395+
qs8_qc2w_gemm_config.mr = 4;
4396+
qs8_qc2w_gemm_config.nr = 4;
4397+
}
43714398
#else
43724399
qs8_qc2w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
4373-
qs8_qc2w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x4__scalar_fmagic);
4374-
qs8_qc2w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(3)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_3x4__scalar_fmagic);
4400+
qs8_qc2w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = XNN_INIT_HMP_GEMM_UKERNEL(xnn_qs8_qc2w_gemm_minmax_fp32_ukernel_1x2__scalar_fmagic);
43754401
qs8_qc2w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc2w_gemm_goi_w;
43764402
qs8_qc2w_gemm_config.planes = 4;
4377-
qs8_qc2w_gemm_config.mr = 3;
4378-
qs8_qc2w_gemm_config.nr = 4;
4403+
qs8_qc2w_gemm_config.mr = 1;
4404+
qs8_qc2w_gemm_config.nr = 2;
43794405
#endif
43804406
assert(qs8_qc2w_gemm_config.mr <= XNN_MAX_MR);
43814407
}

src/qd8-f32-qc2w-gemm/gen/qd8-f32-qc2w-gemm-1x8c8-minmax-avx2-madd.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ void xnn_qd8_f32_qc2w_gemm_minmax_ukernel_1x8c8__avx2_madd(
5656
// XNN_FORCE_REALIZATION(voutput_min);
5757
// XNN_FORCE_REALIZATION(voutput_max);
5858
const __m256i vmask = _mm256_set1_epi8(0x03);
59+
XNN_FORCE_REALIZATION(vmask);
5960
do {
6061
const __m256i vksum01234567 = _mm256_load_si256(w);
6162
__m256i vsum0x01234567 = _mm256_mullo_epi32(vksum01234567, vinput_zero_point0);

0 commit comments

Comments
 (0)