Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions ynnpack/kernels/binary/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,36 +299,27 @@ const binary_kernel* get_binary_reference_kernel(ynn_binary_operator op,
return nullptr;
}

const binary_kernel* get_binary_kernel(ynn_binary_operator op, ynn_type type,
const binary_kernel* get_binary_kernel(ynn_binary_operator op, ynn_type type_a,
ynn_type type_b, ynn_type type_x,
bool is_quantized,
uint64_t supported_arch_flags) {
// TODO(vksnk): select a better kernel based on the passed size.
#define YNN_ELEMENTWISE_KERNEL(arch, name, op_type, init_params_fn, A, B, X) \
if (type == type_of<A>() && type == type_of<B>() && type == type_of<X>() && \
op == ynn_binary_##op_type && \
is_arch_supported(arch, supported_arch_flags)) { \
static binary_kernel kernel##name = {&name, nullptr}; \
YNN_LOG_INFO() << "Using binary kernel " << #name; \
return &kernel##name; \
if (!is_quantized) {
#define YNN_ELEMENTWISE_KERNEL(arch, name, op_type, init_params_fn, A, B, X) \
if (is_arch_supported(arch, supported_arch_flags) && \
op == ynn_binary_##op_type) { \
if (type_of<A>() == type_a && type_of<B>() == type_b && \
type_of<X>() == type_x) { \
static binary_kernel kernel##name = {&name, nullptr}; \
YNN_LOG_INFO() << "Using binary kernel " << #name; \
return &kernel##name; \
} \
}

#include "ynnpack/kernels/binary/kernels.inc"
#undef YNN_ELEMENTWISE_KERNEL

return get_binary_reference_kernel(op, type, false);
}

binary_kernel_fn get_binary_multiply_kernel(ynn_type type_a, ynn_type type_b,
ynn_type type_x) {
#define YNN_ELEMENTWISE_KERNEL(arch, name, op, init_params_fn, A, B, X) \
if (ynn_binary_##op == ynn_binary_multiply && is_arch_supported(arch)) { \
if (type_of<A>() == type_a && type_of<B>() == type_b && \
type_of<X>() == type_x) { \
return name; \
} \
}
#include "ynnpack/kernels/binary/kernels.inc"
#undef YNN_ELEMENTWISE_KERNEL
if (type_a == type_x && type_b == type_x) {
return get_binary_reference_kernel(op, type_x, is_quantized);
}
return nullptr;
}

Expand Down
6 changes: 2 additions & 4 deletions ynnpack/kernels/binary/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,10 @@ const binary_kernel* get_binary_reference_kernel(ynn_binary_operator op, T) {
}

const binary_kernel* get_binary_kernel(
ynn_binary_operator op, ynn_type input_type, bool quantized,
ynn_binary_operator op, ynn_type type_a, ynn_type type_b, ynn_type type_x,
bool is_quantized = false,
uint64_t supported_arch_flags = get_supported_arch_flags());

binary_kernel_fn get_binary_multiply_kernel(ynn_type type_a, ynn_type type_b,
ynn_type type_x);

} // namespace ynn

#endif // XNNPACK_YNNPACK_KERNELS_BINARY_H_
3 changes: 2 additions & 1 deletion ynnpack/subgraph/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ ynn_status ynn_define_binary(ynn_subgraph_t subgraph, ynn_binary_operator op,
b.zero_point_id != YNN_INVALID_VALUE_ID ||
x.scale_id != YNN_INVALID_VALUE_ID ||
x.zero_point_id != YNN_INVALID_VALUE_ID;
const binary_kernel* kernel = get_binary_kernel(op, x.type, is_quantized);
const binary_kernel* kernel =
get_binary_kernel(op, a.type, b.type, x.type, is_quantized);
if (!kernel) {
YNN_LOG_ERROR() << "unsupported binary operator " << op
<< " for input types " << a.type << ", " << b.type
Expand Down
7 changes: 4 additions & 3 deletions ynnpack/subgraph/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,15 @@ bool rewrite_convert_to_multiply(ynn_subgraph& subgraph, ynn_node& node,
}
}
}
ynn::binary_kernel_fn kernel = ynn::get_binary_multiply_kernel(
input.type, subgraph.value(input.scale_id).type, output.type);
const ynn::binary_kernel* kernel =
ynn::get_binary_kernel(ynn_binary_multiply, input.type,
subgraph.value(input.scale_id).type, output.type);
if (kernel != nullptr) {
// This is a binary integer*float multiply, and we have a kernel that
// matches the types we have.
YNN_LOG_DEBUG() << "Rewriting convert to binary multiply";
ynn::define_binary(subgraph, node, node.inputs[0], input.scale_id,
node.outputs[0], ynn_binary_multiply, kernel);
node.outputs[0], ynn_binary_multiply, kernel->op);
}
return false;
}
Expand Down
Loading