profile
viewpoint

google/jax 10058

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

percyliang/sempre 761

Semantic Parser with Execution

google-research/dex-lang 674

Research language for array processing in the Haskell/ML family

percyliang/refdb 24

Stores paper references, outputs to bib/html, does basic sanity checking on bib entries

froystig/rust 0

a safe, concurrent, practical language

froystig/vowpal_wabbit 0

Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

issue commentgoogle/jax

factor named_call primitive into jax core

As I understand it, there might be differences between what we'd upstream to JAX and what's currently in Haiku, namely due to Haiku's state threading (?). The Flax version is roughly what I had in mind originally.

froystig

comment created time in 17 hours

Pull request review commentgoogle/jax

Preserve name scope of named_call primitive in jax2tf

 def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,     assert call_primitive.multiple_results     vals: Sequence[TfVal] = [t.val for t in tracers]     f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))-    vals_out: Sequence[Tuple[TfVal, core.AbstractValue]] = f.call_wrapped(*vals)+    if call_primitive.name == "named_call":

Upstreaming is tracked as issue #4558. I had assigned it to myself, but at the moment anyone is welcome to take it if they'd like.

qiuminxu

comment created time in 17 hours

PullRequestReviewEvent
PullRequestReviewEvent

PR merged google/jax

Reviewers
[jax2tf] Adding shapes as values for shape polymorphism cla: yes pull ready
+804 -268

0 comment

5 changed files

gnecula

pr closed time in 2 days

PR merged google/jax

tiny change for source sync cla: yes

tiny change for source sync

FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/4597 from gnecula:tf_poly3 8458cab005279515ea9b6b693b32a058297709a5

+1 -0

0 comment

1 changed file

copybara-service[bot]

pr closed time in 2 days

push eventgoogle/jax

George Necula

commit sha f0d2a4d239f45288ac3fc72ee6850086bc5e76d6

[jax2tf] Expanding jax2tf shape polymorphism. Allow shape variables in the primitive parameters, e.g., the shape parameter for the reshape primitive. More details are in the updated README.md.

view details

George Necula

commit sha 36542e0499c86119940433fd0cf7d5f7099a2caa

Edits for the documentation

view details

George Necula

commit sha 8458cab005279515ea9b6b693b32a058297709a5

Fix flakes

view details

Roy Frostig

commit sha 3e56be2a0be8f7c1981cc1e178ed39857ea354e3

Merge pull request #4597 from gnecula:tf_poly3 PiperOrigin-RevId: 339327059

view details

push time in 2 days

PR merged google/jax

tiny change for source sync cla: yes

tiny change for source sync

FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/4597 from gnecula:tf_poly3 8458cab005279515ea9b6b693b32a058297709a5

+1 -0

0 comment

1 changed file

copybara-service[bot]

pr closed time in 2 days

delete branch google/jax

delete branch : test_339326299

delete time in 2 days

push eventgoogle/jax

George Necula

commit sha f0d2a4d239f45288ac3fc72ee6850086bc5e76d6

[jax2tf] Expanding jax2tf shape polymorphism. Allow shape variables in the primitive parameters, e.g., the shape parameter for the reshape primitive. More details are in the updated README.md.

view details

George Necula

commit sha 36542e0499c86119940433fd0cf7d5f7099a2caa

Edits for the documentation

view details

George Necula

commit sha 8458cab005279515ea9b6b693b32a058297709a5

Fix flakes

view details

Roy Frostig

commit sha 3e56be2a0be8f7c1981cc1e178ed39857ea354e3

Merge pull request #4597 from gnecula:tf_poly3 PiperOrigin-RevId: 339327059

view details

push time in 2 days

PR opened google/jax

fix jit docstring's formatting
+1 -1

0 comment

1 changed file

pr created time in 3 days

create barnchgoogle/jax

branch : docfix-jit

created branch time in 3 days

PR opened google/jax

register filtered stack trace path exclusions from outside of `traceback_util`

Previously, in traceback_util.py, we maintained a list of include/exclude paths under which stack frames were filtered. This PR introduces a module-level function for registering a path exclusion, and uses it throughout the codebase.

Hopefully this is easier maintain consistently. Our recent factor/move of several files under _src inadvertently affected what frames were filtered. This PR fixes that in particular.

+73 -35

0 comment

17 changed files

pr created time in 3 days

create barnchgoogle/jax

branch : traceback-register

created branch time in 3 days

PullRequestReviewEvent

Pull request review commentgoogle/jax

Ensure that check_jaxpr is done with abstract values

 def write(v: Var, a: AbstractValue) -> None:     prim = eqn.primitive     try:       in_avals = map(read, eqn.invars)+      typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),+                       f"Equation given ConcreteArray type inputs")
                       "Equation given ConcreteArray type inputs")
gnecula

comment created time in 9 days

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentgoogle/jax

Ensure that check_jaxpr is done with abstract values

 def write(v: Var, a: AbstractValue) -> None:     prim = eqn.primitive     try:       in_avals = map(read, eqn.invars)+      typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),+                       f"Primitive {prim} takes ConcreteArray abstract inputs")
                       f"Equation given ConcreteArray type inputs")

Minor: the primitive name and equation are already included as context whenever typecheck_assert raises in this block (see a few lines down in except ...)

gnecula

comment created time in 9 days

PullRequestReviewEvent
PullRequestReviewEvent

issue closedgoogle/jax

PyTreeDef equality check failing with nested namedtuples

I am using a nested pytree composed namedtuples, and dicts as the state object in an algorithm, and PyTreeDef equality is failing when it compares the before/after structure for a while_loop. The string representations of the two PyTreeDefs are identical. Strangely, if I replace the lowest nested namedtuples with tuples, then it works fine. I have extracted the code from lax_control_flow that does the comparison to inspect how the PyTreeDefs are different but can't see how,

        from jax.tree_util import tree_flatten
        from jax.lax.lax_control_flow import _initial_style_jaxpr, _abstractify
        from jax.util import safe_map
        init_vals, in_tree = tree_flatten((state,))
        init_avals = tuple(safe_map(_abstractify, init_vals))

        body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
            lambda state: body(state), in_tree, init_avals)

        in_tree_children = in_tree.children()
        assert len(in_tree_children) == 1
        print(str(body_tree))
        print(str(in_tree_children[0]))
        print('string equality',str(body_tree)==str(in_tree_children[0]))
        print('equality',body_tree==in_tree_children[0])

produces

PyTreeDef(namedtuple[<class 'jaxns.nested_sampling.NestedSamplerState'>], [*,*,*,*,*,PyTreeDef(dict[['_x_gamma', '_x_mu', 'x']], [*,*,*]),*,PyTreeDef(dict[['x']], [*]),*,*,*,PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.logZState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.mState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.MState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.HState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),*,PyTreeDef(namedtuple[<class 'jaxns.likelihood_samplers.multi_ellipsoid.MultiEllipsoidSamplerState'>], [*,*,*,*,*,*]),PyTreeDef(None, [])])
PyTreeDef(namedtuple[<class 'jaxns.nested_sampling.NestedSamplerState'>], [*,*,*,*,*,PyTreeDef(dict[['_x_gamma', '_x_mu', 'x']], [*,*,*]),*,PyTreeDef(dict[['x']], [*]),*,*,*,PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.logZState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.mState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.MState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [PyTreeDef(dict[['x']], [*]),PyTreeDef(dict[['x']], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.HState'>], [PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.SignedLogParam'>], [*,*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*]),PyTreeDef(namedtuple[<class 'jaxns.param_tracking_new.LogParam'>], [*])]),*,PyTreeDef(namedtuple[<class 'jaxns.likelihood_samplers.multi_ellipsoid.MultiEllipsoidSamplerState'>], [*,*,*,*,*,*]),PyTreeDef(None, [])])
string equality True
equality False

closed time in 14 days

Joshuaalbert

issue commentgoogle/jax

PyTreeDef equality check failing with nested namedtuples

That's correct. PyTreeDef equality requires that the types comprising the tree be equal. I'll close this since it seems resolved.

Joshuaalbert

comment created time in 14 days

PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/jax

assume less about source locations in jaxpr_util_test
+5 -3

0 comment

1 changed file

pr created time in 16 days

create barnchgoogle/jax

branch : jaxpr-util-test

created branch time in 16 days

issue openedgoogle/jax

factor named_call primitive into jax core

Both Flax and Haiku implement a named_call JAX primitive. How about we move this into JAX core?

This is the same as JAX's call primitive (see jax.core.call_p), except in how it compiles to XLA, where the name that it carries is included in the HLO computation name. This makes it useful for annotating profiles.

cc @LenaMartens @tomhennigan @trevorcai @levskaya @jheek @avital

created time in 17 days

PullRequestReviewEvent
PullRequestReviewEvent

issue commentgoogle/jax

Support for variable length in jitted scans?

Are there any plans on supporting "variable-length" scans up to a max number of iterations?

We don't have this "off the shelf," but you could combine scan with cond (or select) as in #3850, to mask out final iterations up to the maximum.

Regarding efficiency: with or without omnistaging, you could make sure a cond carries out the computations that you'd expect by passing it explicit operands rather relying on closure-capture, e.g.:

lax.cond(x > 0, lambda yz: f(yz[0]), lambda yz: g(yz[1]), (y, z))

But indeed, with omnistaging, the simpler expression lax.cond(x > 0, lambda: f(y), lambda g(z)) will behave similarly.

Looking only at your original example: does using lax.fori_loop instead of lax.scan work? It seems to fit, since your scan takes no inputs, and produces no output array, and since there's no reverse-mode autodiff taking place.

nrontsis

comment created time in 20 days

PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/jax

re-enable lax autodiff test

xla:cpu bug was due to a change in llvm, now reverted

+0 -3

0 comment

1 changed file

pr created time in 21 days

create barnchgoogle/jax

branch : enable-conv-test

created branch time in 21 days

PR closed google/jax

Add transparent copy of buffers in the C++ jax.jit. cla: yes

Add transparent copy of buffers in the C++ jax.jit.

If jax.jit is not committed to a device, an error is raise when several sticky buffers are on different devices.

When jax.jit is committed to a device, we will move the data unconditionally.

This should mimic the Python logic.

+111 -83

2 comments

3 changed files

copybara-service[bot]

pr closed time in 21 days

PR opened google/jax

remove an always-skipped lax_numpy test
+0 -26

0 comment

1 changed file

pr created time in 22 days

create barnchgoogle/jax

branch : remove-test

created branch time in 22 days

push eventgoogle/jax

Roy Frostig

commit sha 16cd33049bb3d591549597a1a64211e506119da7

skip test that is broken at unreleased (source) jaxlib

view details

Roy Frostig

commit sha 1fb097bded4cfc68543f95012dd586881a651d60

remove stale version guard in test

view details

push time in 22 days

PR opened google/jax

test fixes in jax_jit_test
+5 -5

0 comment

1 changed file

pr created time in 22 days

create barnchgoogle/jax

branch : test-fixes

created branch time in 22 days

issue commentgoogle/jax

Support for variable length in jitted scans?

This is an intentional requirement of scan. The constant length is what enables support of reverse-mode autodiff that still compiles via jit. Variable-length loops are available as lax.while_loop and lax.fori_loop. These will jit and support forward-mode autodiff.

nrontsis

comment created time in 22 days

PullRequestReviewEvent

PR opened google/jax

early return for non-existent path prefixes
+2 -0

0 comment

1 changed file

pr created time in 22 days

create barnchgoogle/jax

branch : tb-util-path

created branch time in 22 days

PR opened google/jax

minor change to test source sync
+1 -0

0 comment

1 changed file

pr created time in 23 days

create barnchgoogle/jax

branch : minor-source-sync

created branch time in 23 days

push eventgoogle/jax

jax authors

commit sha 6eb4f54ececbfec36b419538109bc799bc769b46

Internal change PiperOrigin-RevId: 335699277

view details

push time in 23 days

PR closed google/jax

[jax2tf] Small change to the getting started Colab cla: yes pull ready
+1 -3

0 comment

1 changed file

gnecula

pr closed time in 23 days

create barnchgoogle/jax

branch : tycheck-invar-eqn-context

created branch time in 24 days

PR opened google/jax

skip test that fails due to known xla:cpu bug
+3 -0

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : lax-ad-test-skip

created branch time in a month

PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/jax

wrap long line

minor change, in order to test source sync.

+3 -1

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : long-line

created branch time in a month

PR opened google/jax

reduce test-case count of the numpy-dispatch CI check

... to match our other x64-mode CI check.

The numpy-dispatch check seems to be timing out sometimes.

+1 -1

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : numpy-dispatch-test-count

created branch time in a month

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

PR opened google/jax

mark lax traceables as entry points for filtered stack traces
+134 -5

0 comment

3 changed files

pr created time in a month

create barnchgoogle/jax

branch : lax-api-boundary

created branch time in a month

PullRequestReviewEvent

issue commentgoogle/jax

Differentiation rule for 'create_token'

As of #4330, you should no longer need to make that stop_gradient call yourself.

PhilipVinc

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

delete branch google/jax

delete branch : create-token-stop-grad

delete time in a month

push eventgoogle/jax

Roy Frostig

commit sha 16e60360940f2baa4f3bc4f94b58680af8295198

trivial change to test source sync PiperOrigin-RevId: 332544315

view details

push time in a month

PR opened google/jax

Reviewers
insert a stop_gradient in lax.create_token

see #4292

The data dependency forced on the operand is false. There's no need to follow it for derivatives, and doing so can lead to unexpected errors.

+1 -1

0 comment

1 changed file

pr created time in a month

create barnchgoogle/jax

branch : create-token-stop-grad

created branch time in a month

PullRequestReviewEvent

issue commentgoogle/jax

psum inconsistent across vmap/pmap context

related: #3970

froystig

comment created time in a month

issue openedgoogle/jax

psum inconsistent across vmap/pmap context

In particular, the psum primitive should behave as a broadcasting sum ("allreduce").

A fix would likely involve changing this rule, plus tests for (i) intended behavior and (ii) consistent behavior across pmap and vmap.

cc @apaszke @mattjj

created time in a month

issue commentgoogle/jax

Differentiation rule for 'create_token'

It's helpful to see the code example now, and in particular the token's dependence on x. If you replace the statement token = jax.lax.create_token(x) with:

token = jax.lax.create_token(jax.lax.stop_gradient(x))

I'd expect things to work as you intend. Do they?

PhilipVinc

comment created time in a month

pull request commentgoogle/jax

Implement jnp.ravel_multi_index()

I had core.concrete_or_error in mind, which is used elsewhere throughout lax_numpy. It allows tracers as long as they're concrete.

jakevdp

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent

delete branch google/jax

delete branch : rtdfix

delete time in a month

PR closed google/jax

fix docs build by adding requirement cla: yes

This is an attempt to resolve #4202

+1 -0

2 comments

1 changed file

froystig

pr closed time in a month

pull request commentgoogle/jax

fix docs build by adding requirement

It is no longer needed!

froystig

comment created time in a month

issue closedgoogle/jax

Broken docs check

The docs/readthedocs.org:jax check fails across PRs lately, i. e. here. The only error in the log is

ERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

myst-parser 0.12.8 requires docutils>=0.15, but you'll have docutils 0.14 which is incompatible.

closed time in a month

JuliusKunze

issue commentgoogle/jax

Broken docs check

I'm going to close this because it indeed seems to have been fixed.

JuliusKunze

comment created time in a month

issue commentgoogle/jax

Differentiation rule for 'create_token'

As you say, there isn't a derivative here, so we shouldn't define a false one. JAX allows differentiation with respect to only some function inputs. At the user level, this corresponds to passing argnums to functions such as grad. Does this resolve things within your use case?

PhilipVinc

comment created time in a month

PullRequestReviewEvent
PullRequestReviewEvent
PullRequestReviewEvent

issue commentgoogle/jax

Broken docs check

This looks like it may have been fixed (by RTD?) in the past 16+ hours, as the doc builds have started passing again. @shoyer, any thoughts on whether to add the requirement anyway? PR is ready, but I can discard it too.

JuliusKunze

comment created time in 2 months

PR opened google/jax

fix docs build by adding requirement

This is an attempt to resolve #4202

+1 -0

0 comment

1 changed file

pr created time in 2 months

create barnchgoogle/jax

branch : rtdfix

created branch time in 2 months

issue commentgoogle/jax

`jax.scipy.linalg.expm` causes an infinite loop inside two nested `fori_loop`/`scan`s.

Good point. We could return early in a jit-friendly way using lax.cond, but I'm wary of affecting performance and introducing new global flags. This might also obscure an underlying solve-related bug that will only reappear later.

From what I can tell, the "hanging loop" isn't one directly in our codebase. At some point we pass the inf/nan values over to a lower-level linalg routine. It may be the getrf in our lax.lu_p translation, which comes from lapack via jaxlib, or it may be an XLA solve op after that. It seems better to check and return early at whatever that more upstream location is.

C-J-Cundy

comment created time in 2 months

issue commentgoogle/jax

`jax.scipy.linalg.expm` causes an infinite loop inside two nested `fori_loop`/`scan`s.

I think this is expm hanging on invalid numeric input. The bug doesn't seem to require loop/scan. I've distilled the original example down to:

import numpy as onp
import scipy as osp
from jax import scipy as jsp

sp = jsp

def potential(x):
  W = onp.array([[0.0, x[0]], [x[1], 0.0]])
  return onp.trace(sp.linalg.expm(W * W))

def leapfrog(q):
  def scan_fun(q):
    V = potential(q)
    q -= V
    return q, (q, V)

  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)
  q, (q, V) = scan_fun(q)       # error happens here
  return q

leapfrog(onp.random.normal(size=2) * 3)

which hangs similarly. If I substitute sp = osp on line 5, I encounter:

Traceback (most recent call last):
  ...
  File ".../scipy/sparse/linalg/matfuncs.py", line 671, in _expm
    s = max(int(np.ceil(np.log2(eta_5 / theta_13))), 0)
ValueError: cannot convert float NaN to integer

The third invocation of potential passes it [-inf -inf]. Those inf values are handed to expm.

We could consider erring on input like this. A question is where to do so. Because expm reduces to matrix solve, I suspect that this attempts a solve on a terribly conditioned matrix. Maybe we should err in our solve routines, more generally.

Standard scipy's error isn't directly clear either, but at least the program halts.

Paging @shoyer who knows expm better. Thoughts?

C-J-Cundy

comment created time in 2 months

PullRequestReviewEvent
PullRequestReviewEvent

pull request commentgoogle/jax

Add NumPy backend

This is a large PR! It looks really promising to me still. It'll take us some time to review (whether it's @mattjj or me or both of us), but we'll try to give at least some partial feedback soon.

JuliusKunze

comment created time in 2 months

PullRequestReviewEvent

Pull request review commentgoogle/jax

single-operand cond

 def f_aug(*args):    return _make_typed_jaxpr(f_aug, jaxpr.in_avals) +def _join_cond_pe_staged_jaxpr_inputs(jaxprs, res_avals_per_jaxpr):+  newvar = core.gensym('~')     # TODO(frostig): safer gensym

Resolved by #3211

froystig

comment created time in 2 months

more