Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arm64: Implement 8bpc cdef_dist_kernel #3292

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ fn build_asm_files() {
let asm_files = &[
"src/arm/64/cdef.S",
"src/arm/64/cdef16.S",
"src/arm/64/cdef_dist.S",
"src/arm/64/mc.S",
"src/arm/64/mc16.S",
"src/arm/64/itx.S",
Expand Down
137 changes: 137 additions & 0 deletions src/arm/64/cdef_dist.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/* Copyright (c) 2023, The rav1e contributors. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/

#include "src/arm/asm.S"
#include "util.S"

// v0: tmp register
// v1: src input
// v2: dst input
// v3 = sum(src_{i,j})
// v4 = sum(src_{i,j}^2)
// v5 = sum(dst_{i,j})
// v6 = sum(dst_{i,j}^2)
// v7 = sum(src_{i,j} * dst_{i,j})
// v16: zero register
.macro CDEF_DIST_W8
uabal v3.8h, v1.8b, v16.8b // sum pixel values
umull v0.8h, v1.8b, v1.8b // square
uabal v4.4s, v0.4h, v16.4h // accumulate
uabal2 v4.4s, v0.8h, v16.8h
uabal v5.8h, v2.8b, v16.8b // same as above, but for dst
umull v0.8h, v2.8b, v2.8b
uabal v6.4s, v0.4h, v16.4h
uabal2 v6.4s, v0.8h, v16.8h
umull v0.8h, v1.8b, v2.8b // src_{i,j} * dst_{i,j}
uabal v7.4s, v0.4h, v16.4h
uabal2 v7.4s, v0.8h, v16.8h
.endm

.macro CDEF_DIST_REFINE shift=0
addv h3, v3.8h
umull v3.4s, v3.4h, v3.4h
urshr v3.4s, v3.4s, #(6-\shift) // s3: sum(src_{i,j})^2 / N
addv s4, v4.4s // s4: sum(src_{i,j}^2)
addv h5, v5.8h
umull v5.4s, v5.4h, v5.4h
urshr v5.4s, v5.4s, #(6-\shift) // s5: sum(dst_{i,j})^2 / N
addv s6, v6.4s // s6: sum(dst_{i,j}^2)
addv s7, v7.4s
add v0.4s, v4.4s, v6.4s
sub v0.4s, v0.4s, v7.4s
sub v0.4s, v0.4s, v7.4s // s0: sse
uqsub v4.4s, v4.4s, v3.4s // s4: svar
uqsub v6.4s, v6.4s, v5.4s // s6: dvar
.if \shift != 0
shl v4.4s, v4.4s, #\shift
shl v6.4s, v6.4s, #\shift
.endif
str s4, [x4]
str s6, [x4, #4]
str s0, [x4, #8]
.endm

.macro LOAD_ROW
ldr q1, [x0]
ldr q2, [x2]
add x0, x0, x1
add x2, x2, x3
.endm

.macro LOAD_ROWS
ldr s1, [x0]
ldr s2, [x2]
ldr s0, [x0, x1]
ldr s17, [x2, x3]
add x0, x0, x1, lsl 1
add x2, x2, x3, lsl 1
zip1 v1.2s, v1.2s, v0.2s
zip1 v2.2s, v2.2s, v17.2s
.endm

.macro CDEF_DIST_INIT width, height
.irp i, v3.8h, v4.8h, v5.8h, v6.8h, v7.8h, v16.8h
movi \i, #0
.endr
.if \width == 4
mov w5, #(\height / 2)
.else
mov w5, #\height
.endif
.endm

// x0: src: *const u8,
// x1: src_stride: isize,
// x2: dst: *const u8,
// x3: dst_stride: isize,
// x4: ret_ptr: *mut u32,
function cdef_dist_kernel_4x4_neon, export=1
CDEF_DIST_INIT 4, 4
L(cdk_4x4):
LOAD_ROWS
CDEF_DIST_W8
subs w5, w5, #1
bne L(cdk_4x4)
CDEF_DIST_REFINE 2
ret
endfunc

function cdef_dist_kernel_4x8_neon, export=1
CDEF_DIST_INIT 4, 8
L(cdk_4x8):
LOAD_ROWS
CDEF_DIST_W8
subs w5, w5, #1
bne L(cdk_4x8)
CDEF_DIST_REFINE 1
ret
endfunc

function cdef_dist_kernel_8x4_neon, export=1
CDEF_DIST_INIT 8, 4
L(cdk_8x4):
LOAD_ROW
CDEF_DIST_W8
subs w5, w5, #1
bne L(cdk_8x4)
CDEF_DIST_REFINE 1
ret
endfunc

function cdef_dist_kernel_8x8_neon, export=1
CDEF_DIST_INIT 8, 8
L(cdk_8x8):
LOAD_ROW
CDEF_DIST_W8
subs w5, w5, #1
bne L(cdk_8x8)
CDEF_DIST_REFINE
ret
endfunc
129 changes: 129 additions & 0 deletions src/asm/aarch64/dist/cdef_dist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2023, The rav1e contributors. All rights reserved
//
// This source code is subject to the terms of the BSD 2 Clause License and
// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
// was not distributed with this source code in the LICENSE file, you can
// obtain it at www.aomedia.org/license/software. If the Alliance for Open
// Media Patent License 1.0 was not distributed with this source code in the
// PATENTS file, you can obtain it at www.aomedia.org/license/patent.

use crate::activity::apply_ssim_boost;
use crate::cpu_features::CpuFeatureLevel;
use crate::dist::*;
use crate::tiling::PlaneRegion;
use crate::util::Pixel;
use crate::util::PixelType;

type CdefDistKernelFn = unsafe extern fn(
src: *const u8,
src_stride: isize,
dst: *const u8,
dst_stride: isize,
ret_ptr: *mut u32,
);

extern {
fn rav1e_cdef_dist_kernel_4x4_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_4x8_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_8x4_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_8x8_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
ret_ptr: *mut u32,
);
}

/// # Panics
///
/// - If in `check_asm` mode, panics on mismatch between native and ASM results.
#[allow(clippy::let_and_return)]
pub fn cdef_dist_kernel<T: Pixel>(
src: &PlaneRegion<'_, T>, dst: &PlaneRegion<'_, T>, w: usize, h: usize,
bit_depth: usize, cpu: CpuFeatureLevel,
) -> u32 {
debug_assert!(src.plane_cfg.xdec == 0);
debug_assert!(src.plane_cfg.ydec == 0);
debug_assert!(dst.plane_cfg.xdec == 0);
debug_assert!(dst.plane_cfg.ydec == 0);

// Limit kernel to 8x8
debug_assert!(w <= 8);
debug_assert!(h <= 8);

let call_rust =
|| -> u32 { rust::cdef_dist_kernel(dst, src, w, h, bit_depth, cpu) };
#[cfg(feature = "check_asm")]
let ref_dist = call_rust();

let (svar, dvar, sse) = match T::type_enum() {
PixelType::U8 => {
if let Some(func) =
CDEF_DIST_KERNEL_FNS[cpu.as_index()][kernel_fn_index(w, h)]
{
let mut ret_buf = [0u32; 3];
// SAFETY: Calls Assembly code.
unsafe {
func(
src.data_ptr() as *const _,
T::to_asm_stride(src.plane_cfg.stride),
dst.data_ptr() as *const _,
T::to_asm_stride(dst.plane_cfg.stride),
ret_buf.as_mut_ptr(),
)
}

(ret_buf[0], ret_buf[1], ret_buf[2])
} else {
return call_rust();
}
}
PixelType::U16 => {
return call_rust();
}
};

let dist = apply_ssim_boost(sse, svar, dvar, bit_depth);
#[cfg(feature = "check_asm")]
assert_eq!(
dist, ref_dist,
"CDEF Distortion {}x{}: Assembly doesn't match reference code.",
w, h
);

dist
}

/// Store functions in a 8x8 grid. Most will be empty.
const CDEF_DIST_KERNEL_FNS_LENGTH: usize = 8 * 8;

const fn kernel_fn_index(w: usize, h: usize) -> usize {
((w - 1) << 3) | (h - 1)
}

static CDEF_DIST_KERNEL_FNS_NEON: [Option<CdefDistKernelFn>;
CDEF_DIST_KERNEL_FNS_LENGTH] = {
let mut out: [Option<CdefDistKernelFn>; CDEF_DIST_KERNEL_FNS_LENGTH] =
[None; CDEF_DIST_KERNEL_FNS_LENGTH];

out[kernel_fn_index(4, 4)] = Some(rav1e_cdef_dist_kernel_4x4_neon);
out[kernel_fn_index(4, 8)] = Some(rav1e_cdef_dist_kernel_4x8_neon);
out[kernel_fn_index(8, 4)] = Some(rav1e_cdef_dist_kernel_8x4_neon);
out[kernel_fn_index(8, 8)] = Some(rav1e_cdef_dist_kernel_8x8_neon);

out
};

cpu_function_lookup_table!(
CDEF_DIST_KERNEL_FNS:
[[Option<CdefDistKernelFn>; CDEF_DIST_KERNEL_FNS_LENGTH]],
default: [None; CDEF_DIST_KERNEL_FNS_LENGTH],
[NEON]
);
3 changes: 3 additions & 0 deletions src/asm/aarch64/dist.rs → src/asm/aarch64/dist/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
// Media Patent License 1.0 was not distributed with this source code in the
// PATENTS file, you can obtain it at www.aomedia.org/license/patent.

pub use self::cdef_dist::*;
use crate::cpu_features::CpuFeatureLevel;
use crate::dist::*;
use crate::partition::BlockSize;
use crate::tiling::*;
use crate::util::*;

mod cdef_dist;

type SadFn = unsafe extern fn(
src: *const u8,
src_stride: isize,
Expand Down
Loading