Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'ShapedArray' object has no attribute 'val' #2

Open
liuyixin-louis opened this issue Aug 2, 2022 · 1 comment
Open

'ShapedArray' object has no attribute 'val' #2

liuyixin-louis opened this issue Aug 2, 2022 · 1 comment

Comments

@liuyixin-louis
Copy link

liuyixin-louis commented Aug 2, 2022

Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The detailed output is below

Loading dataset...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Building model...
Generating NTGA....
  0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
    f, in_pvals, app, instantiate=False)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
AttributeError: 'ShapedArray' object has no attribute 'val'
@xrose3159
Copy link

Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The detailed output is below

Loading dataset...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Building model...
Generating NTGA....
  0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
    f, in_pvals, app, instantiate=False)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
AttributeError: 'ShapedArray' object has no attribute 'val'

I'm having the same issue too! Have you solved this problem? How was it solved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants