-
Notifications
You must be signed in to change notification settings - Fork 945
initial commit diagPart operator #1427
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the really long wait. I left comments related to the changed behavior of tf.diagPart. Also if you can sync with master that would be good.
Heads up, we are in the process of switching to a mono-repo, so you can either:
- try to finish this PR in the next few days
- or, you can reopen the PR in the tensorflow/tfjs repo once we move tfjs-core there.
Reviewed 1 of 7 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov and @kedevked)
src/kernels/backend_cpu.ts, line 2541 at r1 (raw file):
diagPart(x: Tensor): Tensor { const xVals = x.dataSync(); const buffer = ops.buffer([Math.sqrt(x.size)], x.dtype);
The size of the output is not necessarily the sqrt(size) since the last output dim is the minimum of the last two input dims: See https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/linalg/diag_part
src/ops/diagpart.ts, line 13 at r1 (raw file):
* input. * * Assume the input has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output
This is an outdated behavior of diag_part. See https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/linalg/diag_part for the new definition. For an input tensor of rank k, the output is rank k-1 (in other words, all the dimensions except the last two are treated as a batch)
Uh oh!
There was an error while loading. Please reload this page.
FEATURE
add operator diagPart
tensorflow/tfjs#655
Description
For repository owners only:
Please remember to apply all applicable tags to your pull request.
Tags: FEATURE, BREAKING, BUG, PERF, DEV, DOC, SECURITY
For more info see: https://github.com/tensorflow/tfjs/blob/master/DEVELOPMENT.md
This change is Reviewable