Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Simplify aten.unfold #1998

Open
justinchuby opened this issue Jan 6, 2025 · 1 comment
Open

[torchlib] Simplify aten.unfold #1998

justinchuby opened this issue Jan 6, 2025 · 1 comment
Labels
contribution welcome We welcome code contributions for this enhancement New feature or request module: torchlib Related to the torch/aten function lib in development

Comments

@justinchuby
Copy link
Collaborator

The current translation for unfold is overly complicated with a loop:

import torch

class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        # Defining window within forward causes it to break
        window = torch.hann_window(window_length=320)
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,
            window=window, # using self._window avoids the issue
            pad_mode="constant",
        )

        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000]).cpu()

args = (input_signals,)
ep = torch.onnx.export( # note that torch.export.export works
    m,
    args,
    dynamo=True
)
ep.optimize()
print(ep)

yields:

ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.onnxscript.torch_lib': 1},
            producer_name='pytorch',
            producer_version='2.6.0a0+git5ef0de7',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"signals"<FLOAT,[2,16000]>
            ),
            outputs=(
                %"transpose"<FLOAT,[2,257,101,2]>
            ),
        ) {
             0 |  # node_Constant_26
                  %"hann_window"<FLOAT,[320]> ⬅️ ::Constant() {value=Tensor<FLOAT,[320]>(array([0.00000000e+00, 9.63797720e-05, 3.85481922e-04, 8.67194962e-04,
                         1.54133327e-03, 2.40763673e-03, 3.46577144e-03, 4.71532997e-03,
                         6.15583034e-03, 7.78671540e-03, 9.60735977e-03, 1.16170608e-02,
                         1.38150407e-02, 1.62004549e-02, 1.87723842e-02, 2.15298310e-02,
                         2.44717449e-02, 2.75969803e-02, 3.09043285e-02, 3.43925320e-02,
                         3.80602330e-02, 4.19060290e-02, 4.59284224e-02, 5.01258597e-02,
                         5.44967428e-02, 5.90393767e-02, 6.37520030e-02, 6.86328188e-02,
                         7.36799315e-02, 7.88913816e-02, 8.42651874e-02, 8.97992775e-02,
                         9.54915062e-02, 1.01339675e-01, 1.07341543e-01, 1.13494776e-01,
                         1.19797014e-01, 1.26245826e-01, 1.32838756e-01, 1.39573202e-01,
                         1.46446630e-01, 1.53456300e-01, 1.60599634e-01, 1.67873785e-01,
                         1.75276011e-01, 1.82803348e-01, 1.90453023e-01, 1.98222041e-01,
                         2.06107393e-01, 2.14106023e-01, 2.22214893e-01, 2.30430856e-01,
                         2.38750741e-01, 2.47171283e-01, 2.55689442e-01, 2.64301658e-01,
                         2.73004800e-01, 2.81795382e-01, 2.90670127e-01, 2.99625605e-01,
                         3.08658302e-01, 3.17764759e-01, 3.26941460e-01, 3.36184919e-01,
                         3.45491499e-01, 3.54857683e-01, 3.64279807e-01, 3.73754203e-01,
                         3.83277327e-01, 3.92845452e-01, 4.02454913e-01, 4.12101895e-01,
                         4.21782762e-01, 4.31493878e-01, 4.41231340e-01, 4.50991452e-01,
                         4.60770488e-01, 4.70564574e-01, 4.80370134e-01, 4.90183175e-01,
                         4.99999970e-01, 5.09816945e-01, 5.19629836e-01, 5.29435456e-01,
                         5.39229631e-01, 5.49008548e-01, 5.58768749e-01, 5.68506241e-01,
                         5.78217328e-01, 5.87898135e-01, 5.97545147e-01, 6.07154608e-01,
                         6.16722703e-01, 6.26245797e-01, 6.35720253e-01, 6.45142436e-01,
                         6.54508531e-01, 6.63815141e-01, 6.73058629e-01, 6.82235301e-01,
                         6.91341758e-01, 7.00374484e-01, 7.09329903e-01, 7.18204677e-01,
                         7.26995349e-01, 7.35698462e-01, 7.44310617e-01, 7.52828777e-01,
                         7.61249363e-01, 7.69569159e-01, 7.77785063e-01, 7.85893977e-01,
                         7.93892741e-01, 8.01777959e-01, 8.09546947e-01, 8.17196667e-01,
                         8.24724019e-01, 8.32126200e-01, 8.39400351e-01, 8.46543729e-01,
                         8.53553355e-01, 8.60426784e-01, 8.67161214e-01, 8.73754144e-01,
                         8.80203009e-01, 8.86505187e-01, 8.92658472e-01, 8.98660302e-01,
                         9.04508531e-01, 9.10200715e-01, 9.15734828e-01, 9.21108603e-01,
                         9.26320136e-01, 9.31367278e-01, 9.36248004e-01, 9.40960646e-01,
                         9.45503235e-01, 9.49874103e-01, 9.54071581e-01, 9.58094001e-01,
                         9.61939812e-01, 9.65607405e-01, 9.69095647e-01, 9.72402990e-01,
                         9.75528300e-01, 9.78470147e-01, 9.81227636e-01, 9.83799577e-01,
                         9.86184955e-01, 9.88382876e-01, 9.90392625e-01, 9.92213309e-01,
                         9.93844092e-01, 9.95284617e-01, 9.96534228e-01, 9.97592330e-01,
                         9.98458624e-01, 9.99132812e-01, 9.99614537e-01, 9.99903560e-01,
                         1.00000000e+00, 9.99903560e-01, 9.99614537e-01, 9.99132812e-01,
                         9.98458624e-01, 9.97592330e-01, 9.96534228e-01, 9.95284617e-01,
                         9.93844092e-01, 9.92213309e-01, 9.90392625e-01, 9.88382876e-01,
                         9.86184955e-01, 9.83799577e-01, 9.81227636e-01, 9.78470147e-01,
                         9.75528181e-01, 9.72402990e-01, 9.69095647e-01, 9.65607405e-01,
                         9.61939692e-01, 9.58093882e-01, 9.54071581e-01, 9.49874103e-01,
                         9.45503235e-01, 9.40960646e-01, 9.36248004e-01, 9.31367159e-01,
                         9.26320016e-01, 9.21108484e-01, 9.15734708e-01, 9.10200715e-01,
                         9.04508412e-01, 8.98660302e-01, 8.92658472e-01, 8.86505187e-01,
                         8.80202889e-01, 8.73754025e-01, 8.67161214e-01, 8.60426903e-01,
                         8.53553355e-01, 8.46543610e-01, 8.39400351e-01, 8.32126200e-01,
                         8.24724019e-01, 8.17196667e-01, 8.09546947e-01, 8.01777959e-01,
                         7.93892503e-01, 7.85893977e-01, 7.77785063e-01, 7.69568920e-01,
                         7.61249363e-01, 7.52828658e-01, 7.44310498e-01, 7.35698342e-01,
                         7.26995111e-01, 7.18204558e-01, 7.09329903e-01, 7.00374365e-01,
                         6.91341758e-01, 6.82235181e-01, 6.73058391e-01, 6.63814962e-01,
                         6.54508293e-01, 6.45142198e-01, 6.35720253e-01, 6.26245677e-01,
                         6.16722703e-01, 6.07154548e-01, 5.97544968e-01, 5.87898016e-01,
                         5.78217208e-01, 5.68506122e-01, 5.58768749e-01, 5.49008489e-01,
                         5.39229512e-01, 5.29435217e-01, 5.19629717e-01, 5.09816945e-01,
                         4.99999970e-01, 4.90183085e-01, 4.80370134e-01, 4.70564514e-01,
                         4.60770488e-01, 4.50991303e-01, 4.41231340e-01, 4.31493938e-01,
                         4.21782762e-01, 4.12101716e-01, 4.02454823e-01, 3.92845303e-01,
                         3.83277327e-01, 3.73754293e-01, 3.64279747e-01, 3.54857743e-01,
                         3.45491499e-01, 3.36184770e-01, 3.26941401e-01, 3.17764580e-01,
                         3.08658242e-01, 2.99625605e-01, 2.90670067e-01, 2.81795442e-01,
                         2.73004681e-01, 2.64301449e-01, 2.55689323e-01, 2.47171342e-01,
                         2.38750651e-01, 2.30430856e-01, 2.22214773e-01, 2.14105994e-01,
                         2.06107259e-01, 1.98221818e-01, 1.90453097e-01, 1.82803318e-01,
                         1.75275862e-01, 1.67873755e-01, 1.60599515e-01, 1.53456271e-01,
                         1.46446496e-01, 1.39573157e-01, 1.32838801e-01, 1.26245812e-01,
                         1.19796887e-01, 1.13494739e-01, 1.07341409e-01, 1.01339616e-01,
                         9.54915285e-02, 8.97992253e-02, 8.42652023e-02, 7.88913369e-02,
                         7.36798048e-02, 6.86327517e-02, 6.37518838e-02, 5.90393171e-02,
                         5.44967428e-02, 5.01258075e-02, 4.59284149e-02, 4.19059731e-02,
                         3.80601399e-02, 3.43925729e-02, 3.09043229e-02, 2.75969319e-02,
                         2.44717356e-02, 2.15297900e-02, 1.87723711e-02, 1.62004139e-02,
                         1.38150305e-02, 1.16170747e-02, 9.60734952e-03, 7.78668514e-03,
                         6.15582103e-03, 4.71530575e-03, 3.46576399e-03, 2.40764185e-03,
                         1.54132769e-03, 8.67197465e-04, 3.85478779e-04, 9.63757848e-05],
                        dtype=float32), name='hann_window')}
             1 |  # node_Constant_28
                  %"val_2"<INT64,[3]> ⬅️ ::Constant() {value=Tensor<INT64,[3]>(array([    1,     2, 16000]), name='val_2')}
             2 |  # node_Reshape_4
                  %"view"<FLOAT,[1,2,16000]> ⬅️ ::Reshape(%"signals", %"val_2") {allowzero=0}
             3 |  # node_Constant_44
                  %"onnx_padding"<INT64,[6]> ⬅️ ::Constant() {value=Tensor<INT64,[6]>(array([  0,   0, 256,   0,   0, 256]), name='onnx_padding')}
             4 |  # node_Constant_46
                  %"value_cast"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='value_cast')}
             5 |  # n20
                  %"constant_pad_nd"<FLOAT,[1,2,16512]> ⬅️ ::Pad(%"view", %"onnx_padding", %"value_cast")
             6 |  # node_Constant_48
                  %"val_5"<INT64,[2]> ⬅️ ::Constant() {value=Tensor<INT64,[2]>(array([    2, 16512]), name='val_5')}
             7 |  # node_Reshape_9
                  %"view_1"<FLOAT,[2,16512]> ⬅️ ::Reshape(%"constant_pad_nd", %"val_5") {allowzero=0}
             8 |  # node_Constant_64
                  %"onnx_padding_2"<INT64,[2]> ⬅️ ::Constant() {value=Tensor<INT64,[2]>(array([96, 96]), name='onnx_padding_2')}
             9 |  # node_Constant_66
                  %"value_cast_2"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='value_cast_2')}
            10 |  # n20_2
                  %"constant_pad_nd_1"<FLOAT,[512]> ⬅️ ::Pad(%"hann_window", %"onnx_padding_2", %"value_cast_2")
            11 |  # node_Constant_67
                  %"dims"<INT64,[1]> ⬅️ ::Constant() {value=Tensor<INT64,[1]>(array([1]), name='dims')}
            12 |  # n3_3
                  %"seq_result"<Sequence(Tensor(FLOAT)),?> ⬅️ ::SequenceEmpty()
            13 |  # n4_3
                  %"i"<INT64,[]> ⬅️ ::Constant() {value_int=0}
            14 |  # node_Constant_70
                  %"cond"<BOOL,[]> ⬅️ ::Constant() {value=Tensor<BOOL,[]>(array(True), name='cond')}
            15 |  # n8_4
                  %"i_8"<?,?>, %"seq_result_9"<?,?> ⬅️ ::Loop(None, %"cond", %"i", %"seq_result") {body=
                      graph(
                          name=loop_body,
                          inputs=(
                              %"infinite_loop"<INT64,[]>,
                              %"cond"<BOOL,[]>,
                              %"i_1"<?,?>,
                              %"seq_result_2"<?,?>
                          ),
                          outputs=(
                              %"cond_out"<?,?>,
                              %"i_5"<?,?>,
                              %"seq_result_4"<?,?>
                          ),
                      ) {
                           0 |  # n0_6
                                %"step"<INT64,[]> ⬅️ ::Constant() {value_int=160}
                           1 |  # n1_6
                                %"step_cast"<?,?> ⬅️ ::CastLike(%"step", %"i_1")
                           2 |  # n2_4
                                %"tmp_3_2"<?,?> ⬅️ ::Mul(%"i_1", %"step_cast")
                           3 |  # n3_4
                                %"int64_m1_1d"<INT64,[1]> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[1]>(name='int64_m1_1d')}
                           4 |  # n4_4
                                %"starts"<?,?> ⬅️ ::Reshape(%"tmp_3_2", %"int64_m1_1d")
                           5 |  # n5_4
                                %"size"<INT64,[]> ⬅️ ::Constant() {value_int=512}
                           6 |  # n6_4
                                %"size_cast"<?,?> ⬅️ ::CastLike(%"size", %"starts")
                           7 |  # n7_4
                                %"ends_4"<?,?> ⬅️ ::Add(%"starts", %"size_cast")
                           8 |  # n8_3
                                %"slice_result"<?,?> ⬅️ ::Slice(%"view_1", %"starts", %"ends_4", %"dims")
                           9 |  # n9_3
                                %"slice_result_float32"<FLOAT,?> ⬅️ ::Cast(%"slice_result") {to=1}
                          10 |  # n10_3
                                %"seq_result_4"<?,?> ⬅️ ::SequenceInsert(%"seq_result_2", %"slice_result_float32")
                          11 |  # n11_3
                                %"int64_1"<INT64,[]> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='int64_1')}
                          12 |  # n12_3
                                %"int64_1_cast"<?,?> ⬅️ ::CastLike(%"int64_1", %"i_1")
                          13 |  # n13_3
                                %"i_5"<?,?> ⬅️ ::Add(%"i_1", %"int64_1_cast")
                          14 |  # n14_3
                                %"target_end_6"<INT64,[]> ⬅️ ::Constant() {value_int=101}
                          15 |  # n15_3
                                %"target_end_6_cast"<?,?> ⬅️ ::CastLike(%"target_end_6", %"i_5")
                          16 |  # n16_3
                                %"cond_7"<?,?> ⬅️ ::Less(%"i_5", %"target_end_6_cast")
                          17 |  # n17_3
                                %"cond_out"<?,?> ⬅️ ::Identity(%"cond_7")
                          return %"cond_out"<?,?>, %"i_5"<?,?>, %"seq_result_4"<?,?>
                      }}
            16 |  # n9_4
                  %"concat_result"<?,?> ⬅️ ::ConcatFromSequence(%"seq_result_9") {axis=1, new_axis=1}
            17 |  # n10_4
                  %"result"<?,?> ⬅️ ::Transpose(%"concat_result") {perm=[0, 1, 2]}
            18 |  # node_Cast_71
                  %"unfold"<FLOAT,[2,101,512]> ⬅️ ::Cast(%"result") {to=1}
            19 |  # node_Mul_13
                  %"mul"<FLOAT,[2,101,512]> ⬅️ ::Mul(%"unfold", %"constant_pad_nd_1")
            20 |  # node_Constant_14
                  %"val_7"<INT64,[1]> ⬅️ ::Constant() {value=Tensor<INT64,[1]>(array([-1]), name=None)}
            21 |  # node_Unsqueeze_15
                  %"val_8"<FLOAT,[2,101,512,1]> ⬅️ ::Unsqueeze(%"mul", %"val_7")
            22 |  # node_Constant_16
                  %"val_9"<INT64,[1]> ⬅️ ::Constant() {value=Tensor<INT64,[1]>(array([0]), name=None)}
            23 |  # node_Unsqueeze_17
                  %"val_10"<FLOAT,[1,2,101,512,1]> ⬅️ ::Unsqueeze(%"val_8", %"val_9")
            24 |  # node_DFT_18
                  %"val_11"<FLOAT,[1,2,101,257,2]> ⬅️ ::DFT(%"val_10") {axis=3, inverse=False, onesided=True}
            25 |  # node_Squeeze_19
                  %"_fft_r2c"<FLOAT,[2,101,257,2]> ⬅️ ::Squeeze(%"val_11", %"val_9")
            26 |  # node_Transpose_25
                  %"transpose"<FLOAT,[2,257,101,2]> ⬅️ ::Transpose(%"_fft_r2c") {perm=[0, 2, 1, 3]}
            return %"transpose"<FLOAT,[2,257,101,2]>
        }


    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, signals: "f32[2, 16000]"):
                     # File: /home/justinchu/dev/pytorch/testtest.py:10 in forward, code: window = torch.hann_window(window_length=320)
                    hann_window: "f32[320]" = torch.ops.aten.hann_window.default(320, device = device(type='cpu'), pin_memory = False)
            
                     # File: /home/justinchu/dev/pytorch/testtest.py:11 in forward, code: x = signals.stft(
                    view: "f32[1, 2, 16000]" = torch.ops.aten.view.default(signals, [1, 2, 16000]);  signals = None
                    constant_pad_nd: "f32[1, 2, 16512]" = torch.ops.aten.constant_pad_nd.default(view, [256, 256], 0.0);  view = None
                    view_1: "f32[2, 16512]" = torch.ops.aten.view.default(constant_pad_nd, [2, 16512]);  constant_pad_nd = None
                    constant_pad_nd_1: "f32[512]" = torch.ops.aten.constant_pad_nd.default(hann_window, [96, 96]);  hann_window = None
                    unfold: "f32[2, 101, 512]" = torch.ops.aten.unfold.default(view_1, -1, 512, 160);  view_1 = None
                    mul: "f32[2, 101, 512]" = torch.ops.aten.mul.Tensor(unfold, constant_pad_nd_1);  unfold = constant_pad_nd_1 = None
                    _fft_r2c: "c64[2, 101, 257]" = torch.ops.aten._fft_r2c.default(mul, [2], 0, True);  mul = None
                    transpose: "c64[2, 257, 101]" = torch.ops.aten.transpose.int(_fft_r2c, 1, 2);  _fft_r2c = None
                    return (transpose,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='signals'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='transpose'), target=None)])
        Range constraints: {}

)
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Jan 6, 2025
@justinchuby
Copy link
Collaborator Author

@justinchuby justinchuby added the enhancement New feature or request label Jan 6, 2025
@justinchuby justinchuby added the contribution welcome We welcome code contributions for this label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contribution welcome We welcome code contributions for this enhancement New feature or request module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

No branches or pull requests

1 participant