Skip to content

Commit c92f03f

Browse files
authored
Support un-fused batchnorm1d/2d on XNNPACK via decomposition (#16533)
Summary: Add a new pass - DecomposeBatchNorm - which converts standalone (non-fused) batch norm operators to 1x1 depthwise convolution. This prevents delegation graph breaks when batch norm operators can't be fused. Differential Revision: D90422630 cc @digantdesai @cbilgin
1 parent 8d29b22 commit c92f03f

File tree

10 files changed

+913
-68
lines changed

10 files changed

+913
-68
lines changed

backends/test/harness/tester.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ def __init__(
4141
example_inputs: Tuple[torch.Tensor],
4242
stage_classes: Dict[StageType, Callable] | None = None,
4343
dynamic_shapes: Optional[Tuple[Any]] = None,
44+
training: bool = False,
4445
):
45-
module.eval()
46+
if training:
47+
module.train()
48+
else:
49+
module.eval()
4650

4751
self.stage_classes = stage_classes or Tester.default_stage_classes()
4852
self.original_module = module

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2424
ConvertToUpsampleBilinear2d,
2525
)
26+
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
2627
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2728
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2829
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
@@ -76,6 +77,7 @@ def __init__(
7677
ConvertToSDPAPass,
7778
ConstPropPass,
7879
FuseBatchNormPass,
80+
DecomposeBatchNorm,
7981
FuseActivationPass,
8082
DecomposeConcatenate,
8183
RemoveGetItemPass,
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import operator
9+
10+
import torch
11+
from executorch.backends.transforms.utils import create_constant_placeholder
12+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
13+
from executorch.backends.xnnpack.utils.utils import (
14+
check_or_raise,
15+
get_param_tensor,
16+
get_tensor_name,
17+
is_param_node,
18+
)
19+
from executorch.exir.backend.utils import WhyNoPartition
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
from torch.export.graph_signature import InputKind
22+
from torch.fx.passes.infra.pass_base import PassResult
23+
24+
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.WARNING)
26+
27+
28+
class DecomposeBatchNorm(XNNPACKPass):
29+
"""
30+
Decompose batchnorm operators into 1x1 depthwise convolution.
31+
"""
32+
33+
BATCH_NORM_OPS = {
34+
exir_ops.edge.aten.native_batch_norm.default,
35+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
36+
}
37+
38+
@staticmethod
39+
def can_decompose_batch_norm( # noqa: C901
40+
node: torch.fx.Node,
41+
exported_program: torch.export.ExportedProgram,
42+
why: WhyNoPartition | None = None,
43+
) -> bool:
44+
"""
45+
Determine whether the given batch norm node can be decomposed by this pass.
46+
"""
47+
48+
if (
49+
node.op != "call_function"
50+
or node.target not in DecomposeBatchNorm.BATCH_NORM_OPS
51+
):
52+
return False
53+
54+
input_meta = node.args[0].meta["val"]
55+
56+
# Since we're converting to conv and XNNPACK doesn't support conv3d, we can't
57+
# handle BatchNorm3d. Validate the input dimension. We'll take NC, NCL, or NCHW.
58+
if input_meta.dim() not in (2, 3, 4):
59+
if why:
60+
why(
61+
node,
62+
f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.",
63+
)
64+
return False
65+
66+
# The batch norm node returns a tuple of output and other stuff we don't care about.
67+
# All users must be getitem nodes that fetch the output (index 0).
68+
# The partitioner should enforce this, but we'll check it here too.
69+
for user in node.users:
70+
if user.target != operator.getitem or user.args[1] != 0:
71+
if why:
72+
why(node, "Batch norm users must only access the output tensor.")
73+
return False
74+
75+
# Channel dimension and non-input args must be statically known.
76+
if not isinstance(input_meta.shape[1], int):
77+
if why:
78+
why(
79+
node,
80+
f"Channel dimension must be statically known, but was {input_meta.shape[1]}.",
81+
)
82+
return False
83+
84+
if node.args[1] is not None and not is_param_node(
85+
exported_program, node.args[1]
86+
):
87+
if why:
88+
why(node, "Batch norm affine weight must be static.")
89+
return False
90+
91+
if node.args[2] is not None and not is_param_node(
92+
exported_program, node.args[2]
93+
):
94+
if why:
95+
why(node, "Batch norm affine bias must be static.")
96+
return False
97+
98+
if not is_param_node(exported_program, node.args[3]) or not is_param_node(
99+
exported_program, node.args[4]
100+
):
101+
if why:
102+
why(node, "Batch norm running mean and variance must be static.")
103+
return False
104+
105+
if isinstance(node.args[-1], torch.fx.Node):
106+
if why:
107+
why(node, "Batch norm epsilon must be static.")
108+
return False
109+
110+
if (
111+
node.target == exir_ops.edge.aten.native_batch_norm.default
112+
and node.args[5] is not False
113+
):
114+
if why:
115+
why(node, "Training batch norm is not supported.")
116+
return False
117+
118+
return True
119+
120+
@staticmethod
121+
def compute_w_and_b(
122+
eps: float,
123+
running_mean: torch.Tensor, # [C]
124+
running_var: torch.Tensor, # [C]
125+
gamma: torch.Tensor, # [C], learned weight
126+
beta: torch.Tensor, # [C], learned bias
127+
) -> (torch.Tensor, torch.Tensor):
128+
"""
129+
Compute equivalent per-channel weight and bias to match the batch norm
130+
computation with frozen values.
131+
"""
132+
133+
# See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
134+
135+
# Do the math in double precision and convert back to the original dtype at the
136+
# end. ATen kernels do this math in increased precision for float16. Note that
137+
# all of the parameter dtypes must match, as per the ATen behavior.
138+
139+
# Also note that gamma and beta can be None if affine=False. This is equivalent
140+
# to gamma = 1 and beta = 0.
141+
gamma_f64 = gamma.double() if gamma is not None else torch.Tensor([1]).double()
142+
beta_f64 = beta.double() if beta is not None else torch.Tensor([0]).double()
143+
running_mean_f64 = running_mean.double()
144+
running_var_f64 = running_var.double()
145+
146+
denom = torch.sqrt(running_var_f64 + torch.Tensor([eps]))
147+
new_weight = gamma_f64 / denom
148+
new_bias = -running_mean_f64 * gamma_f64 / denom + beta_f64
149+
150+
return new_weight.to(running_mean.dtype), new_bias.to(running_mean.dtype)
151+
152+
def replace_bn_node_with_conv(
153+
self,
154+
bn_node: torch.fx.Node,
155+
graph_module: torch.fx.GraphModule,
156+
) -> torch.fx.Node:
157+
"""
158+
Replace a BatchNorm with NCL or NCHW input with an equivalent depthwise
159+
convolution.
160+
"""
161+
162+
# Compute the equivalent per-channel weights and biases.
163+
# Note that the batch norm node args are
164+
# (input, gamma, beta, running_mean, running_var, [training], momentum, eps).
165+
# The training arg is not present in the _no_training variant.
166+
weight, bias = DecomposeBatchNorm.compute_w_and_b(
167+
eps=bn_node.args[-1],
168+
running_mean=get_param_tensor(self.exported_program, bn_node.args[3]),
169+
running_var=get_param_tensor(self.exported_program, bn_node.args[4]),
170+
gamma=get_param_tensor(self.exported_program, bn_node.args[1]),
171+
beta=get_param_tensor(self.exported_program, bn_node.args[2]),
172+
)
173+
174+
# Conv weights have shape [out_c, in_c/g, spatial...].
175+
# For dw, in_c = g. The kernel is also 1x1 (or just 1, for 1d).
176+
#
177+
# BatchNorm weights have shape [in_c].
178+
# So we just need to unsqueeze the [in_c] to to [in_c, 1, 1, [1]].
179+
input_meta = bn_node.args[0].meta["val"]
180+
channel_count = input_meta.shape[1]
181+
spatial_dims = max(
182+
input_meta.dim() - 2, 1
183+
) # Min of 1 since 1d can be NC or NCL.
184+
new_weight_shape = [weight.shape[0], 1] + [1] * spatial_dims
185+
weight = weight.reshape(new_weight_shape)
186+
187+
# Generate names for the new weight and bias parameters based on the original
188+
# batch norm gamma parameter name.
189+
gamma_name = get_tensor_name(self.exported_program, bn_node.args[1])
190+
weight_name = (gamma_name + "_decomposed_bn_weight").replace(".", "_")
191+
bias_name = (gamma_name + "_decomposed_bn_bias").replace(".", "_")
192+
193+
# Insert the new weight and bias as constant placeholders in the graph.
194+
with graph_module.graph.inserting_before(bn_node.args[1]):
195+
weight_node = create_constant_placeholder(
196+
exp_program=self.exported_program,
197+
graph=graph_module.graph,
198+
kind=InputKind.PARAMETER,
199+
name=weight_name,
200+
data=weight,
201+
)
202+
bias_node = create_constant_placeholder(
203+
exp_program=self.exported_program,
204+
graph=graph_module.graph,
205+
kind=InputKind.PARAMETER,
206+
name=bias_name,
207+
data=bias,
208+
)
209+
210+
with graph_module.graph.inserting_after(bn_node):
211+
conv_node = graph_module.graph.call_function(
212+
exir_ops.edge.aten.convolution.default,
213+
args=(
214+
bn_node.args[0], # Input
215+
weight_node, # Weight
216+
bias_node, # Bias
217+
[1] * spatial_dims, # Stride
218+
[0] * spatial_dims, # Padding
219+
[1] * spatial_dims, # Dilation
220+
False, # Transposed
221+
[0] * spatial_dims, # Output_padding
222+
channel_count, # Groups (depthwise, so groups=in_channels)
223+
),
224+
)
225+
226+
# Find the getitem user nodes and replace them with the conv node.
227+
# The decomp checks above enforce that the node is only used by getitem[0].
228+
users = list(bn_node.users)
229+
for user in users:
230+
user.replace_all_uses_with(conv_node)
231+
graph_module.graph.erase_node(user)
232+
233+
graph_module.graph.erase_node(bn_node)
234+
return conv_node
235+
236+
def decompose_node(
237+
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
238+
) -> None:
239+
input_meta = node.args[0].meta["val"]
240+
241+
# These should be checked by the partitioner and calling node,
242+
# so we should never fail these checks.
243+
check_or_raise(
244+
node.op == "call_function"
245+
and node.target in DecomposeBatchNorm.BATCH_NORM_OPS,
246+
f"Invalid batch norm operator {node.op}.",
247+
)
248+
249+
check_or_raise(
250+
input_meta.dim() in (2, 3, 4),
251+
f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.",
252+
)
253+
254+
channel_count = input_meta.shape[1]
255+
check_or_raise(
256+
isinstance(channel_count, int),
257+
f"Channel dimension must be statically known, but was {channel_count}.",
258+
)
259+
260+
# Create the convolution node.
261+
conv_node = self.replace_bn_node_with_conv(node, graph_module)
262+
263+
# BatchNorm1d can be NC or NCL. Conv1d requies the L dim, so unsqueeze NC -> NCL.
264+
if input_meta.dim() == 2:
265+
with graph_module.graph.inserting_before(conv_node):
266+
# Insert unsqueeze node before.
267+
unsqueeze_node = graph_module.graph.call_function(
268+
exir_ops.edge.aten.unsqueeze_copy.default,
269+
args=(conv_node.args[0], 2),
270+
)
271+
conv_node.args = (unsqueeze_node, *conv_node.args[1:])
272+
273+
with graph_module.graph.inserting_after(conv_node):
274+
# Insert squeeze node after.
275+
squeeze_node = graph_module.graph.call_function(
276+
exir_ops.edge.aten.squeeze_copy.dim, args=(conv_node, 2)
277+
)
278+
conv_node.replace_all_uses_with(squeeze_node)
279+
# This gets overwritten by replace_all_uses_with. Maybe there's
280+
# a better solution?
281+
squeeze_node.args = (conv_node, *squeeze_node.args[1:])
282+
283+
# override
284+
def call(self, graph_module: torch.fx.GraphModule):
285+
# Find and transform all eligible batch norm nodes.
286+
for node in graph_module.graph.nodes:
287+
if node.op == "call_function" and node.target in self.BATCH_NORM_OPS:
288+
if self.can_decompose_batch_norm(node, self.exported_program):
289+
self.decompose_node(node, graph_module)
290+
291+
graph_module.recompile()
292+
293+
# Propagate metadata and retrace module
294+
graph_module = super().call(graph_module).graph_module
295+
296+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)