summaryrefslogtreecommitdiff
path: root/js/tf-core.js
diff options
context:
space:
mode:
Diffstat (limited to 'js/tf-core.js')
-rw-r--r--js/tf-core.js27144
1 files changed, 27144 insertions, 0 deletions
diff --git a/js/tf-core.js b/js/tf-core.js
new file mode 100644
index 0000000..75711cc
--- /dev/null
+++ b/js/tf-core.js
@@ -0,0 +1,27144 @@
+/**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+(function (global, factory) {
+ typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('crypto')) :
+ typeof define === 'function' && define.amd ? define(['exports', 'crypto'], factory) :
+ (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.tf = global.tf || {}, global.require$$0));
+}(this, (function (exports, require$$0) { 'use strict';
+
+ function _interopDefaultLegacy (e) { return e && typeof e === 'object' && 'default' in e ? e : { 'default': e }; }
+
+ var require$$0__default = /*#__PURE__*/_interopDefaultLegacy(require$$0);
+
+ /*! *****************************************************************************
+ Copyright (c) Microsoft Corporation.
+
+ Permission to use, copy, modify, and/or distribute this software for any
+ purpose with or without fee is hereby granted.
+
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
+ REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+ AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
+ INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+ LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
+ OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+ PERFORMANCE OF THIS SOFTWARE.
+ ***************************************************************************** */
+ /* global Reflect, Promise */
+ var extendStatics = function (d, b) {
+ extendStatics = Object.setPrototypeOf ||
+ ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
+ function (d, b) { for (var p in b)
+ if (b.hasOwnProperty(p))
+ d[p] = b[p]; };
+ return extendStatics(d, b);
+ };
+ function __extends(d, b) {
+ extendStatics(d, b);
+ function __() { this.constructor = d; }
+ d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
+ }
+ function __awaiter(thisArg, _arguments, P, generator) {
+ function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
+ return new (P || (P = Promise))(function (resolve, reject) {
+ function fulfilled(value) { try {
+ step(generator.next(value));
+ }
+ catch (e) {
+ reject(e);
+ } }
+ function rejected(value) { try {
+ step(generator["throw"](value));
+ }
+ catch (e) {
+ reject(e);
+ } }
+ function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
+ step((generator = generator.apply(thisArg, _arguments || [])).next());
+ });
+ }
+ function __generator(thisArg, body) {
+ var _ = { label: 0, sent: function () { if (t[0] & 1)
+ throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
+ return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g;
+ function verb(n) { return function (v) { return step([n, v]); }; }
+ function step(op) {
+ if (f)
+ throw new TypeError("Generator is already executing.");
+ while (_)
+ try {
+ if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done)
+ return t;
+ if (y = 0, t)
+ op = [op[0] & 2, t.value];
+ switch (op[0]) {
+ case 0:
+ case 1:
+ t = op;
+ break;
+ case 4:
+ _.label++;
+ return { value: op[1], done: false };
+ case 5:
+ _.label++;
+ y = op[1];
+ op = [0];
+ continue;
+ case 7:
+ op = _.ops.pop();
+ _.trys.pop();
+ continue;
+ default:
+ if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) {
+ _ = 0;
+ continue;
+ }
+ if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) {
+ _.label = op[1];
+ break;
+ }
+ if (op[0] === 6 && _.label < t[1]) {
+ _.label = t[1];
+ t = op;
+ break;
+ }
+ if (t && _.label < t[2]) {
+ _.label = t[2];
+ _.ops.push(op);
+ break;
+ }
+ if (t[2])
+ _.ops.pop();
+ _.trys.pop();
+ continue;
+ }
+ op = body.call(thisArg, _);
+ }
+ catch (e) {
+ op = [6, e];
+ y = 0;
+ }
+ finally {
+ f = t = 0;
+ }
+ if (op[0] & 5)
+ throw op[1];
+ return { value: op[0] ? op[1] : void 0, done: true };
+ }
+ }
+ function __values(o) {
+ var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0;
+ if (m)
+ return m.call(o);
+ if (o && typeof o.length === "number")
+ return {
+ next: function () {
+ if (o && i >= o.length)
+ o = void 0;
+ return { value: o && o[i++], done: !o };
+ }
+ };
+ throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined.");
+ }
+ function __read(o, n) {
+ var m = typeof Symbol === "function" && o[Symbol.iterator];
+ if (!m)
+ return o;
+ var i = m.call(o), r, ar = [], e;
+ try {
+ while ((n === void 0 || n-- > 0) && !(r = i.next()).done)
+ ar.push(r.value);
+ }
+ catch (error) {
+ e = { error: error };
+ }
+ finally {
+ try {
+ if (r && !r.done && (m = i["return"]))
+ m.call(i);
+ }
+ finally {
+ if (e)
+ throw e.error;
+ }
+ }
+ return ar;
+ }
+ function __spread() {
+ for (var ar = [], i = 0; i < arguments.length; i++)
+ ar = ar.concat(__read(arguments[i]));
+ return ar;
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var EPSILON_FLOAT32 = 1e-7;
+ var EPSILON_FLOAT16 = 1e-4;
+ /** Convenient class for storing tensor-related data. */
+ var DataStorage = /** @class */ (function () {
+ function DataStorage(backend, dataMover) {
+ this.backend = backend;
+ this.dataMover = dataMover;
+ this.data = new WeakMap();
+ this.dataIdsCount = 0;
+ }
+ DataStorage.prototype.get = function (dataId) {
+ if (!this.data.has(dataId)) {
+ this.dataMover.moveData(this.backend, dataId);
+ }
+ return this.data.get(dataId);
+ };
+ DataStorage.prototype.set = function (dataId, value) {
+ this.dataIdsCount++;
+ this.data.set(dataId, value);
+ };
+ DataStorage.prototype.has = function (dataId) {
+ return this.data.has(dataId);
+ };
+ DataStorage.prototype.delete = function (dataId) {
+ this.dataIdsCount--;
+ return this.data.delete(dataId);
+ };
+ DataStorage.prototype.numDataIds = function () {
+ return this.dataIdsCount;
+ };
+ return DataStorage;
+ }());
+ /**
+ * The interface that defines the kernels that should be implemented when
+ * adding a new backend. New backends don't need to implement every one of the
+ * methods, this can be done gradually (throw an error for unimplemented
+ * methods).
+ */
+ var KernelBackend = /** @class */ (function () {
+ function KernelBackend() {
+ }
+ KernelBackend.prototype.refCount = function (dataId) {
+ return notYetImplemented('refCount');
+ };
+ KernelBackend.prototype.incRef = function (dataId) {
+ return notYetImplemented('incRef');
+ };
+ KernelBackend.prototype.timerAvailable = function () {
+ return true;
+ };
+ KernelBackend.prototype.time = function (f) {
+ return notYetImplemented('time');
+ };
+ KernelBackend.prototype.read = function (dataId) {
+ return notYetImplemented('read');
+ };
+ KernelBackend.prototype.readSync = function (dataId) {
+ return notYetImplemented('readSync');
+ };
+ KernelBackend.prototype.numDataIds = function () {
+ return notYetImplemented('numDataIds');
+ };
+ KernelBackend.prototype.disposeData = function (dataId, force) {
+ return notYetImplemented('disposeData');
+ };
+ KernelBackend.prototype.write = function (values, shape, dtype) {
+ return notYetImplemented('write');
+ };
+ KernelBackend.prototype.move = function (dataId, values, shape, dtype, refCount) {
+ return notYetImplemented('move');
+ };
+ KernelBackend.prototype.memory = function () {
+ return notYetImplemented('memory');
+ };
+ /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
+ KernelBackend.prototype.floatPrecision = function () {
+ return notYetImplemented('floatPrecision');
+ };
+ /** Returns the smallest representable number. */
+ KernelBackend.prototype.epsilon = function () {
+ return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
+ };
+ KernelBackend.prototype.dispose = function () {
+ return notYetImplemented('dispose');
+ };
+ return KernelBackend;
+ }());
+ function notYetImplemented(kernelName) {
+ throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " +
+ "This kernel may not be supported by the tfjs backend you have chosen");
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Shuffles the array in-place using Fisher-Yates algorithm.
+ *
+ * ```js
+ * const a = [1, 2, 3, 4, 5];
+ * tf.util.shuffle(a);
+ * console.log(a);
+ * ```
+ *
+ * @param array The array to shuffle in-place.
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ // tslint:disable-next-line:no-any
+ function shuffle(array) {
+ var counter = array.length;
+ var index = 0;
+ // While there are elements in the array
+ while (counter > 0) {
+ // Pick a random index
+ index = (Math.random() * counter) | 0;
+ // Decrease counter by 1
+ counter--;
+ // And swap the last element with it
+ swap(array, counter, index);
+ }
+ }
+ /**
+ * Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
+ *
+ * ```js
+ * const a = [1,2,3,4,5];
+ * const b = [11,22,33,44,55];
+ * tf.util.shuffleCombo(a, b);
+ * console.log(a, b);
+ * ```
+ *
+ * @param array The first array to shuffle in-place.
+ * @param array2 The second array to shuffle in-place with the same permutation
+ * as the first array.
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function shuffleCombo(
+ // tslint:disable-next-line:no-any
+ array,
+ // tslint:disable-next-line:no-any
+ array2) {
+ if (array.length !== array2.length) {
+ throw new Error("Array sizes must match to be shuffled together " +
+ ("First array length was " + array.length) +
+ ("Second array length was " + array2.length));
+ }
+ var counter = array.length;
+ var index = 0;
+ // While there are elements in the array
+ while (counter > 0) {
+ // Pick a random index
+ index = (Math.random() * counter) | 0;
+ // Decrease counter by 1
+ counter--;
+ // And swap the last element of each array with it
+ swap(array, counter, index);
+ swap(array2, counter, index);
+ }
+ }
+ /** Clamps a value to a specified range. */
+ function clamp(min, x, max) {
+ return Math.max(min, Math.min(x, max));
+ }
+ function nearestLargerEven(val) {
+ return val % 2 === 0 ? val : val + 1;
+ }
+ function swap(object, left, right) {
+ var temp = object[left];
+ object[left] = object[right];
+ object[right] = temp;
+ }
+ function sum$1(arr) {
+ var sum = 0;
+ for (var i = 0; i < arr.length; i++) {
+ sum += arr[i];
+ }
+ return sum;
+ }
+ /**
+ * Returns a sample from a uniform [a, b) distribution.
+ *
+ * @param a The minimum support (inclusive).
+ * @param b The maximum support (exclusive).
+ * @return A pseudorandom number on the half-open interval [a,b).
+ */
+ function randUniform(a, b) {
+ var r = Math.random();
+ return (b * r) + (1 - r) * a;
+ }
+ /** Returns the squared Euclidean distance between two vectors. */
+ function distSquared(a, b) {
+ var result = 0;
+ for (var i = 0; i < a.length; i++) {
+ var diff = Number(a[i]) - Number(b[i]);
+ result += diff * diff;
+ }
+ return result;
+ }
+ /**
+ * Asserts that the expression is true. Otherwise throws an error with the
+ * provided message.
+ *
+ * ```js
+ * const x = 2;
+ * tf.util.assert(x === 2, 'x is not 2');
+ * ```
+ *
+ * @param expr The expression to assert (as a boolean).
+ * @param msg A function that returns the message to report when throwing an
+ * error. We use a function for performance reasons.
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function assert(expr, msg) {
+ if (!expr) {
+ throw new Error(typeof msg === 'string' ? msg : msg());
+ }
+ }
+ function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
+ if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; }
+ assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); });
+ }
+ function assertNonNull(a) {
+ assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; });
+ }
+ // NOTE: We explicitly type out what T extends instead of any so that
+ // util.flatten on a nested array of number doesn't try to infer T as a
+ // number[][], causing us to explicitly type util.flatten<number>().
+ /**
+ * Flattens an arbitrarily nested array.
+ *
+ * ```js
+ * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
+ * const flat = tf.util.flatten(a);
+ * console.log(flat);
+ * ```
+ *
+ * @param arr The nested array to flatten.
+ * @param result The destination array which holds the elements.
+ * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
+ * to false.
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function flatten(arr, result, skipTypedArray) {
+ if (result === void 0) { result = []; }
+ if (skipTypedArray === void 0) { skipTypedArray = false; }
+ if (result == null) {
+ result = [];
+ }
+ if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
+ for (var i = 0; i < arr.length; ++i) {
+ flatten(arr[i], result, skipTypedArray);
+ }
+ }
+ else {
+ result.push(arr);
+ }
+ return result;
+ }
+ /**
+ * Returns the size (number of elements) of the tensor given its shape.
+ *
+ * ```js
+ * const shape = [3, 4, 2];
+ * const size = tf.util.sizeFromShape(shape);
+ * console.log(size);
+ * ```
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function sizeFromShape(shape) {
+ if (shape.length === 0) {
+ // Scalar.
+ return 1;
+ }
+ var size = shape[0];
+ for (var i = 1; i < shape.length; i++) {
+ size *= shape[i];
+ }
+ return size;
+ }
+ function isScalarShape(shape) {
+ return shape.length === 0;
+ }
+ function arraysEqual(n1, n2) {
+ if (n1 === n2) {
+ return true;
+ }
+ if (n1 == null || n2 == null) {
+ return false;
+ }
+ if (n1.length !== n2.length) {
+ return false;
+ }
+ for (var i = 0; i < n1.length; i++) {
+ if (n1[i] !== n2[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+ function isInt(a) {
+ return a % 1 === 0;
+ }
+ function tanh$1(x) {
+ // tslint:disable-next-line:no-any
+ if (Math.tanh != null) {
+ // tslint:disable-next-line:no-any
+ return Math.tanh(x);
+ }
+ if (x === Infinity) {
+ return 1;
+ }
+ else if (x === -Infinity) {
+ return -1;
+ }
+ else {
+ var e2x = Math.exp(2 * x);
+ return (e2x - 1) / (e2x + 1);
+ }
+ }
+ function sizeToSquarishShape(size) {
+ var width = Math.ceil(Math.sqrt(size));
+ return [width, Math.ceil(size / width)];
+ }
+ /**
+ * Creates a new array with randomized indicies to a given quantity.
+ *
+ * ```js
+ * const randomTen = tf.util.createShuffledIndices(10);
+ * console.log(randomTen);
+ * ```
+ *
+ * @param number Quantity of how many shuffled indicies to create.
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function createShuffledIndices(n) {
+ var shuffledIndices = new Uint32Array(n);
+ for (var i = 0; i < n; ++i) {
+ shuffledIndices[i] = i;
+ }
+ shuffle(shuffledIndices);
+ return shuffledIndices;
+ }
+ function rightPad(a, size) {
+ if (size <= a.length) {
+ return a;
+ }
+ return a + ' '.repeat(size - a.length);
+ }
+ function repeatedTry(checkFn, delayFn, maxCounter) {
+ if (delayFn === void 0) { delayFn = function (counter) { return 0; }; }
+ return new Promise(function (resolve, reject) {
+ var tryCount = 0;
+ var tryFn = function () {
+ if (checkFn()) {
+ resolve();
+ return;
+ }
+ tryCount++;
+ var nextBackoff = delayFn(tryCount);
+ if (maxCounter != null && tryCount >= maxCounter) {
+ reject();
+ return;
+ }
+ setTimeout(tryFn, nextBackoff);
+ };
+ tryFn();
+ });
+ }
+ /**
+ * Given the full size of the array and a shape that may contain -1 as the
+ * implicit dimension, returns the inferred shape where -1 is replaced.
+ * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
+ *
+ * @param shape The shape, which may contain -1 in some dimension.
+ * @param size The full size (number of elements) of the array.
+ * @return The inferred shape where -1 is replaced with the inferred size.
+ */
+ function inferFromImplicitShape(shape, size) {
+ var shapeProd = 1;
+ var implicitIdx = -1;
+ for (var i = 0; i < shape.length; ++i) {
+ if (shape[i] >= 0) {
+ shapeProd *= shape[i];
+ }
+ else if (shape[i] === -1) {
+ if (implicitIdx !== -1) {
+ throw Error("Shapes can only have 1 implicit size. " +
+ ("Found -1 at dim " + implicitIdx + " and dim " + i));
+ }
+ implicitIdx = i;
+ }
+ else if (shape[i] < 0) {
+ throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
+ }
+ }
+ if (implicitIdx === -1) {
+ if (size > 0 && size !== shapeProd) {
+ throw Error("Size(" + size + ") must match the product of shape " + shape);
+ }
+ return shape;
+ }
+ if (shapeProd === 0) {
+ throw Error("Cannot infer the missing size in [" + shape + "] when " +
+ "there are 0 elements");
+ }
+ if (size % shapeProd !== 0) {
+ throw Error("The implicit shape can't be a fractional number. " +
+ ("Got " + size + " / " + shapeProd));
+ }
+ var newShape = shape.slice();
+ newShape[implicitIdx] = size / shapeProd;
+ return newShape;
+ }
+ function parseAxisParam(axis, shape) {
+ var rank = shape.length;
+ // Normalize input
+ axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis);
+ // Check for valid range
+ assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " +
+ ("got axis " + axis); });
+ // Check for only integers
+ assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " +
+ ("got axis " + axis); });
+ // Handle negative axis.
+ return axis.map(function (a) { return a < 0 ? rank + a : a; });
+ }
+ /** Reduces the shape by removing all dimensions of shape 1. */
+ function squeezeShape(shape, axis) {
+ var newShape = [];
+ var keptDims = [];
+ var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
+ var axes = (axis == null || isEmptyArray) ?
+ null :
+ parseAxisParam(axis, shape).sort();
+ var j = 0;
+ for (var i = 0; i < shape.length; ++i) {
+ if (axes != null) {
+ if (axes[j] === i && shape[i] !== 1) {
+ throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
+ }
+ if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
+ newShape.push(shape[i]);
+ keptDims.push(i);
+ }
+ if (axes[j] <= i) {
+ j++;
+ }
+ }
+ if (shape[i] !== 1) {
+ newShape.push(shape[i]);
+ keptDims.push(i);
+ }
+ }
+ return { newShape: newShape, keptDims: keptDims };
+ }
+ function getTypedArrayFromDType(dtype, size) {
+ var values = null;
+ if (dtype == null || dtype === 'float32') {
+ values = new Float32Array(size);
+ }
+ else if (dtype === 'int32') {
+ values = new Int32Array(size);
+ }
+ else if (dtype === 'bool') {
+ values = new Uint8Array(size);
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ return values;
+ }
+ function getArrayFromDType(dtype, size) {
+ var values = null;
+ if (dtype == null || dtype === 'float32') {
+ values = new Float32Array(size);
+ }
+ else if (dtype === 'int32') {
+ values = new Int32Array(size);
+ }
+ else if (dtype === 'bool') {
+ values = new Uint8Array(size);
+ }
+ else if (dtype === 'string') {
+ values = new Array(size);
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ return values;
+ }
+ function checkConversionForErrors(vals, dtype) {
+ for (var i = 0; i < vals.length; i++) {
+ var num = vals[i];
+ if (isNaN(num) || !isFinite(num)) {
+ throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
+ }
+ }
+ }
+ /** Returns true if the dtype is valid. */
+ function isValidDtype(dtype) {
+ return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
+ dtype === 'int32' || dtype === 'string';
+ }
+ /**
+ * Returns true if the new type can't encode the old type without loss of
+ * precision.
+ */
+ function hasEncodingLoss(oldType, newType) {
+ if (newType === 'complex64') {
+ return false;
+ }
+ if (newType === 'float32' && oldType !== 'complex64') {
+ return false;
+ }
+ if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
+ return false;
+ }
+ if (newType === 'bool' && oldType === 'bool') {
+ return false;
+ }
+ return true;
+ }
+ function isTypedArray(a) {
+ return a instanceof Float32Array || a instanceof Int32Array ||
+ a instanceof Uint8Array || a instanceof Uint8ClampedArray;
+ }
+ function bytesPerElement(dtype) {
+ if (dtype === 'float32' || dtype === 'int32') {
+ return 4;
+ }
+ else if (dtype === 'complex64') {
+ return 8;
+ }
+ else if (dtype === 'bool') {
+ return 1;
+ }
+ else {
+ throw new Error("Unknown dtype " + dtype);
+ }
+ }
+ /**
+ * Returns the approximate number of bytes allocated in the string array - 2
+ * bytes per character. Computing the exact bytes for a native string in JS is
+ * not possible since it depends on the encoding of the html page that serves
+ * the website.
+ */
+ function bytesFromStringArray(arr) {
+ if (arr == null) {
+ return 0;
+ }
+ var bytes = 0;
+ arr.forEach(function (x) { return bytes += x.length; });
+ return bytes;
+ }
+ /** Returns true if the value is a string. */
+ function isString(value) {
+ return typeof value === 'string' || value instanceof String;
+ }
+ function isBoolean(value) {
+ return typeof value === 'boolean';
+ }
+ function isNumber(value) {
+ return typeof value === 'number';
+ }
+ function inferDtype(values) {
+ if (Array.isArray(values)) {
+ return inferDtype(values[0]);
+ }
+ if (values instanceof Float32Array) {
+ return 'float32';
+ }
+ else if (values instanceof Int32Array
+ || values instanceof Uint8Array
+ || values instanceof Uint8ClampedArray) {
+ return 'int32';
+ }
+ else if (isNumber(values)) {
+ return 'float32';
+ }
+ else if (isString(values)) {
+ return 'string';
+ }
+ else if (isBoolean(values)) {
+ return 'bool';
+ }
+ return 'float32';
+ }
+ function isFunction(f) {
+ return !!(f && f.constructor && f.call && f.apply);
+ }
+ function nearestDivisor(size, start) {
+ for (var i = start; i < size; ++i) {
+ if (size % i === 0) {
+ return i;
+ }
+ }
+ return size;
+ }
+ function computeStrides(shape) {
+ var rank = shape.length;
+ if (rank < 2) {
+ return [];
+ }
+ // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
+ // strides.
+ var strides = new Array(rank - 1);
+ strides[rank - 2] = shape[rank - 1];
+ for (var i = rank - 3; i >= 0; --i) {
+ strides[i] = strides[i + 1] * shape[i + 1];
+ }
+ return strides;
+ }
+ function createNestedArray(offset, shape, a, isComplex) {
+ if (isComplex === void 0) { isComplex = false; }
+ var ret = new Array();
+ if (shape.length === 1) {
+ var d = shape[0] * (isComplex ? 2 : 1);
+ for (var i = 0; i < d; i++) {
+ ret[i] = a[offset + i];
+ }
+ }
+ else {
+ var d = shape[0];
+ var rest = shape.slice(1);
+ var len = rest.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1);
+ for (var i = 0; i < d; i++) {
+ ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
+ }
+ }
+ return ret;
+ }
+ // Provide a nested array of TypedArray in given shape.
+ function toNestedArray(shape, a, isComplex) {
+ if (isComplex === void 0) { isComplex = false; }
+ if (shape.length === 0) {
+ // Scalar type should return a single number.
+ return a[0];
+ }
+ var size = shape.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1);
+ if (size === 0) {
+ // A tensor with shape zero should be turned into empty list.
+ return [];
+ }
+ if (size !== a.length) {
+ throw new Error("[" + shape + "] does not match the input size " + a.length + (isComplex ? ' for a complex tensor' : '') + ".");
+ }
+ return createNestedArray(0, shape, a, isComplex);
+ }
+ function makeOnesTypedArray(size, dtype) {
+ var array = makeZerosTypedArray(size, dtype);
+ for (var i = 0; i < array.length; i++) {
+ array[i] = 1;
+ }
+ return array;
+ }
+ function makeZerosTypedArray(size, dtype) {
+ if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
+ return new Float32Array(size);
+ }
+ else if (dtype === 'int32') {
+ return new Int32Array(size);
+ }
+ else if (dtype === 'bool') {
+ return new Uint8Array(size);
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ }
+ /**
+ * Make nested `TypedArray` filled with zeros.
+ * @param shape The shape information for the nested array.
+ * @param dtype dtype of the array element.
+ */
+ function makeZerosNestedTypedArray(shape, dtype) {
+ var size = shape.reduce(function (prev, curr) { return prev * curr; }, 1);
+ if (dtype == null || dtype === 'float32') {
+ return toNestedArray(shape, new Float32Array(size));
+ }
+ else if (dtype === 'int32') {
+ return toNestedArray(shape, new Int32Array(size));
+ }
+ else if (dtype === 'bool') {
+ return toNestedArray(shape, new Uint8Array(size));
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ }
+ function assertNonNegativeIntegerDimensions(shape) {
+ shape.forEach(function (dimSize) {
+ assert(Number.isInteger(dimSize) && dimSize >= 0, function () { return "Tensor must have a shape comprised of positive integers but got " +
+ ("shape [" + shape + "]."); });
+ });
+ }
+ /**
+ * Computes flat index for a given location (multidimentionsal index) in a
+ * Tensor/multidimensional array.
+ *
+ * @param locs Location in the tensor.
+ * @param rank Rank of the tensor.
+ * @param strides Tensor strides.
+ */
+ function locToIndex(locs, rank, strides) {
+ if (rank === 0) {
+ return 0;
+ }
+ else if (rank === 1) {
+ return locs[0];
+ }
+ var index = locs[locs.length - 1];
+ for (var i = 0; i < locs.length - 1; ++i) {
+ index += strides[i] * locs[i];
+ }
+ return index;
+ }
+ /**
+ * Computes the location (multidimensional index) in a tensor/multidimentional
+ * array for a given flat index.
+ *
+ * @param index Index in flat array.
+ * @param rank Rank of tensor.
+ * @param strides Strides of tensor.
+ */
+ function indexToLoc(index, rank, strides) {
+ if (rank === 0) {
+ return [];
+ }
+ else if (rank === 1) {
+ return [index];
+ }
+ var locs = new Array(rank);
+ for (var i = 0; i < locs.length - 1; ++i) {
+ locs[i] = Math.floor(index / strides[i]);
+ index -= locs[i] * strides[i];
+ }
+ locs[locs.length - 1] = index;
+ return locs;
+ }
+ /**
+ * This method asserts whether an object is a Promise instance.
+ * @param object
+ */
+ // tslint:disable-next-line: no-any
+ function isPromise(object) {
+ // We chose to not use 'obj instanceOf Promise' for two reasons:
+ // 1. It only reliably works for es6 Promise, not other Promise
+ // implementations.
+ // 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
+ // the async calls, so it is possible the obj (patched) is comparing to a
+ // pre-patched Promise.
+ return object && object.then && typeof object.then === 'function';
+ }
+
+ // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
+ var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
+ /**
+ * The environment contains evaluated flags as well as the registered platform.
+ * This is always used as a global singleton and can be retrieved with
+ * `tf.env()`.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ var Environment = /** @class */ (function () {
+ // tslint:disable-next-line: no-any
+ function Environment(global) {
+ this.global = global;
+ this.flags = {};
+ this.flagRegistry = {};
+ this.urlFlags = {};
+ // Jasmine spies on this in 'environment_test.ts'
+ this.getQueryParams = getQueryParams;
+ this.populateURLFlags();
+ }
+ Environment.prototype.setPlatform = function (platformName, platform) {
+ if (this.platform != null) {
+ if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
+ console.warn("Platform " + this.platformName + " has already been set. " +
+ ("Overwriting the platform with " + platform + "."));
+ }
+ }
+ this.platformName = platformName;
+ this.platform = platform;
+ };
+ Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) {
+ this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook };
+ // Override the flag value from the URL. This has to happen here because
+ // the environment is initialized before flags get registered.
+ if (this.urlFlags[flagName] != null) {
+ var flagValue = this.urlFlags[flagName];
+ if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
+ console.warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
+ }
+ this.set(flagName, flagValue);
+ }
+ };
+ Environment.prototype.getAsync = function (flagName) {
+ return __awaiter(this, void 0, void 0, function () {
+ var _a, _b;
+ return __generator(this, function (_c) {
+ switch (_c.label) {
+ case 0:
+ if (flagName in this.flags) {
+ return [2 /*return*/, this.flags[flagName]];
+ }
+ _a = this.flags;
+ _b = flagName;
+ return [4 /*yield*/, this.evaluateFlag(flagName)];
+ case 1:
+ _a[_b] = _c.sent();
+ return [2 /*return*/, this.flags[flagName]];
+ }
+ });
+ });
+ };
+ Environment.prototype.get = function (flagName) {
+ if (flagName in this.flags) {
+ return this.flags[flagName];
+ }
+ var flagValue = this.evaluateFlag(flagName);
+ if (isPromise(flagValue)) {
+ throw new Error("Flag " + flagName + " cannot be synchronously evaluated. " +
+ "Please use getAsync() instead.");
+ }
+ this.flags[flagName] = flagValue;
+ return this.flags[flagName];
+ };
+ Environment.prototype.getNumber = function (flagName) {
+ return this.get(flagName);
+ };
+ Environment.prototype.getBool = function (flagName) {
+ return this.get(flagName);
+ };
+ Environment.prototype.getFlags = function () {
+ return this.flags;
+ };
+ Object.defineProperty(Environment.prototype, "features", {
+ // For backwards compatibility.
+ get: function () {
+ return this.flags;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Environment.prototype.set = function (flagName, value) {
+ if (this.flagRegistry[flagName] == null) {
+ throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
+ }
+ this.flags[flagName] = value;
+ if (this.flagRegistry[flagName].setHook != null) {
+ this.flagRegistry[flagName].setHook(value);
+ }
+ };
+ Environment.prototype.evaluateFlag = function (flagName) {
+ if (this.flagRegistry[flagName] == null) {
+ throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
+ }
+ return this.flagRegistry[flagName].evaluationFn();
+ };
+ Environment.prototype.setFlags = function (flags) {
+ this.flags = Object.assign({}, flags);
+ };
+ Environment.prototype.reset = function () {
+ this.flags = {};
+ this.urlFlags = {};
+ this.populateURLFlags();
+ };
+ Environment.prototype.populateURLFlags = function () {
+ var _this = this;
+ if (typeof this.global === 'undefined' ||
+ typeof this.global.location === 'undefined' ||
+ typeof this.global.location.search === 'undefined') {
+ return;
+ }
+ var urlParams = this.getQueryParams(this.global.location.search);
+ if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
+ var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
+ keyValues.forEach(function (keyValue) {
+ var _a = __read(keyValue.split(':'), 2), key = _a[0], value = _a[1];
+ _this.urlFlags[key] = parseValue(key, value);
+ });
+ }
+ };
+ return Environment;
+ }());
+ function getQueryParams(queryString) {
+ var params = {};
+ queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) {
+ var t = [];
+ for (var _i = 1; _i < arguments.length; _i++) {
+ t[_i - 1] = arguments[_i];
+ }
+ decodeParam(params, t[0], t[1]);
+ return t.join('=');
+ });
+ return params;
+ }
+ function decodeParam(params, name, value) {
+ params[decodeURIComponent(name)] = decodeURIComponent(value || '');
+ }
+ function parseValue(flagName, value) {
+ value = value.toLowerCase();
+ if (value === 'true' || value === 'false') {
+ return value === 'true';
+ }
+ else if ("" + +value === value) {
+ return +value;
+ }
+ throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
+ }
+ /**
+ * Returns the current environment (a global singleton).
+ *
+ * The environment object contains the evaluated feature values as well as the
+ * active platform.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ function env() {
+ return exports.ENV;
+ }
+ exports.ENV = null;
+ function setEnvironmentGlobal(environment) {
+ exports.ENV = environment;
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // Note that the identifier globalNameSpace is scoped to this module, but will
+ // always resolve to the same global object regardless of how the module is
+ // resolved.
+ // tslint:disable-next-line:no-any
+ var globalNameSpace;
+ // tslint:disable-next-line:no-any
+ function getGlobalNamespace() {
+ if (globalNameSpace == null) {
+ // tslint:disable-next-line:no-any
+ var ns = void 0;
+ if (typeof (window) !== 'undefined') {
+ ns = window;
+ }
+ else if (typeof (global) !== 'undefined') {
+ ns = global;
+ }
+ else if (typeof (process) !== 'undefined') {
+ ns = process;
+ }
+ else if (typeof (self) !== 'undefined') {
+ ns = self;
+ }
+ else {
+ throw new Error('Could not find a global object');
+ }
+ globalNameSpace = ns;
+ }
+ return globalNameSpace;
+ }
+ // tslint:disable-next-line:no-any
+ function getGlobalMap() {
+ var ns = getGlobalNamespace();
+ if (ns._tfGlobals == null) {
+ ns._tfGlobals = new Map();
+ }
+ return ns._tfGlobals;
+ }
+ /**
+ * Returns a globally accessible 'singleton' object.
+ *
+ * @param key the name of the object
+ * @param init a function to initialize to initialize this object
+ * the first time it is fetched.
+ */
+ function getGlobal(key, init) {
+ var globalMap = getGlobalMap();
+ if (globalMap.has(key)) {
+ return globalMap.get(key);
+ }
+ else {
+ var singleton = init();
+ globalMap.set(key, singleton);
+ return globalMap.get(key);
+ }
+ }
+
+ var Abs = 'Abs';
+ var Acos = 'Acos';
+ var Acosh = 'Acosh';
+ var Add = 'Add';
+ var AddN = 'AddN';
+ var All = 'All';
+ var Any = 'Any';
+ var ArgMax = 'ArgMax';
+ var ArgMin = 'ArgMin';
+ var Asin = 'Asin';
+ var Asinh = 'Asinh';
+ var Atan = 'Atan';
+ var Atanh = 'Atanh';
+ var Atan2 = 'Atan2';
+ var AvgPool = 'AvgPool';
+ var AvgPoolGrad = 'AvgPoolGrad';
+ var AvgPool3D = 'AvgPool3D';
+ var AvgPool3DGrad = 'AvgPool3DGrad';
+ var BatchMatMul = 'BatchMatMul';
+ var BatchToSpaceND = 'BatchToSpaceND';
+ var Bincount = 'Bincount';
+ var BroadcastTo = 'BroadcastTo';
+ var BroadcastArgs = 'BroadcastArgs';
+ var Cast = 'Cast';
+ var Ceil = 'Ceil';
+ var ClipByValue = 'ClipByValue';
+ var Complex = 'Complex';
+ var ComplexAbs = 'ComplexAbs';
+ var Concat = 'Concat';
+ var Conv2D = 'Conv2D';
+ var Conv2DBackpropFilter = 'Conv2DBackpropFilter';
+ var Conv2DBackpropInput = 'Conv2DBackpropInput';
+ var Conv3D = 'Conv3D';
+ var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
+ var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
+ var Cos = 'Cos';
+ var Cosh = 'Cosh';
+ var Cumsum = 'Cumsum';
+ var CropAndResize = 'CropAndResize';
+ var DenseBincount = 'DenseBincount';
+ var DepthToSpace = 'DepthToSpace';
+ var DepthwiseConv2dNative = 'DepthwiseConv2dNative';
+ var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
+ var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
+ var Diag = 'Diag';
+ var Dilation2D = 'Dilation2D';
+ var Dilation2DBackpropInput = 'Dilation2DBackpropInput';
+ var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
+ var RealDiv = 'RealDiv';
+ var Einsum = 'Einsum';
+ var Elu = 'Elu';
+ var EluGrad = 'EluGrad';
+ var Erf = 'Erf';
+ var Equal = 'Equal';
+ var Exp = 'Exp';
+ var ExpandDims = 'ExpandDims';
+ var Expm1 = 'Expm1';
+ var FFT = 'FFT';
+ var Fill = 'Fill';
+ var FlipLeftRight = 'FlipLeftRight';
+ var Floor = 'Floor';
+ var FloorDiv = 'FloorDiv';
+ var FusedBatchNorm = 'FusedBatchNorm';
+ var GatherV2 = 'GatherV2';
+ var GatherNd = 'GatherNd';
+ var Greater = 'Greater';
+ var GreaterEqual = 'GreaterEqual';
+ var Identity = 'Identity';
+ var IFFT = 'IFFT';
+ var Imag = 'Imag';
+ var IsFinite = 'IsFinite';
+ var IsInf = 'IsInf';
+ var IsNan = 'IsNan';
+ var LeakyRelu = 'LeakyRelu';
+ var Less = 'Less';
+ var LessEqual = 'LessEqual';
+ var LinSpace = 'LinSpace';
+ var Log = 'Log';
+ var Log1p = 'Log1p';
+ var LogicalAnd = 'LogicalAnd';
+ var LogicalNot = 'LogicalNot';
+ var LogicalOr = 'LogicalOr';
+ var LogSoftmax = 'LogSoftmax';
+ var LRN = 'LRN';
+ var LRNGrad = 'LRNGrad';
+ var Max = 'Max';
+ var Maximum = 'Maximum';
+ var MaxPool = 'MaxPool';
+ var MaxPoolGrad = 'MaxPoolGrad';
+ var MaxPool3D = 'MaxPool3D';
+ var MaxPool3DGrad = 'MaxPool3DGrad';
+ var MaxPoolWithArgmax = 'MaxPoolWithArgmax';
+ var Mean = 'Mean';
+ var Min = 'Min';
+ var Minimum = 'Minimum';
+ var MirrorPad = 'MirrorPad';
+ var Mod = 'Mod';
+ var Multinomial = 'Multinomial';
+ var Multiply = 'Multiply';
+ var Neg = 'Neg';
+ var NotEqual = 'NotEqual';
+ var NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
+ var NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
+ var NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
+ var OnesLike = 'OnesLike';
+ var OneHot = 'OneHot';
+ var Pack = 'Pack';
+ var PadV2 = 'PadV2';
+ var Pool = 'Pool';
+ var Pow = 'Pow';
+ var Prelu = 'Prelu';
+ var Prod = 'Prod';
+ var Range = 'Range';
+ var Real = 'Real';
+ var Reciprocal = 'Reciprocal';
+ var Relu = 'Relu';
+ var Reshape = 'Reshape';
+ var ResizeNearestNeighbor = 'ResizeNearestNeighbor';
+ var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
+ var ResizeBilinear = 'ResizeBilinear';
+ var ResizeBilinearGrad = 'ResizeBilinearGrad';
+ var Relu6 = 'Relu6';
+ var Reverse = 'Reverse';
+ var Round = 'Round';
+ var Rsqrt = 'Rsqrt';
+ var ScatterNd = 'ScatterNd';
+ var Select = 'Select';
+ var Selu = 'Selu';
+ var Slice = 'Slice';
+ var Sin = 'Sin';
+ var Sinh = 'Sinh';
+ var Sign = 'Sign';
+ var Sigmoid = 'Sigmoid';
+ var Softplus = 'Softplus';
+ var Sqrt = 'Sqrt';
+ var Sum = 'Sum';
+ var SpaceToBatchND = 'SpaceToBatchND';
+ var SplitV = 'SplitV';
+ var Softmax = 'Softmax';
+ var SparseFillEmptyRows = 'SparseFillEmptyRows';
+ var SparseReshape = 'SparseReshape';
+ var SparseSegmentMean = 'SparseSegmentMean';
+ var SparseSegmentSum = 'SparseSegmentSum';
+ var SparseToDense = 'SparseToDense';
+ var SquaredDifference = 'SquaredDifference';
+ var Square = 'Square';
+ var StridedSlice = 'StridedSlice';
+ var StringNGrams = 'StringNGrams';
+ var StringSplit = 'StringSplit';
+ var StringToHashBucketFast = 'StringToHashBucketFast';
+ var Sub = 'Sub';
+ var Tan = 'Tan';
+ var Tanh = 'Tanh';
+ var Tile = 'Tile';
+ var TopK = 'TopK';
+ var Transform = 'Transform';
+ var Transpose = 'Transpose';
+ var Unique = 'Unique';
+ var Unpack = 'Unpack';
+ var UnsortedSegmentSum = 'UnsortedSegmentSum';
+ var ZerosLike = 'ZerosLike';
+ /**
+ * TensorFlow.js-only kernels
+ */
+ var Step = 'Step';
+ var FromPixels = 'FromPixels';
+ var RotateWithOffset = 'RotateWithOffset';
+ var _FusedMatMul = '_FusedMatMul';
+ var FusedConv2D = 'FusedConv2D';
+ var FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
+
+ function warn() {
+ var msg = [];
+ for (var _i = 0; _i < arguments.length; _i++) {
+ msg[_i] = arguments[_i];
+ }
+ if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
+ console.warn.apply(console, __spread(msg));
+ }
+ }
+ function log$1() {
+ var msg = [];
+ for (var _i = 0; _i < arguments.length; _i++) {
+ msg[_i] = arguments[_i];
+ }
+ if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
+ console.log.apply(console, __spread(msg));
+ }
+ }
+
+ var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); });
+ var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); });
+ /**
+ * Returns the kernel function (code) associated with the provided names.
+ *
+ * @param kernelName The official name of the kernel.
+ * @param backendName The official name of the backend.
+ */
+ function getKernel(kernelName, backendName) {
+ var key = makeKey(kernelName, backendName);
+ return kernelRegistry.get(key);
+ }
+ /**
+ * Returns the registered gradient info associated with the provided kernel.
+ * @param kernelName The official TF kernel name.
+ */
+ function getGradient(kernelName) {
+ return gradRegistry.get(kernelName);
+ }
+ function getKernelsForBackend(backendName) {
+ var it = kernelRegistry.entries();
+ var result = [];
+ while (true) {
+ var _a = it.next(), done = _a.done, value = _a.value;
+ if (done) {
+ break;
+ }
+ var _b = __read(value, 2), key = _b[0], config = _b[1];
+ var _c = __read(key.split('_'), 1), backend = _c[0];
+ if (backend === backendName) {
+ result.push(config);
+ }
+ }
+ return result;
+ }
+ /**
+ * Registers the function (forward pass) for the kernel in a global registry.
+ *
+ * @param config A config object with the following properties:
+ * - `kernelName` The official name of the kernel.
+ * - `backendName` The official name of the backend.
+ * - `kernelFunc` The function to run during the forward pass of the kernel.
+ * - `setupFunc` Optional. Gets called once, after the backend initializes.
+ * - `disposeFunc` Optional. Gets called once, right before the backend is
+ * disposed.
+ */
+ function registerKernel(config) {
+ var kernelName = config.kernelName, backendName = config.backendName;
+ var key = makeKey(kernelName, backendName);
+ if (kernelRegistry.has(key)) {
+ warn("The kernel '" + kernelName + "' for backend " +
+ ("'" + backendName + "' is already registered"));
+ }
+ kernelRegistry.set(key, config);
+ }
+ /**
+ * Registers a gradient function for a given kernel in the global registry,
+ * to be used during the back-propagation of that kernel.
+ *
+ * @param config An object with the following properties:
+ * - `kernelName` The name of the kernel that the gradient function is for.
+ * - `gradFunc` The function to run during back-propagation.
+ */
+ function registerGradient(config) {
+ var kernelName = config.kernelName;
+ if (gradRegistry.has(kernelName)) {
+ // TODO (yassogba) after 3.0 assess whether we need to keep this gated
+ // to debug mode.
+ if (env().getBool('DEBUG')) {
+ warn("Overriding the gradient for '" + kernelName + "'");
+ }
+ }
+ gradRegistry.set(kernelName, config);
+ }
+ /**
+ * Removes the kernel function from the registry.
+ *
+ * @param kernelName The official name of the kernel.
+ * @param backendName The official name of the backend.
+ *
+ */
+ function unregisterKernel(kernelName, backendName) {
+ var key = makeKey(kernelName, backendName);
+ if (!kernelRegistry.has(key)) {
+ throw new Error("The kernel '" + kernelName + "' for backend " +
+ ("'" + backendName + "' is not registered"));
+ }
+ kernelRegistry.delete(key);
+ }
+ /** Removes the registered gradient from the global registry. */
+ function unregisterGradient(kernelName) {
+ if (!gradRegistry.has(kernelName)) {
+ throw new Error("The gradient '" + kernelName + "' for backend is not registered");
+ }
+ gradRegistry.delete(kernelName);
+ }
+ /**
+ * Finds kernels that have already been registered to a backend and re-registers
+ * them for a new backend. Useful for registering custom backends.
+ * @param registeredBackendName Already registered backend.
+ * @param newBackendName New backend.
+ */
+ function copyRegisteredKernels(registeredBackendName, newBackendName) {
+ var kernels = getKernelsForBackend(registeredBackendName);
+ kernels.forEach(function (kernelConfig) {
+ var newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName });
+ registerKernel(newKernelConfig);
+ });
+ }
+ function makeKey(kernelName, backendName) {
+ return backendName + "_" + kernelName;
+ }
+
+ var long = Long$1;
+ /**
+ * wasm optimizations, to do native i64 multiplication and divide
+ */
+ var wasm = null;
+ try {
+ wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
+ 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
+ ])), {}).exports;
+ }
+ catch (e) {
+ // no wasm support :(
+ }
+ /**
+ * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
+ * See the from* functions below for more convenient ways of constructing Longs.
+ * @exports Long
+ * @class A Long class for representing a 64 bit two's-complement integer value.
+ * @param {number} low The low (signed) 32 bits of the long
+ * @param {number} high The high (signed) 32 bits of the long
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @constructor
+ */
+ function Long$1(low, high, unsigned) {
+ /**
+ * The low 32 bits as a signed value.
+ * @type {number}
+ */
+ this.low = low | 0;
+ /**
+ * The high 32 bits as a signed value.
+ * @type {number}
+ */
+ this.high = high | 0;
+ /**
+ * Whether unsigned or not.
+ * @type {boolean}
+ */
+ this.unsigned = !!unsigned;
+ }
+ // The internal representation of a long is the two given signed, 32-bit values.
+ // We use 32-bit pieces because these are the size of integers on which
+ // Javascript performs bit-operations. For operations like addition and
+ // multiplication, we split each number into 16 bit pieces, which can easily be
+ // multiplied within Javascript's floating-point representation without overflow
+ // or change in sign.
+ //
+ // In the algorithms below, we frequently reduce the negative case to the
+ // positive case by negating the input(s) and then post-processing the result.
+ // Note that we must ALWAYS check specially whether those values are MIN_VALUE
+ // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
+ // a positive number, it overflows back into a negative). Not handling this
+ // case would often result in infinite recursion.
+ //
+ // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
+ // methods on which they depend.
+ /**
+ * An indicator used to reliably determine if an object is a Long or not.
+ * @type {boolean}
+ * @const
+ * @private
+ */
+ Long$1.prototype.__isLong__;
+ Object.defineProperty(Long$1.prototype, "__isLong__", { value: true });
+ /**
+ * @function
+ * @param {*} obj Object
+ * @returns {boolean}
+ * @inner
+ */
+ function isLong(obj) {
+ return (obj && obj["__isLong__"]) === true;
+ }
+ /**
+ * Tests if the specified object is a Long.
+ * @function
+ * @param {*} obj Object
+ * @returns {boolean}
+ */
+ Long$1.isLong = isLong;
+ /**
+ * A cache of the Long representations of small integer values.
+ * @type {!Object}
+ * @inner
+ */
+ var INT_CACHE = {};
+ /**
+ * A cache of the Long representations of small unsigned integer values.
+ * @type {!Object}
+ * @inner
+ */
+ var UINT_CACHE = {};
+ /**
+ * @param {number} value
+ * @param {boolean=} unsigned
+ * @returns {!Long}
+ * @inner
+ */
+ function fromInt(value, unsigned) {
+ var obj, cachedObj, cache;
+ if (unsigned) {
+ value >>>= 0;
+ if (cache = (0 <= value && value < 256)) {
+ cachedObj = UINT_CACHE[value];
+ if (cachedObj)
+ return cachedObj;
+ }
+ obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
+ if (cache)
+ UINT_CACHE[value] = obj;
+ return obj;
+ }
+ else {
+ value |= 0;
+ if (cache = (-128 <= value && value < 128)) {
+ cachedObj = INT_CACHE[value];
+ if (cachedObj)
+ return cachedObj;
+ }
+ obj = fromBits(value, value < 0 ? -1 : 0, false);
+ if (cache)
+ INT_CACHE[value] = obj;
+ return obj;
+ }
+ }
+ /**
+ * Returns a Long representing the given 32 bit integer value.
+ * @function
+ * @param {number} value The 32 bit integer in question
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {!Long} The corresponding Long value
+ */
+ Long$1.fromInt = fromInt;
+ /**
+ * @param {number} value
+ * @param {boolean=} unsigned
+ * @returns {!Long}
+ * @inner
+ */
+ function fromNumber(value, unsigned) {
+ if (isNaN(value))
+ return unsigned ? UZERO : ZERO;
+ if (unsigned) {
+ if (value < 0)
+ return UZERO;
+ if (value >= TWO_PWR_64_DBL)
+ return MAX_UNSIGNED_VALUE;
+ }
+ else {
+ if (value <= -TWO_PWR_63_DBL)
+ return MIN_VALUE;
+ if (value + 1 >= TWO_PWR_63_DBL)
+ return MAX_VALUE;
+ }
+ if (value < 0)
+ return fromNumber(-value, unsigned).neg();
+ return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
+ }
+ /**
+ * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
+ * @function
+ * @param {number} value The number in question
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {!Long} The corresponding Long value
+ */
+ Long$1.fromNumber = fromNumber;
+ /**
+ * @param {number} lowBits
+ * @param {number} highBits
+ * @param {boolean=} unsigned
+ * @returns {!Long}
+ * @inner
+ */
+ function fromBits(lowBits, highBits, unsigned) {
+ return new Long$1(lowBits, highBits, unsigned);
+ }
+ /**
+ * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
+ * assumed to use 32 bits.
+ * @function
+ * @param {number} lowBits The low 32 bits
+ * @param {number} highBits The high 32 bits
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {!Long} The corresponding Long value
+ */
+ Long$1.fromBits = fromBits;
+ /**
+ * @function
+ * @param {number} base
+ * @param {number} exponent
+ * @returns {number}
+ * @inner
+ */
+ var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
+ /**
+ * @param {string} str
+ * @param {(boolean|number)=} unsigned
+ * @param {number=} radix
+ * @returns {!Long}
+ * @inner
+ */
+ function fromString(str, unsigned, radix) {
+ if (str.length === 0)
+ throw Error('empty string');
+ if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
+ return ZERO;
+ if (typeof unsigned === 'number') {
+ // For goog.math.long compatibility
+ radix = unsigned,
+ unsigned = false;
+ }
+ else {
+ unsigned = !!unsigned;
+ }
+ radix = radix || 10;
+ if (radix < 2 || 36 < radix)
+ throw RangeError('radix');
+ var p;
+ if ((p = str.indexOf('-')) > 0)
+ throw Error('interior hyphen');
+ else if (p === 0) {
+ return fromString(str.substring(1), unsigned, radix).neg();
+ }
+ // Do several (8) digits each time through the loop, so as to
+ // minimize the calls to the very expensive emulated div.
+ var radixToPower = fromNumber(pow_dbl(radix, 8));
+ var result = ZERO;
+ for (var i = 0; i < str.length; i += 8) {
+ var size = Math.min(8, str.length - i), value = parseInt(str.substring(i, i + size), radix);
+ if (size < 8) {
+ var power = fromNumber(pow_dbl(radix, size));
+ result = result.mul(power).add(fromNumber(value));
+ }
+ else {
+ result = result.mul(radixToPower);
+ result = result.add(fromNumber(value));
+ }
+ }
+ result.unsigned = unsigned;
+ return result;
+ }
+ /**
+ * Returns a Long representation of the given string, written using the specified radix.
+ * @function
+ * @param {string} str The textual representation of the Long
+ * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
+ * @param {number=} radix The radix in which the text is written (2-36), defaults to 10
+ * @returns {!Long} The corresponding Long value
+ */
+ Long$1.fromString = fromString;
+ /**
+ * @function
+ * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
+ * @param {boolean=} unsigned
+ * @returns {!Long}
+ * @inner
+ */
+ function fromValue(val, unsigned) {
+ if (typeof val === 'number')
+ return fromNumber(val, unsigned);
+ if (typeof val === 'string')
+ return fromString(val, unsigned);
+ // Throws for non-objects, converts non-instanceof Long:
+ return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
+ }
+ /**
+ * Converts the specified value to a Long using the appropriate from* function for its type.
+ * @function
+ * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {!Long}
+ */
+ Long$1.fromValue = fromValue;
+ // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
+ // no runtime penalty for these.
+ /**
+ * @type {number}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_16_DBL = 1 << 16;
+ /**
+ * @type {number}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_24_DBL = 1 << 24;
+ /**
+ * @type {number}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
+ /**
+ * @type {number}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
+ /**
+ * @type {number}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
+ /**
+ * @type {!Long}
+ * @const
+ * @inner
+ */
+ var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var ZERO = fromInt(0);
+ /**
+ * Signed zero.
+ * @type {!Long}
+ */
+ Long$1.ZERO = ZERO;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var UZERO = fromInt(0, true);
+ /**
+ * Unsigned zero.
+ * @type {!Long}
+ */
+ Long$1.UZERO = UZERO;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var ONE = fromInt(1);
+ /**
+ * Signed one.
+ * @type {!Long}
+ */
+ Long$1.ONE = ONE;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var UONE = fromInt(1, true);
+ /**
+ * Unsigned one.
+ * @type {!Long}
+ */
+ Long$1.UONE = UONE;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var NEG_ONE = fromInt(-1);
+ /**
+ * Signed negative one.
+ * @type {!Long}
+ */
+ Long$1.NEG_ONE = NEG_ONE;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var MAX_VALUE = fromBits(0xFFFFFFFF | 0, 0x7FFFFFFF | 0, false);
+ /**
+ * Maximum signed value.
+ * @type {!Long}
+ */
+ Long$1.MAX_VALUE = MAX_VALUE;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF | 0, 0xFFFFFFFF | 0, true);
+ /**
+ * Maximum unsigned value.
+ * @type {!Long}
+ */
+ Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
+ /**
+ * @type {!Long}
+ * @inner
+ */
+ var MIN_VALUE = fromBits(0, 0x80000000 | 0, false);
+ /**
+ * Minimum signed value.
+ * @type {!Long}
+ */
+ Long$1.MIN_VALUE = MIN_VALUE;
+ /**
+ * @alias Long.prototype
+ * @inner
+ */
+ var LongPrototype = Long$1.prototype;
+ /**
+ * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer.
+ * @returns {number}
+ */
+ LongPrototype.toInt = function toInt() {
+ return this.unsigned ? this.low >>> 0 : this.low;
+ };
+ /**
+ * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa).
+ * @returns {number}
+ */
+ LongPrototype.toNumber = function toNumber() {
+ if (this.unsigned)
+ return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0);
+ return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
+ };
+ /**
+ * Converts the Long to a string written in the specified radix.
+ * @param {number=} radix Radix (2-36), defaults to 10
+ * @returns {string}
+ * @override
+ * @throws {RangeError} If `radix` is out of range
+ */
+ LongPrototype.toString = function toString(radix) {
+ radix = radix || 10;
+ if (radix < 2 || 36 < radix)
+ throw RangeError('radix');
+ if (this.isZero())
+ return '0';
+ if (this.isNegative()) { // Unsigned Longs are never negative
+ if (this.eq(MIN_VALUE)) {
+ // We need to change the Long value before it can be negated, so we remove
+ // the bottom-most digit in this base and then recurse to do the rest.
+ var radixLong = fromNumber(radix), div = this.div(radixLong), rem1 = div.mul(radixLong).sub(this);
+ return div.toString(radix) + rem1.toInt().toString(radix);
+ }
+ else
+ return '-' + this.neg().toString(radix);
+ }
+ // Do several (6) digits each time through the loop, so as to
+ // minimize the calls to the very expensive emulated div.
+ var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned), rem = this;
+ var result = '';
+ while (true) {
+ var remDiv = rem.div(radixToPower), intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0, digits = intval.toString(radix);
+ rem = remDiv;
+ if (rem.isZero())
+ return digits + result;
+ else {
+ while (digits.length < 6)
+ digits = '0' + digits;
+ result = '' + digits + result;
+ }
+ }
+ };
+ /**
+ * Gets the high 32 bits as a signed integer.
+ * @returns {number} Signed high bits
+ */
+ LongPrototype.getHighBits = function getHighBits() {
+ return this.high;
+ };
+ /**
+ * Gets the high 32 bits as an unsigned integer.
+ * @returns {number} Unsigned high bits
+ */
+ LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
+ return this.high >>> 0;
+ };
+ /**
+ * Gets the low 32 bits as a signed integer.
+ * @returns {number} Signed low bits
+ */
+ LongPrototype.getLowBits = function getLowBits() {
+ return this.low;
+ };
+ /**
+ * Gets the low 32 bits as an unsigned integer.
+ * @returns {number} Unsigned low bits
+ */
+ LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
+ return this.low >>> 0;
+ };
+ /**
+ * Gets the number of bits needed to represent the absolute value of this Long.
+ * @returns {number}
+ */
+ LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
+ if (this.isNegative()) // Unsigned Longs are never negative
+ return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
+ var val = this.high != 0 ? this.high : this.low;
+ for (var bit = 31; bit > 0; bit--)
+ if ((val & (1 << bit)) != 0)
+ break;
+ return this.high != 0 ? bit + 33 : bit + 1;
+ };
+ /**
+ * Tests if this Long's value equals zero.
+ * @returns {boolean}
+ */
+ LongPrototype.isZero = function isZero() {
+ return this.high === 0 && this.low === 0;
+ };
+ /**
+ * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}.
+ * @returns {boolean}
+ */
+ LongPrototype.eqz = LongPrototype.isZero;
+ /**
+ * Tests if this Long's value is negative.
+ * @returns {boolean}
+ */
+ LongPrototype.isNegative = function isNegative() {
+ return !this.unsigned && this.high < 0;
+ };
+ /**
+ * Tests if this Long's value is positive.
+ * @returns {boolean}
+ */
+ LongPrototype.isPositive = function isPositive() {
+ return this.unsigned || this.high >= 0;
+ };
+ /**
+ * Tests if this Long's value is odd.
+ * @returns {boolean}
+ */
+ LongPrototype.isOdd = function isOdd() {
+ return (this.low & 1) === 1;
+ };
+ /**
+ * Tests if this Long's value is even.
+ * @returns {boolean}
+ */
+ LongPrototype.isEven = function isEven() {
+ return (this.low & 1) === 0;
+ };
+ /**
+ * Tests if this Long's value equals the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.equals = function equals(other) {
+ if (!isLong(other))
+ other = fromValue(other);
+ if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1)
+ return false;
+ return this.high === other.high && this.low === other.low;
+ };
+ /**
+ * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.eq = LongPrototype.equals;
+ /**
+ * Tests if this Long's value differs from the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.notEquals = function notEquals(other) {
+ return !this.eq(/* validates */ other);
+ };
+ /**
+ * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.neq = LongPrototype.notEquals;
+ /**
+ * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.ne = LongPrototype.notEquals;
+ /**
+ * Tests if this Long's value is less than the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.lessThan = function lessThan(other) {
+ return this.comp(/* validates */ other) < 0;
+ };
+ /**
+ * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.lt = LongPrototype.lessThan;
+ /**
+ * Tests if this Long's value is less than or equal the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
+ return this.comp(/* validates */ other) <= 0;
+ };
+ /**
+ * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.lte = LongPrototype.lessThanOrEqual;
+ /**
+ * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.le = LongPrototype.lessThanOrEqual;
+ /**
+ * Tests if this Long's value is greater than the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.greaterThan = function greaterThan(other) {
+ return this.comp(/* validates */ other) > 0;
+ };
+ /**
+ * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.gt = LongPrototype.greaterThan;
+ /**
+ * Tests if this Long's value is greater than or equal the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
+ return this.comp(/* validates */ other) >= 0;
+ };
+ /**
+ * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.gte = LongPrototype.greaterThanOrEqual;
+ /**
+ * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {boolean}
+ */
+ LongPrototype.ge = LongPrototype.greaterThanOrEqual;
+ /**
+ * Compares this Long's value with the specified's.
+ * @param {!Long|number|string} other Other value
+ * @returns {number} 0 if they are the same, 1 if the this is greater and -1
+ * if the given one is greater
+ */
+ LongPrototype.compare = function compare(other) {
+ if (!isLong(other))
+ other = fromValue(other);
+ if (this.eq(other))
+ return 0;
+ var thisNeg = this.isNegative(), otherNeg = other.isNegative();
+ if (thisNeg && !otherNeg)
+ return -1;
+ if (!thisNeg && otherNeg)
+ return 1;
+ // At this point the sign bits are the same
+ if (!this.unsigned)
+ return this.sub(other).isNegative() ? -1 : 1;
+ // Both are positive if at least one is unsigned
+ return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1;
+ };
+ /**
+ * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}.
+ * @function
+ * @param {!Long|number|string} other Other value
+ * @returns {number} 0 if they are the same, 1 if the this is greater and -1
+ * if the given one is greater
+ */
+ LongPrototype.comp = LongPrototype.compare;
+ /**
+ * Negates this Long's value.
+ * @returns {!Long} Negated Long
+ */
+ LongPrototype.negate = function negate() {
+ if (!this.unsigned && this.eq(MIN_VALUE))
+ return MIN_VALUE;
+ return this.not().add(ONE);
+ };
+ /**
+ * Negates this Long's value. This is an alias of {@link Long#negate}.
+ * @function
+ * @returns {!Long} Negated Long
+ */
+ LongPrototype.neg = LongPrototype.negate;
+ /**
+ * Returns the sum of this and the specified Long.
+ * @param {!Long|number|string} addend Addend
+ * @returns {!Long} Sum
+ */
+ LongPrototype.add = function add(addend) {
+ if (!isLong(addend))
+ addend = fromValue(addend);
+ // Divide each number into 4 chunks of 16 bits, and then sum the chunks.
+ var a48 = this.high >>> 16;
+ var a32 = this.high & 0xFFFF;
+ var a16 = this.low >>> 16;
+ var a00 = this.low & 0xFFFF;
+ var b48 = addend.high >>> 16;
+ var b32 = addend.high & 0xFFFF;
+ var b16 = addend.low >>> 16;
+ var b00 = addend.low & 0xFFFF;
+ var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
+ c00 += a00 + b00;
+ c16 += c00 >>> 16;
+ c00 &= 0xFFFF;
+ c16 += a16 + b16;
+ c32 += c16 >>> 16;
+ c16 &= 0xFFFF;
+ c32 += a32 + b32;
+ c48 += c32 >>> 16;
+ c32 &= 0xFFFF;
+ c48 += a48 + b48;
+ c48 &= 0xFFFF;
+ return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
+ };
+ /**
+ * Returns the difference of this and the specified Long.
+ * @param {!Long|number|string} subtrahend Subtrahend
+ * @returns {!Long} Difference
+ */
+ LongPrototype.subtract = function subtract(subtrahend) {
+ if (!isLong(subtrahend))
+ subtrahend = fromValue(subtrahend);
+ return this.add(subtrahend.neg());
+ };
+ /**
+ * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}.
+ * @function
+ * @param {!Long|number|string} subtrahend Subtrahend
+ * @returns {!Long} Difference
+ */
+ LongPrototype.sub = LongPrototype.subtract;
+ /**
+ * Returns the product of this and the specified Long.
+ * @param {!Long|number|string} multiplier Multiplier
+ * @returns {!Long} Product
+ */
+ LongPrototype.multiply = function multiply(multiplier) {
+ if (this.isZero())
+ return ZERO;
+ if (!isLong(multiplier))
+ multiplier = fromValue(multiplier);
+ // use wasm support if present
+ if (wasm) {
+ var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high);
+ return fromBits(low, wasm.get_high(), this.unsigned);
+ }
+ if (multiplier.isZero())
+ return ZERO;
+ if (this.eq(MIN_VALUE))
+ return multiplier.isOdd() ? MIN_VALUE : ZERO;
+ if (multiplier.eq(MIN_VALUE))
+ return this.isOdd() ? MIN_VALUE : ZERO;
+ if (this.isNegative()) {
+ if (multiplier.isNegative())
+ return this.neg().mul(multiplier.neg());
+ else
+ return this.neg().mul(multiplier).neg();
+ }
+ else if (multiplier.isNegative())
+ return this.mul(multiplier.neg()).neg();
+ // If both longs are small, use float multiplication
+ if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24))
+ return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
+ // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
+ // We can skip products that would overflow.
+ var a48 = this.high >>> 16;
+ var a32 = this.high & 0xFFFF;
+ var a16 = this.low >>> 16;
+ var a00 = this.low & 0xFFFF;
+ var b48 = multiplier.high >>> 16;
+ var b32 = multiplier.high & 0xFFFF;
+ var b16 = multiplier.low >>> 16;
+ var b00 = multiplier.low & 0xFFFF;
+ var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
+ c00 += a00 * b00;
+ c16 += c00 >>> 16;
+ c00 &= 0xFFFF;
+ c16 += a16 * b00;
+ c32 += c16 >>> 16;
+ c16 &= 0xFFFF;
+ c16 += a00 * b16;
+ c32 += c16 >>> 16;
+ c16 &= 0xFFFF;
+ c32 += a32 * b00;
+ c48 += c32 >>> 16;
+ c32 &= 0xFFFF;
+ c32 += a16 * b16;
+ c48 += c32 >>> 16;
+ c32 &= 0xFFFF;
+ c32 += a00 * b32;
+ c48 += c32 >>> 16;
+ c32 &= 0xFFFF;
+ c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
+ c48 &= 0xFFFF;
+ return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
+ };
+ /**
+ * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}.
+ * @function
+ * @param {!Long|number|string} multiplier Multiplier
+ * @returns {!Long} Product
+ */
+ LongPrototype.mul = LongPrototype.multiply;
+ /**
+ * Returns this Long divided by the specified. The result is signed if this Long is signed or
+ * unsigned if this Long is unsigned.
+ * @param {!Long|number|string} divisor Divisor
+ * @returns {!Long} Quotient
+ */
+ LongPrototype.divide = function divide(divisor) {
+ if (!isLong(divisor))
+ divisor = fromValue(divisor);
+ if (divisor.isZero())
+ throw Error('division by zero');
+ // use wasm support if present
+ if (wasm) {
+ // guard against signed division overflow: the largest
+ // negative number / -1 would be 1 larger than the largest
+ // positive number, due to two's complement.
+ if (!this.unsigned &&
+ this.high === -0x80000000 &&
+ divisor.low === -1 && divisor.high === -1) {
+ // be consistent with non-wasm code path
+ return this;
+ }
+ var low = (this.unsigned ? wasm.div_u : wasm.div_s)(this.low, this.high, divisor.low, divisor.high);
+ return fromBits(low, wasm.get_high(), this.unsigned);
+ }
+ if (this.isZero())
+ return this.unsigned ? UZERO : ZERO;
+ var approx, rem, res;
+ if (!this.unsigned) {
+ // This section is only relevant for signed longs and is derived from the
+ // closure library as a whole.
+ if (this.eq(MIN_VALUE)) {
+ if (divisor.eq(ONE) || divisor.eq(NEG_ONE))
+ return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE
+ else if (divisor.eq(MIN_VALUE))
+ return ONE;
+ else {
+ // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|.
+ var halfThis = this.shr(1);
+ approx = halfThis.div(divisor).shl(1);
+ if (approx.eq(ZERO)) {
+ return divisor.isNegative() ? ONE : NEG_ONE;
+ }
+ else {
+ rem = this.sub(divisor.mul(approx));
+ res = approx.add(rem.div(divisor));
+ return res;
+ }
+ }
+ }
+ else if (divisor.eq(MIN_VALUE))
+ return this.unsigned ? UZERO : ZERO;
+ if (this.isNegative()) {
+ if (divisor.isNegative())
+ return this.neg().div(divisor.neg());
+ return this.neg().div(divisor).neg();
+ }
+ else if (divisor.isNegative())
+ return this.div(divisor.neg()).neg();
+ res = ZERO;
+ }
+ else {
+ // The algorithm below has not been made for unsigned longs. It's therefore
+ // required to take special care of the MSB prior to running it.
+ if (!divisor.unsigned)
+ divisor = divisor.toUnsigned();
+ if (divisor.gt(this))
+ return UZERO;
+ if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true
+ return UONE;
+ res = UZERO;
+ }
+ // Repeat the following until the remainder is less than other: find a
+ // floating-point that approximates remainder / other *from below*, add this
+ // into the result, and subtract it from the remainder. It is critical that
+ // the approximate value is less than or equal to the real value so that the
+ // remainder never becomes negative.
+ rem = this;
+ while (rem.gte(divisor)) {
+ // Approximate the result of division. This may be a little greater or
+ // smaller than the actual value.
+ approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
+ // We will tweak the approximate result by changing it in the 48-th digit or
+ // the smallest non-fractional digit, whichever is larger.
+ var log2 = Math.ceil(Math.log(approx) / Math.LN2), delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48),
+ // Decrease the approximation until it is smaller than the remainder. Note
+ // that if it is too large, the product overflows and is negative.
+ approxRes = fromNumber(approx), approxRem = approxRes.mul(divisor);
+ while (approxRem.isNegative() || approxRem.gt(rem)) {
+ approx -= delta;
+ approxRes = fromNumber(approx, this.unsigned);
+ approxRem = approxRes.mul(divisor);
+ }
+ // We know the answer can't be zero... and actually, zero would cause
+ // infinite recursion since we would make no progress.
+ if (approxRes.isZero())
+ approxRes = ONE;
+ res = res.add(approxRes);
+ rem = rem.sub(approxRem);
+ }
+ return res;
+ };
+ /**
+ * Returns this Long divided by the specified. This is an alias of {@link Long#divide}.
+ * @function
+ * @param {!Long|number|string} divisor Divisor
+ * @returns {!Long} Quotient
+ */
+ LongPrototype.div = LongPrototype.divide;
+ /**
+ * Returns this Long modulo the specified.
+ * @param {!Long|number|string} divisor Divisor
+ * @returns {!Long} Remainder
+ */
+ LongPrototype.modulo = function modulo(divisor) {
+ if (!isLong(divisor))
+ divisor = fromValue(divisor);
+ // use wasm support if present
+ if (wasm) {
+ var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(this.low, this.high, divisor.low, divisor.high);
+ return fromBits(low, wasm.get_high(), this.unsigned);
+ }
+ return this.sub(this.div(divisor).mul(divisor));
+ };
+ /**
+ * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
+ * @function
+ * @param {!Long|number|string} divisor Divisor
+ * @returns {!Long} Remainder
+ */
+ LongPrototype.mod = LongPrototype.modulo;
+ /**
+ * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
+ * @function
+ * @param {!Long|number|string} divisor Divisor
+ * @returns {!Long} Remainder
+ */
+ LongPrototype.rem = LongPrototype.modulo;
+ /**
+ * Returns the bitwise NOT of this Long.
+ * @returns {!Long}
+ */
+ LongPrototype.not = function not() {
+ return fromBits(~this.low, ~this.high, this.unsigned);
+ };
+ /**
+ * Returns the bitwise AND of this Long and the specified.
+ * @param {!Long|number|string} other Other Long
+ * @returns {!Long}
+ */
+ LongPrototype.and = function and(other) {
+ if (!isLong(other))
+ other = fromValue(other);
+ return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
+ };
+ /**
+ * Returns the bitwise OR of this Long and the specified.
+ * @param {!Long|number|string} other Other Long
+ * @returns {!Long}
+ */
+ LongPrototype.or = function or(other) {
+ if (!isLong(other))
+ other = fromValue(other);
+ return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
+ };
+ /**
+ * Returns the bitwise XOR of this Long and the given one.
+ * @param {!Long|number|string} other Other Long
+ * @returns {!Long}
+ */
+ LongPrototype.xor = function xor(other) {
+ if (!isLong(other))
+ other = fromValue(other);
+ return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
+ };
+ /**
+ * Returns this Long with bits shifted to the left by the given amount.
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shiftLeft = function shiftLeft(numBits) {
+ if (isLong(numBits))
+ numBits = numBits.toInt();
+ if ((numBits &= 63) === 0)
+ return this;
+ else if (numBits < 32)
+ return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned);
+ else
+ return fromBits(0, this.low << (numBits - 32), this.unsigned);
+ };
+ /**
+ * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}.
+ * @function
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shl = LongPrototype.shiftLeft;
+ /**
+ * Returns this Long with bits arithmetically shifted to the right by the given amount.
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shiftRight = function shiftRight(numBits) {
+ if (isLong(numBits))
+ numBits = numBits.toInt();
+ if ((numBits &= 63) === 0)
+ return this;
+ else if (numBits < 32)
+ return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned);
+ else
+ return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned);
+ };
+ /**
+ * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}.
+ * @function
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shr = LongPrototype.shiftRight;
+ /**
+ * Returns this Long with bits logically shifted to the right by the given amount.
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
+ if (isLong(numBits))
+ numBits = numBits.toInt();
+ numBits &= 63;
+ if (numBits === 0)
+ return this;
+ else {
+ var high = this.high;
+ if (numBits < 32) {
+ var low = this.low;
+ return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned);
+ }
+ else if (numBits === 32)
+ return fromBits(high, 0, this.unsigned);
+ else
+ return fromBits(high >>> (numBits - 32), 0, this.unsigned);
+ }
+ };
+ /**
+ * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
+ * @function
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shru = LongPrototype.shiftRightUnsigned;
+ /**
+ * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
+ * @function
+ * @param {number|!Long} numBits Number of bits
+ * @returns {!Long} Shifted Long
+ */
+ LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
+ /**
+ * Converts this Long to signed.
+ * @returns {!Long} Signed long
+ */
+ LongPrototype.toSigned = function toSigned() {
+ if (!this.unsigned)
+ return this;
+ return fromBits(this.low, this.high, false);
+ };
+ /**
+ * Converts this Long to unsigned.
+ * @returns {!Long} Unsigned long
+ */
+ LongPrototype.toUnsigned = function toUnsigned() {
+ if (this.unsigned)
+ return this;
+ return fromBits(this.low, this.high, true);
+ };
+ /**
+ * Converts this Long to its byte representation.
+ * @param {boolean=} le Whether little or big endian, defaults to big endian
+ * @returns {!Array.<number>} Byte representation
+ */
+ LongPrototype.toBytes = function toBytes(le) {
+ return le ? this.toBytesLE() : this.toBytesBE();
+ };
+ /**
+ * Converts this Long to its little endian byte representation.
+ * @returns {!Array.<number>} Little endian byte representation
+ */
+ LongPrototype.toBytesLE = function toBytesLE() {
+ var hi = this.high, lo = this.low;
+ return [
+ lo & 0xff,
+ lo >>> 8 & 0xff,
+ lo >>> 16 & 0xff,
+ lo >>> 24,
+ hi & 0xff,
+ hi >>> 8 & 0xff,
+ hi >>> 16 & 0xff,
+ hi >>> 24
+ ];
+ };
+ /**
+ * Converts this Long to its big endian byte representation.
+ * @returns {!Array.<number>} Big endian byte representation
+ */
+ LongPrototype.toBytesBE = function toBytesBE() {
+ var hi = this.high, lo = this.low;
+ return [
+ hi >>> 24,
+ hi >>> 16 & 0xff,
+ hi >>> 8 & 0xff,
+ hi & 0xff,
+ lo >>> 24,
+ lo >>> 16 & 0xff,
+ lo >>> 8 & 0xff,
+ lo & 0xff
+ ];
+ };
+ /**
+ * Creates a Long from its byte representation.
+ * @param {!Array.<number>} bytes Byte representation
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @param {boolean=} le Whether little or big endian, defaults to big endian
+ * @returns {Long} The corresponding Long value
+ */
+ Long$1.fromBytes = function fromBytes(bytes, unsigned, le) {
+ return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned);
+ };
+ /**
+ * Creates a Long from its little endian byte representation.
+ * @param {!Array.<number>} bytes Little endian byte representation
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {Long} The corresponding Long value
+ */
+ Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) {
+ return new Long$1(bytes[0] |
+ bytes[1] << 8 |
+ bytes[2] << 16 |
+ bytes[3] << 24, bytes[4] |
+ bytes[5] << 8 |
+ bytes[6] << 16 |
+ bytes[7] << 24, unsigned);
+ };
+ /**
+ * Creates a Long from its big endian byte representation.
+ * @param {!Array.<number>} bytes Big endian byte representation
+ * @param {boolean=} unsigned Whether unsigned or not, defaults to signed
+ * @returns {Long} The corresponding Long value
+ */
+ Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) {
+ return new Long$1(bytes[4] << 24 |
+ bytes[5] << 16 |
+ bytes[6] << 8 |
+ bytes[7], bytes[0] << 24 |
+ bytes[1] << 16 |
+ bytes[2] << 8 |
+ bytes[3], unsigned);
+ };
+
+ var LongExports = /*#__PURE__*/Object.assign(/*#__PURE__*/Object.create(null), long, {
+ 'default': long
+ });
+
+ // tslint:disable-next-line
+ var Long =
+ // tslint:disable-next-line
+ long || LongExports;
+ function hexToLong(hex) {
+ return Long.fromString(hex, true, 16);
+ }
+ // Some primes between 2^63 and 2^64 for various uses.
+ // Hex 0xc3a5c85c97cb3127
+ var k0 = hexToLong('c3a5c85c97cb3127');
+ // Hex 0xb492b66fbe98f273
+ var k1 = hexToLong('b492b66fbe98f273');
+ // Hex 0x9ae16a3b2f90404f
+ var k2 = hexToLong('9ae16a3b2f90404f');
+ function shiftMix(val) {
+ return val.xor(val.shru(47));
+ }
+ function fetch$2(s, offset, numBytes) {
+ var bytes = s.slice(offset, offset + numBytes);
+ return Long.fromBytes(Array.from(bytes), true, true);
+ }
+ function fetch64(s, offset) {
+ return fetch$2(s, offset, 8);
+ }
+ function fetch32(s, offset) {
+ return fetch$2(s, offset, 4);
+ }
+ function rotate64(val, shift) {
+ // Avoid shifting by 64: doing so yields an undefined result.
+ return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
+ }
+ function hashLen16(u, v, mul) {
+ if (mul === void 0) { mul = hexToLong('9ddfea08eb382d69'); }
+ // Murmur-inspired hashing.
+ var a = u.xor(v).mul(mul);
+ a = a.xor(a.shru(47));
+ var b = v.xor(a).mul(mul);
+ b = b.xor(b.shru(47));
+ b = b.mul(mul);
+ return b;
+ }
+ // Return a 16-byte hash for 48 bytes. Quick and dirty.
+ // Callers do best to use "random-looking" values for a and b.
+ function weakHashLen32WithSeeds(w, x, y, z, a, b) {
+ a = a.add(w);
+ b = rotate64(b.add(a).add(z), 21);
+ var c = a;
+ a = a.add(x);
+ a = a.add(y);
+ b = b.add(rotate64(a, 44));
+ return [a.add(z), b.add(c)];
+ }
+ function weakHashLen32WithSeedsStr(s, offset, a, b) {
+ return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
+ }
+ function hashLen0to16(s, len) {
+ if (len === void 0) { len = s.length; }
+ if (len >= 8) {
+ var mul = k2.add(len * 2);
+ var a = fetch64(s, 0).add(k2);
+ var b = fetch64(s, len - 8);
+ var c = rotate64(b, 37).mul(mul).add(a);
+ var d = rotate64(a, 25).add(b).mul(mul);
+ return hashLen16(c, d, mul);
+ }
+ if (len >= 4) {
+ var mul = k2.add(len * 2);
+ var a = fetch32(s, 0);
+ return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
+ }
+ if (len > 0) {
+ var a = s[0];
+ var b = s[len >> 1];
+ var c = s[len - 1];
+ var y = a + (b << 8);
+ var z = len + (c << 2);
+ return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
+ }
+ return k2;
+ }
+ function hashLen17to32(s, len) {
+ if (len === void 0) { len = s.length; }
+ var mul = k2.add(len * 2);
+ var a = fetch64(s, 0).mul(k1);
+ var b = fetch64(s, 8);
+ var c = fetch64(s, len - 8).mul(mul);
+ var d = fetch64(s, len - 16).mul(k2);
+ return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
+ }
+ function hashLen33to64(s, len) {
+ if (len === void 0) { len = s.length; }
+ var mul = k2.add(len * 2);
+ var a = fetch64(s, 0).mul(k2);
+ var b = fetch64(s, 8);
+ var c = fetch64(s, len - 8).mul(mul);
+ var d = fetch64(s, len - 16).mul(k2);
+ var y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
+ var z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
+ var e = fetch64(s, 16).mul(mul);
+ var f = fetch64(s, 24);
+ var g = y.add(fetch64(s, len - 32)).mul(mul);
+ var h = z.add(fetch64(s, len - 24)).mul(mul);
+ return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
+ }
+ function fingerPrint64(s, len) {
+ var _a, _b;
+ if (len === void 0) { len = s.length; }
+ var seed = Long.fromNumber(81, true);
+ if (len <= 32) {
+ if (len <= 16) {
+ return hashLen0to16(s, len);
+ }
+ else {
+ return hashLen17to32(s, len);
+ }
+ }
+ else if (len <= 64) {
+ return hashLen33to64(s, len);
+ }
+ // For strings over 64 bytes we loop. Internal state consists of
+ // 56 bytes: v, w, x, y, and z.
+ var x = seed;
+ var y = seed.mul(k1).add(113);
+ var z = shiftMix(y.mul(k2).add(113)).mul(k2);
+ var v = [Long.UZERO, Long.UZERO];
+ var w = [Long.UZERO, Long.UZERO];
+ x = x.mul(k2).add(fetch64(s, 0));
+ var offset = 0;
+ // Set end so that after the loop we have 1 to 64 bytes left to process.
+ var end = ((len - 1) >> 6) * 64;
+ var last64 = end + ((len - 1) & 63) - 63;
+ do {
+ x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
+ y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
+ x = x.xor(w[1]);
+ y = y.add(v[0]).add(fetch64(s, offset + 40));
+ z = rotate64(z.add(w[0]), 33).mul(k1);
+ v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
+ w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
+ _a = __read([x, z], 2), z = _a[0], x = _a[1];
+ offset += 64;
+ } while (offset !== end);
+ var mul = k1.add(z.and(0xff).shl(1));
+ // Point to the last 64 bytes of input.
+ offset = last64;
+ w[0] = w[0].add((len - 1) & 63);
+ v[0] = v[0].add(w[0]);
+ w[0] = w[0].add(v[0]);
+ x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
+ y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
+ x = x.xor(w[1].mul(9));
+ y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
+ z = rotate64(z.add(w[0]), 33).mul(mul);
+ v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
+ w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
+ _b = __read([x, z], 2), z = _b[0], x = _b[1];
+ return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
+ }
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Create typed array for scalar value. Used for storing in `DataStorage`.
+ */
+ function createScalarValue(value, dtype) {
+ if (dtype === 'string') {
+ return encodeString(value);
+ }
+ return toTypedArray([value], dtype);
+ }
+ function noConversionNeeded(a, dtype) {
+ return (a instanceof Float32Array && dtype === 'float32') ||
+ (a instanceof Int32Array && dtype === 'int32') ||
+ (a instanceof Uint8Array && dtype === 'bool');
+ }
+ function toTypedArray(a, dtype) {
+ if (dtype === 'string') {
+ throw new Error('Cannot convert a string[] to a TypedArray');
+ }
+ if (Array.isArray(a)) {
+ a = flatten(a);
+ }
+ if (env().getBool('DEBUG')) {
+ checkConversionForErrors(a, dtype);
+ }
+ if (noConversionNeeded(a, dtype)) {
+ return a;
+ }
+ if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
+ return new Float32Array(a);
+ }
+ else if (dtype === 'int32') {
+ return new Int32Array(a);
+ }
+ else if (dtype === 'bool') {
+ var bool = new Uint8Array(a.length);
+ for (var i = 0; i < bool.length; ++i) {
+ if (Math.round(a[i]) !== 0) {
+ bool[i] = 1;
+ }
+ }
+ return bool;
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ }
+ /**
+ * Returns the current high-resolution time in milliseconds relative to an
+ * arbitrary time in the past. It works across different platforms (node.js,
+ * browsers).
+ *
+ * ```js
+ * console.log(tf.util.now());
+ * ```
+ *
+ * @doc {heading: 'Util', namespace: 'util'}
+ */
+ function now() {
+ return env().platform.now();
+ }
+ /**
+ * Returns a platform-specific implementation of
+ * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
+ *
+ * If `fetch` is defined on the global object (`window`, `process`, etc.),
+ * `tf.util.fetch` returns that function.
+ *
+ * If not, `tf.util.fetch` returns a platform-specific solution.
+ *
+ * ```js
+ * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
+ * // handle response
+ * ```
+ *
+ * @doc {heading: 'Util'}
+ */
+ function fetch$1(path, requestInits) {
+ return env().platform.fetch(path, requestInits);
+ }
+ /**
+ * Encodes the provided string into bytes using the provided encoding scheme.
+ *
+ * @param s The string to encode.
+ * @param encoding The encoding scheme. Defaults to utf-8.
+ *
+ * @doc {heading: 'Util'}
+ */
+ function encodeString(s, encoding) {
+ if (encoding === void 0) { encoding = 'utf-8'; }
+ encoding = encoding || 'utf-8';
+ return env().platform.encode(s, encoding);
+ }
+ /**
+ * Decodes the provided bytes into a string using the provided encoding scheme.
+ * @param bytes The bytes to decode.
+ *
+ * @param encoding The encoding scheme. Defaults to utf-8.
+ *
+ * @doc {heading: 'Util'}
+ */
+ function decodeString(bytes, encoding) {
+ if (encoding === void 0) { encoding = 'utf-8'; }
+ encoding = encoding || 'utf-8';
+ return env().platform.decode(bytes, encoding);
+ }
+
+ var util = {
+ __proto__: null,
+ createScalarValue: createScalarValue,
+ toTypedArray: toTypedArray,
+ now: now,
+ fetch: fetch$1,
+ encodeString: encodeString,
+ decodeString: decodeString,
+ shuffle: shuffle,
+ shuffleCombo: shuffleCombo,
+ clamp: clamp,
+ nearestLargerEven: nearestLargerEven,
+ swap: swap,
+ sum: sum$1,
+ randUniform: randUniform,
+ distSquared: distSquared,
+ assert: assert,
+ assertShapesMatch: assertShapesMatch,
+ assertNonNull: assertNonNull,
+ flatten: flatten,
+ sizeFromShape: sizeFromShape,
+ isScalarShape: isScalarShape,
+ arraysEqual: arraysEqual,
+ isInt: isInt,
+ tanh: tanh$1,
+ sizeToSquarishShape: sizeToSquarishShape,
+ createShuffledIndices: createShuffledIndices,
+ rightPad: rightPad,
+ repeatedTry: repeatedTry,
+ inferFromImplicitShape: inferFromImplicitShape,
+ parseAxisParam: parseAxisParam,
+ squeezeShape: squeezeShape,
+ getTypedArrayFromDType: getTypedArrayFromDType,
+ getArrayFromDType: getArrayFromDType,
+ checkConversionForErrors: checkConversionForErrors,
+ isValidDtype: isValidDtype,
+ hasEncodingLoss: hasEncodingLoss,
+ isTypedArray: isTypedArray,
+ bytesPerElement: bytesPerElement,
+ bytesFromStringArray: bytesFromStringArray,
+ isString: isString,
+ isBoolean: isBoolean,
+ isNumber: isNumber,
+ inferDtype: inferDtype,
+ isFunction: isFunction,
+ nearestDivisor: nearestDivisor,
+ computeStrides: computeStrides,
+ toNestedArray: toNestedArray,
+ makeOnesTypedArray: makeOnesTypedArray,
+ makeZerosTypedArray: makeZerosTypedArray,
+ makeZerosNestedTypedArray: makeZerosNestedTypedArray,
+ assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
+ locToIndex: locToIndex,
+ indexToLoc: indexToLoc,
+ isPromise: isPromise,
+ hexToLong: hexToLong,
+ fingerPrint64: fingerPrint64
+ };
+
+ var Profiler = /** @class */ (function () {
+ function Profiler(backendTimer, logger) {
+ this.backendTimer = backendTimer;
+ this.logger = logger;
+ if (logger == null) {
+ this.logger = new Logger();
+ }
+ }
+ Profiler.prototype.profileKernel = function (kernelName, inputs, f) {
+ var e_1, _a;
+ var outputs;
+ var holdResultWrapperFn = function () {
+ outputs = f();
+ };
+ var timer;
+ var start = now();
+ if (this.backendTimer.timerAvailable()) {
+ timer = this.backendTimer.time(holdResultWrapperFn);
+ }
+ else {
+ holdResultWrapperFn();
+ try {
+ for (var outputs_1 = __values(outputs), outputs_1_1 = outputs_1.next(); !outputs_1_1.done; outputs_1_1 = outputs_1.next()) {
+ var output = outputs_1_1.value;
+ output.dataSync();
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (outputs_1_1 && !outputs_1_1.done && (_a = outputs_1.return)) _a.call(outputs_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ timer = Promise.resolve({ kernelMs: now() - start });
+ }
+ if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
+ var _loop_1 = function (i) {
+ var output = outputs[i];
+ // Dangling promise here because we don't want to propagate up
+ // asynchronicity.
+ output.data().then(function (tensorVals) {
+ checkComputationForErrors(tensorVals, output.dtype, kernelName);
+ });
+ };
+ for (var i = 0; i < outputs.length; i++) {
+ _loop_1(i);
+ }
+ }
+ var kernelProfile = {
+ kernelName: kernelName,
+ outputs: outputs,
+ inputs: inputs,
+ timeMs: timer.then(function (timing) { return timing.kernelMs; }),
+ extraInfo: timer.then(function (timing) { return timing.getExtraProfileInfo != null ?
+ timing.getExtraProfileInfo() :
+ ''; })
+ };
+ return kernelProfile;
+ };
+ Profiler.prototype.logKernelProfile = function (kernelProfile) {
+ var _this = this;
+ var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo;
+ outputs.forEach(function (result) {
+ Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) {
+ _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
+ });
+ });
+ };
+ return Profiler;
+ }());
+ function checkComputationForErrors(vals, dtype, kernelName) {
+ if (dtype !== 'float32') {
+ // Only floating point computations will generate NaN values
+ return false;
+ }
+ for (var i = 0; i < vals.length; i++) {
+ var num = vals[i];
+ if (isNaN(num) || !isFinite(num)) {
+ // Throwing custom exception so behavior is testable.
+ console.warn("Found " + num + " in the result of '" + kernelName + "'");
+ return true;
+ }
+ }
+ return false;
+ }
+ var Logger = /** @class */ (function () {
+ function Logger() {
+ }
+ Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) {
+ var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) :
+ timeMs['error'];
+ var paddedName = rightPad(name, 25);
+ var rank = result.rank;
+ var size = result.size;
+ var shape = rightPad(result.shape.toString(), 14);
+ var inputShapesDescription = '';
+ for (var name_1 in inputs) {
+ var input = inputs[name_1];
+ if (input != null) {
+ // The input might be a non-tensor (e.g HTMLImageElement), in which case
+ // we claim the output shape as input shape.
+ var inputShape = input.shape || result.shape;
+ var inputRank = inputShape.length;
+ inputShapesDescription +=
+ name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " ";
+ }
+ }
+ console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
+ };
+ return Logger;
+ }());
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes a list of TapeNodes that connect x to y, filtering everything else
+ * out and preserving the order of the original tape elements.
+ *
+ * @param tape The tape elements to filter.
+ * @param xs The input Tensors.
+ * @param y The output Tensor.
+ */
+ function getFilteredNodesXToY(tape, xs, y) {
+ // Forward pass to compute all the nodes and Tensors that are transitively a
+ // function of x.
+ var tensorsFromX = {};
+ var nodesFromX = {};
+ for (var i = 0; i < xs.length; i++) {
+ tensorsFromX[xs[i].id] = true;
+ }
+ for (var i = 0; i < tape.length; i++) {
+ var node = tape[i];
+ var nodeInputs = node.inputs;
+ for (var inputName in nodeInputs) {
+ var input = nodeInputs[inputName];
+ var anyInputFromX = false;
+ for (var j = 0; j < xs.length; j++) {
+ if (tensorsFromX[input.id]) {
+ node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; });
+ anyInputFromX = true;
+ nodesFromX[node.id] = true;
+ break;
+ }
+ }
+ if (anyInputFromX) {
+ break;
+ }
+ }
+ }
+ // Backward pass to find all of the nodes and Tensors that lead to y.
+ var tensorsLeadToY = {};
+ tensorsLeadToY[y.id] = true;
+ var nodesToY = {};
+ for (var i = tape.length - 1; i >= 0; i--) {
+ var node = tape[i];
+ var nodeInputs = node.inputs;
+ // If any of the outputs lead to y, mark all of the inputs as leading to y.
+ for (var j = 0; j < node.outputs.length; j++) {
+ if (tensorsLeadToY[node.outputs[j].id]) {
+ for (var inputName in nodeInputs) {
+ tensorsLeadToY[nodeInputs[inputName].id] = true;
+ nodesToY[node.id] = true;
+ }
+ break;
+ }
+ }
+ }
+ // Return the paths that come from x and lead to y.
+ var filteredTape = [];
+ for (var i = 0; i < tape.length; i++) {
+ var node = tape[i];
+ if (nodesFromX[node.id] && nodesToY[node.id]) {
+ // Prune the inputs from the node that aren't a function of x.
+ var prunedInputs = {};
+ for (var inputName in node.inputs) {
+ var nodeInput = node.inputs[inputName];
+ if (tensorsFromX[nodeInput.id]) {
+ prunedInputs[inputName] = nodeInput;
+ }
+ }
+ // Copy the node and overwrite inputsAndArgs to the pruned version.
+ var prunedNode = Object.assign({}, node);
+ prunedNode.inputs = prunedInputs;
+ prunedNode.outputs = node.outputs;
+ filteredTape.push(prunedNode);
+ }
+ }
+ return filteredTape;
+ }
+ /**
+ * Backpropagate gradients through the filtered TapeNodes.
+ *
+ * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
+ * is mutated by this method.
+ * @param filteredTape The filtered TapeNodes to backprop through.
+ */
+ function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
+ var _loop_1 = function (i) {
+ var node = filteredTape[i];
+ var dys = [];
+ node.outputs.forEach(function (o) {
+ var gradTensor = tensorAccumulatedGradientMap[o.id];
+ if (gradTensor != null) {
+ dys.push(gradTensor);
+ }
+ else {
+ // This particular output is not in the back-propagation subgraph, so it
+ // does not affect the final output, thus we put null for its dy.
+ dys.push(null);
+ }
+ });
+ if (node.gradient == null) {
+ throw new Error("Cannot compute gradient: gradient function not found " +
+ ("for " + node.kernelName + "."));
+ }
+ // Backprop dy through this node and accumulate gradients over the inputs.
+ var inputGradients = node.gradient(dys);
+ var _loop_2 = function (inputName) {
+ if (!(inputName in inputGradients)) {
+ throw new Error("Cannot backprop through input " + inputName + ". " +
+ ("Available gradients found: " + Object.keys(inputGradients) + "."));
+ }
+ // Call the gradient function.
+ var dx = tidy(function () { return inputGradients[inputName](); });
+ if (dx.dtype !== 'float32') {
+ throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
+ (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
+ }
+ var x = node.inputs[inputName];
+ if (!arraysEqual(dx.shape, x.shape)) {
+ throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
+ ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") +
+ ("the shape of the input '" + x.shape + "'"));
+ }
+ if (tensorAccumulatedGradientMap[x.id] == null) {
+ tensorAccumulatedGradientMap[x.id] = dx;
+ }
+ else {
+ var curGradient = tensorAccumulatedGradientMap[x.id];
+ tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
+ curGradient.dispose();
+ }
+ };
+ for (var inputName in node.inputs) {
+ _loop_2(inputName);
+ }
+ };
+ // Walk the tape backward and keep a map of Tensor to its gradient.
+ for (var i = filteredTape.length - 1; i >= 0; i--) {
+ _loop_1(i);
+ }
+ }
+
+ // Maximum number of values before we decide to show ellipsis.
+ var FORMAT_LIMIT_NUM_VALS = 20;
+ // Number of first and last values to show when displaying a, b,...,y, z.
+ var FORMAT_NUM_FIRST_LAST_VALS = 3;
+ // Number of significant digits to show.
+ var FORMAT_NUM_SIG_DIGITS = 7;
+ function tensorToString(vals, shape, dtype, verbose) {
+ var strides = computeStrides(shape);
+ var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
+ var rank = shape.length;
+ var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
+ var lines = ['Tensor'];
+ if (verbose) {
+ lines.push(" dtype: " + dtype);
+ lines.push(" rank: " + rank);
+ lines.push(" shape: [" + shape + "]");
+ lines.push(" values:");
+ }
+ lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n'));
+ return lines.join('\n');
+ }
+ function computeMaxSizePerColumn(vals, shape, dtype, strides) {
+ var n = sizeFromShape(shape);
+ var numCols = strides[strides.length - 1];
+ var padPerCol = new Array(numCols).fill(0);
+ var rank = shape.length;
+ var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
+ if (rank > 1) {
+ for (var row = 0; row < n / numCols; row++) {
+ var offset = row * numCols;
+ for (var j = 0; j < numCols; j++) {
+ padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
+ }
+ }
+ }
+ return padPerCol;
+ }
+ function valToString(val, pad, dtype) {
+ var valStr;
+ if (Array.isArray(val)) {
+ valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " +
+ (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j");
+ }
+ else if (isString(val)) {
+ valStr = "'" + val + "'";
+ }
+ else if (dtype === 'bool') {
+ valStr = boolNumToString(val);
+ }
+ else {
+ valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
+ }
+ return rightPad(valStr, pad);
+ }
+ function boolNumToString(v) {
+ return v === 0 ? 'false' : 'true';
+ }
+ function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) {
+ if (isLast === void 0) { isLast = true; }
+ var storagePerElement = dtype === 'complex64' ? 2 : 1;
+ var size = shape[0];
+ var rank = shape.length;
+ if (rank === 0) {
+ if (dtype === 'complex64') {
+ var complexTuple = createComplexTuples(vals);
+ return [valToString(complexTuple[0], 0, dtype)];
+ }
+ if (dtype === 'bool') {
+ return [boolNumToString(vals[0])];
+ }
+ return [vals[0].toString()];
+ }
+ if (rank === 1) {
+ if (size > FORMAT_LIMIT_NUM_VALS) {
+ var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
+ var firstVals = Array.from(vals.slice(0, firstValsSize));
+ var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
+ if (dtype === 'complex64') {
+ firstVals = createComplexTuples(firstVals);
+ lastVals = createComplexTuples(lastVals);
+ }
+ return [
+ '[' +
+ firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); })
+ .join(', ') +
+ ', ..., ' +
+ lastVals
+ .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); })
+ .join(', ') +
+ ']'
+ ];
+ }
+ var displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
+ Array.from(vals);
+ return [
+ '[' +
+ displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); })
+ .join(', ') +
+ ']'
+ ];
+ }
+ // The array is rank 2 or more.
+ var subshape = shape.slice(1);
+ var substrides = strides.slice(1);
+ var stride = strides[0] * storagePerElement;
+ var lines = [];
+ if (size > FORMAT_LIMIT_NUM_VALS) {
+ for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, __spread(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)));
+ }
+ lines.push('...');
+ for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, __spread(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)));
+ }
+ }
+ else {
+ for (var i = 0; i < size; i++) {
+ var start = i * stride;
+ var end = start + stride;
+ lines.push.apply(lines, __spread(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)));
+ }
+ }
+ var sep = rank === 2 ? ',' : '';
+ lines[0] = '[' + lines[0] + sep;
+ for (var i = 1; i < lines.length - 1; i++) {
+ lines[i] = ' ' + lines[i] + sep;
+ }
+ var newLineSep = ',\n';
+ for (var i = 2; i < rank; i++) {
+ newLineSep += '\n';
+ }
+ lines[lines.length - 1] =
+ ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
+ return lines;
+ }
+ function createComplexTuples(vals) {
+ var complexTuples = [];
+ for (var i = 0; i < vals.length; i += 2) {
+ complexTuples.push([vals[i], vals[i + 1]]);
+ }
+ return complexTuples;
+ }
+
+ /**
+ * A mutable object, similar to `tf.Tensor`, that allows users to set values
+ * at locations before converting to an immutable `tf.Tensor`.
+ *
+ * See `tf.buffer` for creating a tensor buffer.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ var TensorBuffer = /** @class */ (function () {
+ function TensorBuffer(shape, dtype, values) {
+ var _this = this;
+ this.dtype = dtype;
+ this.shape = shape.slice();
+ this.size = sizeFromShape(shape);
+ if (values != null) {
+ var n_1 = values.length;
+ assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " +
+ ("inferred by the shape '" + _this.size + "'."); });
+ }
+ if (dtype === 'complex64') {
+ throw new Error("complex64 dtype TensorBuffers are not supported. Please create " +
+ "a TensorBuffer for the real and imaginary parts separately and " +
+ "call tf.complex(real, imag).");
+ }
+ this.values = values || getArrayFromDType(dtype, this.size);
+ this.strides = computeStrides(shape);
+ }
+ /**
+ * Sets a value in the buffer at a given location.
+ *
+ * @param value The value to set.
+ * @param locs The location indices.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ TensorBuffer.prototype.set = function (value) {
+ var _this = this;
+ var locs = [];
+ for (var _i = 1; _i < arguments.length; _i++) {
+ locs[_i - 1] = arguments[_i];
+ }
+ if (locs.length === 0) {
+ locs = [0];
+ }
+ assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " +
+ ("match the rank (" + _this.rank + ")"); });
+ var index = this.locToIndex(locs);
+ this.values[index] = value;
+ };
+ /**
+ * Returns the value in the buffer at the provided location.
+ *
+ * @param locs The location indices.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ TensorBuffer.prototype.get = function () {
+ var e_1, _b;
+ var locs = [];
+ for (var _i = 0; _i < arguments.length; _i++) {
+ locs[_i] = arguments[_i];
+ }
+ if (locs.length === 0) {
+ locs = [0];
+ }
+ var i = 0;
+ try {
+ for (var locs_1 = __values(locs), locs_1_1 = locs_1.next(); !locs_1_1.done; locs_1_1 = locs_1.next()) {
+ var loc = locs_1_1.value;
+ if (loc < 0 || loc >= this.shape[i]) {
+ var msg = "Requested out of range element at " + locs + ". " +
+ (" Buffer shape=" + this.shape);
+ throw new Error(msg);
+ }
+ i++;
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (locs_1_1 && !locs_1_1.done && (_b = locs_1.return)) _b.call(locs_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ var index = locs[locs.length - 1];
+ for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) {
+ index += this.strides[i_1] * locs[i_1];
+ }
+ return this.values[index];
+ };
+ TensorBuffer.prototype.locToIndex = function (locs) {
+ if (this.rank === 0) {
+ return 0;
+ }
+ else if (this.rank === 1) {
+ return locs[0];
+ }
+ var index = locs[locs.length - 1];
+ for (var i = 0; i < locs.length - 1; ++i) {
+ index += this.strides[i] * locs[i];
+ }
+ return index;
+ };
+ TensorBuffer.prototype.indexToLoc = function (index) {
+ if (this.rank === 0) {
+ return [];
+ }
+ else if (this.rank === 1) {
+ return [index];
+ }
+ var locs = new Array(this.shape.length);
+ for (var i = 0; i < locs.length - 1; ++i) {
+ locs[i] = Math.floor(index / this.strides[i]);
+ index -= locs[i] * this.strides[i];
+ }
+ locs[locs.length - 1] = index;
+ return locs;
+ };
+ Object.defineProperty(TensorBuffer.prototype, "rank", {
+ get: function () {
+ return this.shape.length;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ /**
+ * Creates an immutable `tf.Tensor` object from the buffer.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ TensorBuffer.prototype.toTensor = function () {
+ return trackerFn().makeTensor(this.values, this.shape, this.dtype);
+ };
+ return TensorBuffer;
+ }());
+ // For tracking tensor creation and disposal.
+ var trackerFn = null;
+ // Used by chaining methods to call into ops.
+ var opHandler$1 = null;
+ /**
+ * An external consumer can register itself as the tensor tracker. This way
+ * the Tensor class can notify the tracker for every tensor created and
+ * disposed.
+ */
+ function setTensorTracker(fn) {
+ trackerFn = fn;
+ }
+ /**
+ * An external consumer can register itself as the op handler. This way the
+ * Tensor class can have chaining methods that call into ops via the op
+ * handler.
+ */
+ function setOpHandler(handler) {
+ opHandler$1 = handler;
+ }
+ /**
+ * A `tf.Tensor` object represents an immutable, multidimensional array of
+ * numbers that has a shape and a data type.
+ *
+ * For performance reasons, functions that create tensors do not necessarily
+ * perform a copy of the data passed to them (e.g. if the data is passed as a
+ * `Float32Array`), and changes to the data will change the tensor. This is not
+ * a feature and is not supported. To avoid this behavior, use the tensor before
+ * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`.
+ *
+ * See `tf.tensor` for details on how to create a `tf.Tensor`.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ var Tensor = /** @class */ (function () {
+ function Tensor(shape, dtype, dataId, id) {
+ /** Whether this tensor has been globally kept. */
+ this.kept = false;
+ this.isDisposedInternal = false;
+ this.shape = shape.slice();
+ this.dtype = dtype || 'float32';
+ this.size = sizeFromShape(shape);
+ this.strides = computeStrides(shape);
+ this.dataId = dataId;
+ this.id = id;
+ this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
+ }
+ Object.defineProperty(Tensor.prototype, "rank", {
+ get: function () {
+ return this.shape.length;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ /**
+ * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.buffer = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var vals;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0: return [4 /*yield*/, this.data()];
+ case 1:
+ vals = _b.sent();
+ return [2 /*return*/, opHandler$1.buffer(this.shape, this.dtype, vals)];
+ }
+ });
+ });
+ };
+ /**
+ * Returns a `tf.TensorBuffer` that holds the underlying data.
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.bufferSync = function () {
+ return opHandler$1.buffer(this.shape, this.dtype, this.dataSync());
+ };
+ /**
+ * Returns the tensor data as a nested array. The transfer of data is done
+ * asynchronously.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.array = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var vals;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0: return [4 /*yield*/, this.data()];
+ case 1:
+ vals = _b.sent();
+ return [2 /*return*/, toNestedArray(this.shape, vals, this.dtype === 'complex64')];
+ }
+ });
+ });
+ };
+ /**
+ * Returns the tensor data as a nested array. The transfer of data is done
+ * synchronously.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.arraySync = function () {
+ return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
+ };
+ /**
+ * Asynchronously downloads the values from the `tf.Tensor`. Returns a
+ * promise of `TypedArray` that resolves when the computation has finished.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.data = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var data, bytes;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ this.throwIfDisposed();
+ data = trackerFn().read(this.dataId);
+ if (!(this.dtype === 'string')) return [3 /*break*/, 2];
+ return [4 /*yield*/, data];
+ case 1:
+ bytes = _b.sent();
+ try {
+ return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })];
+ }
+ catch (_a) {
+ throw new Error('Failed to decode the string bytes into utf-8. ' +
+ 'To get the original bytes, call tensor.bytes().');
+ }
+ _b.label = 2;
+ case 2: return [2 /*return*/, data];
+ }
+ });
+ });
+ };
+ /**
+ * Synchronously downloads the values from the `tf.Tensor`. This blocks the
+ * UI thread until the values are ready, which can cause performance issues.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.dataSync = function () {
+ this.throwIfDisposed();
+ var data = trackerFn().readSync(this.dataId);
+ if (this.dtype === 'string') {
+ try {
+ return data.map(function (b) { return decodeString(b); });
+ }
+ catch (_a) {
+ throw new Error('Failed to decode the string bytes into utf-8. ' +
+ 'To get the original bytes, call tensor.bytes().');
+ }
+ }
+ return data;
+ };
+ /** Returns the underlying bytes of the tensor's data. */
+ Tensor.prototype.bytes = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var data;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ this.throwIfDisposed();
+ return [4 /*yield*/, trackerFn().read(this.dataId)];
+ case 1:
+ data = _b.sent();
+ if (this.dtype === 'string') {
+ return [2 /*return*/, data];
+ }
+ else {
+ return [2 /*return*/, new Uint8Array(data.buffer)];
+ }
+ }
+ });
+ });
+ };
+ /**
+ * Disposes `tf.Tensor` from memory.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.dispose = function () {
+ if (this.isDisposed) {
+ return;
+ }
+ trackerFn().disposeTensor(this);
+ this.isDisposedInternal = true;
+ };
+ Object.defineProperty(Tensor.prototype, "isDisposed", {
+ get: function () {
+ return this.isDisposedInternal;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Tensor.prototype.throwIfDisposed = function () {
+ if (this.isDisposed) {
+ throw new Error("Tensor is disposed.");
+ }
+ };
+ /**
+ * Prints the `tf.Tensor`. See `tf.print` for details.
+ *
+ * @param verbose Whether to print verbose information about the tensor,
+ * including dtype and size.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.print = function (verbose) {
+ if (verbose === void 0) { verbose = false; }
+ return opHandler$1.print(this, verbose);
+ };
+ /**
+ * Returns a copy of the tensor. See `tf.clone` for details.
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.clone = function () {
+ this.throwIfDisposed();
+ return opHandler$1.clone(this);
+ };
+ /**
+ * Returns a human-readable description of the tensor. Useful for logging.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Tensor.prototype.toString = function (verbose) {
+ if (verbose === void 0) { verbose = false; }
+ var vals = this.dataSync();
+ return tensorToString(vals, this.shape, this.dtype, verbose);
+ };
+ Tensor.prototype.cast = function (dtype) {
+ this.throwIfDisposed();
+ return opHandler$1.cast(this, dtype);
+ };
+ Tensor.prototype.variable = function (trainable, name, dtype) {
+ if (trainable === void 0) { trainable = true; }
+ this.throwIfDisposed();
+ return trackerFn().makeVariable(this, trainable, name, dtype);
+ };
+ return Tensor;
+ }());
+ Object.defineProperty(Tensor, Symbol.hasInstance, {
+ value: function (instance) {
+ // Implementation note: we should use properties of the object that will be
+ // defined before the constructor body has finished executing (methods).
+ // This is because when this code is transpiled by babel, babel will call
+ // classCallCheck before the constructor body is run.
+ // See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
+ return !!instance && instance.data != null && instance.dataSync != null &&
+ instance.throwIfDisposed != null;
+ }
+ });
+ function getGlobalTensorClass() {
+ // Use getGlobal so that we can augment the Tensor class across package
+ // boundaries becase the node resolution alg may result in different modules
+ // being returned for this file depending on the path they are loaded from.
+ return getGlobal('Tensor', function () {
+ return Tensor;
+ });
+ }
+ // Global side effect. Cache global reference to Tensor class
+ getGlobalTensorClass();
+ /**
+ * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ var Variable = /** @class */ (function (_super) {
+ __extends(Variable, _super);
+ function Variable(initialValue, trainable, name, tensorId) {
+ var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this;
+ _this.trainable = trainable;
+ _this.name = name;
+ return _this;
+ }
+ /**
+ * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
+ * the same shape and dtype as the old `tf.Tensor`.
+ *
+ * @param newValue New tensor to be assigned to this variable.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Classes'}
+ */
+ Variable.prototype.assign = function (newValue) {
+ if (newValue.dtype !== this.dtype) {
+ throw new Error("dtype of the new value (" + newValue.dtype + ") and " +
+ ("previous value (" + this.dtype + ") must match"));
+ }
+ if (!arraysEqual(newValue.shape, this.shape)) {
+ throw new Error("shape of the new value (" + newValue.shape + ") and " +
+ ("previous value (" + this.shape + ") must match"));
+ }
+ trackerFn().disposeTensor(this);
+ this.dataId = newValue.dataId;
+ trackerFn().incRef(this, null /* backend */);
+ };
+ Variable.prototype.dispose = function () {
+ trackerFn().disposeVariable(this);
+ this.isDisposedInternal = true;
+ };
+ return Variable;
+ }(Tensor));
+ Object.defineProperty(Variable, Symbol.hasInstance, {
+ value: function (instance) {
+ return instance instanceof Tensor && instance.assign != null &&
+ instance.assign instanceof Function;
+ }
+ });
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ exports.Rank = void 0;
+ (function (Rank) {
+ Rank["R0"] = "R0";
+ Rank["R1"] = "R1";
+ Rank["R2"] = "R2";
+ Rank["R3"] = "R3";
+ Rank["R4"] = "R4";
+ Rank["R5"] = "R5";
+ Rank["R6"] = "R6";
+ })(exports.Rank || (exports.Rank = {}));
+ // Looks for upcasting types. Used, for example, in operations with mixed dtype
+ // inputs.
+ var UpcastInt32AndMap;
+ (function (UpcastInt32AndMap) {
+ UpcastInt32AndMap["float32"] = "float32";
+ UpcastInt32AndMap["int32"] = "int32";
+ UpcastInt32AndMap["bool"] = "int32";
+ UpcastInt32AndMap["complex64"] = "complex64";
+ })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
+ var UpcastBoolAndMap;
+ (function (UpcastBoolAndMap) {
+ UpcastBoolAndMap["float32"] = "float32";
+ UpcastBoolAndMap["int32"] = "int32";
+ UpcastBoolAndMap["bool"] = "bool";
+ UpcastBoolAndMap["complex64"] = "complex64";
+ })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
+ var UpcastFloat32AndMap;
+ (function (UpcastFloat32AndMap) {
+ UpcastFloat32AndMap["float32"] = "float32";
+ UpcastFloat32AndMap["int32"] = "float32";
+ UpcastFloat32AndMap["bool"] = "float32";
+ UpcastFloat32AndMap["complex64"] = "complex64";
+ })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
+ var UpcastComplex64AndMap;
+ (function (UpcastComplex64AndMap) {
+ UpcastComplex64AndMap["float32"] = "complex64";
+ UpcastComplex64AndMap["int32"] = "complex64";
+ UpcastComplex64AndMap["bool"] = "complex64";
+ UpcastComplex64AndMap["complex64"] = "complex64";
+ })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
+ var upcastTypeMap = {
+ 'float32': UpcastFloat32AndMap,
+ 'int32': UpcastInt32AndMap,
+ 'bool': UpcastBoolAndMap,
+ 'complex64': UpcastComplex64AndMap
+ };
+ function upcastType(typeA, typeB) {
+ if (typeA === 'string' || typeB === 'string') {
+ if (typeA === 'string' && typeB === 'string') {
+ return 'string';
+ }
+ throw new Error("Can not upcast " + typeA + " with " + typeB);
+ }
+ return upcastTypeMap[typeA][typeB];
+ }
+ /** Returns the output type after summation. */
+ function sumOutType(type) {
+ return upcastType(type, 'int32');
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function makeTypesMatch(a, b) {
+ if (a.dtype === b.dtype) {
+ return [a, b];
+ }
+ var dtype = upcastType(a.dtype, b.dtype);
+ return [a.cast(dtype), b.cast(dtype)];
+ }
+ function assertTypesMatch(a, b) {
+ assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" +
+ (" second(" + b.dtype + ") input must match"); });
+ }
+ function isTensorInList(tensor, tensorList) {
+ return tensorList.some(function (x) { return x.id === tensor.id; });
+ }
+ /**
+ * Extracts any `Tensor`s found within the provided object.
+ *
+ * @param container an object that may be a `Tensor` or may directly contain
+ * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
+ * is safe to pass any object here, except that `Promise`s are not
+ * supported.
+ * @returns An array of `Tensors` found within the passed object. If the
+ * argument is simply a `Tensor', a list containing that `Tensor` is
+ * returned. If the object is not a `Tensor` or does not
+ * contain `Tensors`, an empty list is returned.
+ */
+ function getTensorsInContainer(result) {
+ var list = [];
+ var seen = new Set();
+ walkTensorContainer(result, list, seen);
+ return list;
+ }
+ function walkTensorContainer(container, list, seen) {
+ if (container == null) {
+ return;
+ }
+ if (container instanceof Tensor) {
+ list.push(container);
+ return;
+ }
+ if (!isIterable(container)) {
+ return;
+ }
+ // Iteration over keys works also for arrays.
+ var iterable = container;
+ for (var k in iterable) {
+ var val = iterable[k];
+ if (!seen.has(val)) {
+ seen.add(val);
+ walkTensorContainer(val, list, seen);
+ }
+ }
+ }
+ // tslint:disable-next-line:no-any
+ function isIterable(obj) {
+ return Array.isArray(obj) || typeof obj === 'object';
+ }
+
+ var tensor_util = {
+ __proto__: null,
+ makeTypesMatch: makeTypesMatch,
+ assertTypesMatch: assertTypesMatch,
+ isTensorInList: isTensorInList,
+ getTensorsInContainer: getTensorsInContainer
+ };
+
+ function isRegisteredKernelInvocation(kernelInvocation) {
+ return kernelInvocation.kernelName != null;
+ }
+ var EngineState = /** @class */ (function () {
+ function EngineState() {
+ // Public since optimizers will use it.
+ this.registeredVariables = {};
+ this.nextTapeNodeId = 0;
+ this.numBytes = 0;
+ this.numTensors = 0;
+ this.numStringTensors = 0;
+ this.numDataBuffers = 0;
+ // Number of nested tf.grad() statements when computing higher-order
+ // gradients. E.g. `1` for first-order gradients and `2` for second-order
+ // gradients. Used to track if the tape should be removed after a backprop.
+ this.gradientDepth = 0;
+ // Number of nested kernel calls. When kernel depth is greater than 1, we turn
+ // off the tape.
+ this.kernelDepth = 0;
+ this.scopeStack = [];
+ /**
+ * Keeps track of the number of data moves during a kernel execution. We
+ * maintain a stack since kernels can call other kernels, recursively.
+ */
+ this.numDataMovesStack = [];
+ this.nextScopeId = 0;
+ this.tensorInfo = new WeakMap();
+ this.profiling = false;
+ this.activeProfile = {
+ newBytes: 0,
+ newTensors: 0,
+ peakBytes: 0,
+ kernels: [],
+ result: null,
+ get kernelNames() {
+ return Array.from(new Set(this.kernels.map(function (k) { return k.name; })));
+ }
+ };
+ }
+ EngineState.prototype.dispose = function () {
+ for (var variableName in this.registeredVariables) {
+ this.registeredVariables[variableName].dispose();
+ }
+ };
+ return EngineState;
+ }());
+ var Engine = /** @class */ (function () {
+ function Engine(ENV) {
+ this.ENV = ENV;
+ this.registry = {};
+ this.registryFactory = {};
+ this.pendingBackendInitId = 0;
+ this.state = new EngineState();
+ }
+ Engine.prototype.ready = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var sortedBackends, i, backendName, success;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ if (this.pendingBackendInit != null) {
+ return [2 /*return*/, this.pendingBackendInit.then(function () { })];
+ }
+ if (this.backendInstance != null) {
+ return [2 /*return*/];
+ }
+ sortedBackends = this.getSortedBackends();
+ i = 0;
+ _a.label = 1;
+ case 1:
+ if (!(i < sortedBackends.length)) return [3 /*break*/, 5];
+ backendName = sortedBackends[i];
+ return [4 /*yield*/, this.initializeBackend(backendName).success];
+ case 2:
+ success = _a.sent();
+ if (!success) return [3 /*break*/, 4];
+ return [4 /*yield*/, this.setBackend(backendName)];
+ case 3:
+ _a.sent();
+ return [2 /*return*/];
+ case 4:
+ i++;
+ return [3 /*break*/, 1];
+ case 5: throw new Error("Could not initialize any backends, all backend initializations " +
+ "failed.");
+ }
+ });
+ });
+ };
+ Object.defineProperty(Engine.prototype, "backend", {
+ get: function () {
+ if (this.pendingBackendInit != null) {
+ throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " +
+ "sure to await tf.ready() or await tf.setBackend() before calling " +
+ "other methods");
+ }
+ if (this.backendInstance == null) {
+ var _a = this.initializeBackendsAndReturnBest(), name = _a.name, asyncInit = _a.asyncInit;
+ if (asyncInit) {
+ throw new Error("The highest priority backend '" + name + "' has not yet been " +
+ "initialized. Make sure to await tf.ready() or " +
+ "await tf.setBackend() before calling other methods");
+ }
+ this.setBackend(name);
+ }
+ return this.backendInstance;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Engine.prototype.backendNames = function () {
+ return Object.keys(this.registryFactory);
+ };
+ Engine.prototype.findBackend = function (backendName) {
+ if (!(backendName in this.registry)) {
+ // If the backend hasn't been initialized but we have a registry entry for
+ // it, initialize it and return it.
+ if (backendName in this.registryFactory) {
+ var asyncInit = this.initializeBackend(backendName).asyncInit;
+ if (asyncInit) {
+ // Backend is not ready yet.
+ return null;
+ }
+ }
+ else {
+ return null;
+ }
+ }
+ return this.registry[backendName];
+ };
+ Engine.prototype.findBackendFactory = function (backendName) {
+ if (!(backendName in this.registryFactory)) {
+ return null;
+ }
+ return this.registryFactory[backendName].factory;
+ };
+ Engine.prototype.registerBackend = function (backendName, factory, priority) {
+ if (priority === void 0) { priority = 1; }
+ if (backendName in this.registryFactory) {
+ warn(backendName + " backend was already registered. " +
+ "Reusing existing backend factory.");
+ return false;
+ }
+ this.registryFactory[backendName] = { factory: factory, priority: priority };
+ return true;
+ };
+ Engine.prototype.setBackend = function (backendName) {
+ return __awaiter(this, void 0, void 0, function () {
+ var _a, success, asyncInit, result, _b;
+ return __generator(this, function (_c) {
+ switch (_c.label) {
+ case 0:
+ if (this.registryFactory[backendName] == null) {
+ throw new Error("Backend name '" + backendName + "' not found in registry");
+ }
+ this.backendName = backendName;
+ if (!(this.registry[backendName] == null)) return [3 /*break*/, 4];
+ this.backendInstance = null;
+ _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
+ if (!asyncInit) return [3 /*break*/, 2];
+ return [4 /*yield*/, success];
+ case 1:
+ _b = _c.sent();
+ return [3 /*break*/, 3];
+ case 2:
+ _b = success;
+ _c.label = 3;
+ case 3:
+ result = _b;
+ if (!result) {
+ return [2 /*return*/, false];
+ }
+ _c.label = 4;
+ case 4:
+ this.backendInstance = this.registry[backendName];
+ this.setupRegisteredKernels();
+ // Reset the profiler.
+ this.profiler = new Profiler(this.backendInstance);
+ return [2 /*return*/, true];
+ }
+ });
+ });
+ };
+ Engine.prototype.setupRegisteredKernels = function () {
+ var _this = this;
+ var kernels = getKernelsForBackend(this.backendName);
+ kernels.forEach(function (kernel) {
+ if (kernel.setupFunc != null) {
+ kernel.setupFunc(_this.backendInstance);
+ }
+ });
+ };
+ Engine.prototype.disposeRegisteredKernels = function (backendName) {
+ var _this = this;
+ var kernels = getKernelsForBackend(backendName);
+ kernels.forEach(function (kernel) {
+ if (kernel.disposeFunc != null) {
+ kernel.disposeFunc(_this.registry[backendName]);
+ }
+ });
+ };
+ /**
+ * Initializes a backend by looking up the backend name in the factory
+ * registry and calling the factory method. Returns a boolean representing
+ * whether the initialization of the backend suceeded. Throws an error if
+ * there is no backend in the factory registry.
+ */
+ Engine.prototype.initializeBackend = function (backendName) {
+ var _this = this;
+ var registryFactoryEntry = this.registryFactory[backendName];
+ if (registryFactoryEntry == null) {
+ throw new Error("Cannot initialize backend " + backendName + ", no registration found.");
+ }
+ try {
+ var backend = registryFactoryEntry.factory();
+ /* Test if the factory returns a promise.
+ Done in a more liberal way than
+ previous 'Promise.resolve(backend)===backend'
+ as we needed to account for custom Promise
+ implementations (e.g. Angular) */
+ if (backend && !(backend instanceof KernelBackend) &&
+ typeof backend.then === 'function') {
+ var promiseId_1 = ++this.pendingBackendInitId;
+ var success = backend
+ .then(function (backendInstance) {
+ // Outdated promise. Another backend was set in the meantime.
+ if (promiseId_1 < _this.pendingBackendInitId) {
+ return false;
+ }
+ _this.registry[backendName] = backendInstance;
+ _this.pendingBackendInit = null;
+ return true;
+ })
+ .catch(function (err) {
+ // Outdated promise. Another backend was set in the meantime.
+ if (promiseId_1 < _this.pendingBackendInitId) {
+ return false;
+ }
+ _this.pendingBackendInit = null;
+ warn("Initialization of backend " + backendName + " failed");
+ warn(err.stack || err.message);
+ return false;
+ });
+ this.pendingBackendInit = success;
+ return { success: success, asyncInit: true };
+ }
+ else {
+ this.registry[backendName] = backend;
+ return { success: true, asyncInit: false };
+ }
+ }
+ catch (err) {
+ warn("Initialization of backend " + backendName + " failed");
+ warn(err.stack || err.message);
+ return { success: false, asyncInit: false };
+ }
+ };
+ Engine.prototype.removeBackend = function (backendName) {
+ if (!(backendName in this.registryFactory)) {
+ throw new Error(backendName + " backend not found in registry");
+ }
+ if (this.backendName === backendName && this.pendingBackendInit != null) {
+ // There is a pending promise of the backend we want to remove. Make it
+ // obsolete.
+ this.pendingBackendInitId++;
+ }
+ if (backendName in this.registry) {
+ this.disposeRegisteredKernels(backendName);
+ this.registry[backendName].dispose();
+ delete this.registry[backendName];
+ }
+ delete this.registryFactory[backendName];
+ // Unset the backend if it is active.
+ if (this.backendName === backendName) {
+ this.pendingBackendInit = null;
+ this.backendName = null;
+ this.backendInstance = null;
+ }
+ };
+ Engine.prototype.getSortedBackends = function () {
+ var _this = this;
+ if (Object.keys(this.registryFactory).length === 0) {
+ throw new Error('No backend found in registry.');
+ }
+ return Object.keys(this.registryFactory).sort(function (a, b) {
+ // Highest priority comes first.
+ return _this.registryFactory[b].priority -
+ _this.registryFactory[a].priority;
+ });
+ };
+ Engine.prototype.initializeBackendsAndReturnBest = function () {
+ var sortedBackends = this.getSortedBackends();
+ for (var i = 0; i < sortedBackends.length; i++) {
+ var backendName = sortedBackends[i];
+ var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
+ if (asyncInit || success) {
+ return { name: backendName, asyncInit: asyncInit };
+ }
+ }
+ throw new Error("Could not initialize any backends, all backend initializations " +
+ "failed.");
+ };
+ Engine.prototype.moveData = function (backend, dataId) {
+ var info = this.state.tensorInfo.get(dataId);
+ var srcBackend = info.backend;
+ var values = this.readSync(dataId);
+ var refCount = srcBackend.refCount(dataId);
+ // Delete the tensor from the old backend and move it to the new
+ // backend.
+ srcBackend.disposeData(dataId, true);
+ info.backend = backend;
+ backend.move(dataId, values, info.shape, info.dtype, refCount);
+ if (this.shouldCheckForMemLeaks()) {
+ // Track the number of moves during a kernel execution to correctly
+ // detect memory leaks.
+ this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
+ }
+ };
+ Engine.prototype.tidy = function (nameOrFn, fn) {
+ var _this = this;
+ var name = null;
+ if (fn == null) {
+ // Called with only 1 argument.
+ if (typeof nameOrFn !== 'function') {
+ throw new Error('Please provide a function to tidy()');
+ }
+ fn = nameOrFn;
+ }
+ else {
+ // Called with 2 arguments.
+ if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
+ throw new Error('When calling with two arguments, the first argument ' +
+ 'to tidy() must be a string');
+ }
+ if (typeof fn !== 'function') {
+ throw new Error('When calling with two arguments, the 2nd argument ' +
+ 'to tidy() must be a function');
+ }
+ name = nameOrFn;
+ // TODO(nsthorat,smilkov): Do operation logging and performance
+ // profiling.
+ }
+ var result;
+ return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () {
+ result = fn();
+ if (result instanceof Promise) {
+ console.error('Cannot return a Promise inside of tidy.');
+ }
+ return result;
+ });
+ };
+ Engine.prototype.scopedRun = function (start, end, f) {
+ start();
+ try {
+ var res = f();
+ end();
+ return res;
+ }
+ catch (ex) {
+ end();
+ throw ex;
+ }
+ };
+ Engine.prototype.nextTensorId = function () {
+ return Engine.nextTensorId++;
+ };
+ Engine.prototype.nextVariableId = function () {
+ return Engine.nextVariableId++;
+ };
+ /**
+ * This method is called instead of the public-facing tensor.clone() when
+ * saving a tensor for backwards pass. It makes sure to add the clone
+ * operation to the tape regardless of being called inside a kernel
+ * execution.
+ */
+ Engine.prototype.clone = function (x) {
+ var y = ENGINE.runKernel(Identity, { x: x });
+ var inputs = { x: x };
+ var grad = function (dy) { return ({
+ x: function () {
+ var dtype = 'float32';
+ var gradInputs = { x: dy };
+ var attrs = { dtype: dtype };
+ return ENGINE.runKernel(Cast, gradInputs,
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ attrs);
+ }
+ }); };
+ var saved = [];
+ this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
+ return y;
+ };
+ /**
+ * Execute a kernel with the given name and return the output tensor.
+ *
+ * @param kernelName The name of the kernel to execute.
+ * @param inputs A map of input names to tensors.
+ * @param attrs A map of attribute names to their values. An attribute is a
+ * primitive (non-tensor) input to the kernel.
+ * @param inputsToSave A list of tensors, inputs to save for the backprop
+ * computation.
+ * @param outputsToSave A list of booleans, specifying which output to save
+ * for the backprop computation. These are booleans since the output
+ * tensors are not visible to the user.
+ */
+ Engine.prototype.runKernel = function (kernelName, inputs, attrs) {
+ if (this.backendName == null) {
+ // backend has not been initialized yet (backend initialization is lazy
+ // can be deferred until an op/ kernel is run).
+ // The below getter has side effects that will try to initialize the
+ // backend and set properties like this.backendName
+ // tslint:disable-next-line: no-unused-expression
+ this.backend;
+ }
+ var hasKernel = getKernel(kernelName, this.backendName) != null;
+ if (!hasKernel) {
+ throw new Error("Kernel '" + kernelName + "' not registered for backend '" + this.backendName + "'");
+ }
+ return this.runKernelFunc({ kernelName: kernelName, inputs: inputs, attrs: attrs });
+ };
+ Engine.prototype.shouldCheckForMemLeaks = function () {
+ return this.ENV.getBool('IS_TEST');
+ };
+ Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) {
+ var numDataIdsAfter = this.backend.numDataIds();
+ // Count the number of data ids associated with the result of the kernel.
+ var numOutputDataIds = 0;
+ outInfos.forEach(function (info) {
+ // Complex numbers allocate 3 data ids, one for 'real', one for
+ // 'imaginary', and one for the container that holds the former two.
+ numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
+ });
+ // Account for the number of moves during kernel execution. A "data move"
+ // can happen in the middle of a kernel execution, placing a new (key,value)
+ // pair in the data storage. Since data moves have net zero effect (we
+ // always remove the data from the old backend), we have to cancel them out
+ // when detecting memory leaks.
+ var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
+ var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
+ if (dataIdsLeaked > 0) {
+ throw new Error("Backend '" + this.backendName + "' has an internal memory leak " +
+ ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'"));
+ }
+ };
+ /**
+ * Internal helper method to execute a kernel Func
+ *
+ * Use `runKernel` to execute kernels from outside of engine.
+ */
+ Engine.prototype.runKernelFunc = function (kernelParams) {
+ var _this = this;
+ var outputs;
+ var saved = [];
+ var isTapeOn = this.isTapeOn();
+ var startingBytecount = this.state.numBytes;
+ var startingNumTensors = this.state.numTensors;
+ if (this.shouldCheckForMemLeaks()) {
+ this.state.numDataMovesStack.push(0);
+ }
+ var kernelFunc;
+ if (this.backendName == null) {
+ // backend has not been initialized yet (backend initialization is lazy
+ // can be deferred until an op/ kernel is run).
+ // The below getter has side effects that will try to initialize the
+ // backend and set properties like this.backendName
+ // tslint:disable-next-line: no-unused-expression
+ this.backend;
+ }
+ var out;
+ var kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
+ kernelParams.kernelName :
+ this.state.activeScope != null ? this.state.activeScope.name : '';
+ // Create the kernelFunc from either a registered kernel OR passed in
+ // forward/backward functions (used by custom grad). In this context a
+ // kernelFunc wraps a kernel implementation with some bookkeeping.
+ if (isRegisteredKernelInvocation(kernelParams)) {
+ var kernelName_1 = kernelParams.kernelName, inputs_1 = kernelParams.inputs, attrs_1 = kernelParams.attrs;
+ if (this.backendName == null) {
+ // backend has not been initialized yet (backend initialization is lazy
+ // can be deferred until an op/ kernel is run).
+ // The below getter has side effects that will try to initialize the
+ // backend and set properties like this.backendName
+ // tslint:disable-next-line: no-unused-expression
+ this.backend;
+ }
+ var kernel_1 = getKernel(kernelName_1, this.backendName);
+ assert(kernel_1 != null, function () { return "Cannot find registered kernel '" + kernelName_1 + "' for backend '" + _this.backendName + "'"; });
+ kernelFunc = function () {
+ var numDataIdsBefore = _this.backend.numDataIds();
+ out = kernel_1.kernelFunc({ inputs: inputs_1, attrs: attrs_1, backend: _this.backend });
+ var outInfos = Array.isArray(out) ? out : [out];
+ if (_this.shouldCheckForMemLeaks()) {
+ _this.checkKernelForMemLeak(kernelName_1, numDataIdsBefore, outInfos);
+ }
+ var outTensors = outInfos.map(function (outInfo) {
+ // todo (yassogba) remove this option (Tensor) when node backend
+ // methods have been modularized and they all return tensorInfo.
+ // TensorInfos do not have a rank attribute.
+ if (outInfo.rank != null) {
+ return outInfo;
+ }
+ var dataId = outInfo.dataId, shape = outInfo.shape, dtype = outInfo.dtype;
+ return _this.makeTensorFromDataId(dataId, shape, dtype);
+ });
+ // Save any required inputs and outputs.
+ // Do not save unless we are recording to the tape. Otherwise it would
+ // cause a mem leak since there would be no backprop for these tensors
+ // (which would otherwise dispose them).
+ if (isTapeOn) {
+ var tensorsToSave = _this.getTensorsForGradient(kernelName_1, inputs_1, outTensors);
+ saved = _this.saveTensorsForBackwardMode(tensorsToSave);
+ }
+ return outTensors;
+ };
+ }
+ else {
+ var forwardFunc_1 = kernelParams.forwardFunc;
+ // Running a customGrad op.
+ var saveFunc_1 = function (tensors) {
+ // Do not save unless we are recording to the tape. Otherwise it would
+ // cause a mem leak since we would never run backprop, which disposes
+ // the kept tensors.
+ if (!isTapeOn) {
+ return;
+ }
+ saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); });
+ };
+ kernelFunc = function () {
+ var numDataIdsBefore = _this.backend.numDataIds();
+ out = _this.tidy(function () { return forwardFunc_1(_this.backend, saveFunc_1); });
+ var outs = (Array.isArray(out) ? out : [out]);
+ if (_this.shouldCheckForMemLeaks()) {
+ // Scope name is used to print a more helpful error message if needed.
+ _this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
+ }
+ return outs;
+ };
+ }
+ //
+ // Run the kernelFunc. Optionally profiling it.
+ //
+ var inputs = kernelParams.inputs, attrs = kernelParams.attrs;
+ var backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
+ null :
+ kernelParams.backwardsFunc;
+ var kernelProfile;
+ this.scopedRun(
+ // Stop recording to a tape when running a kernel.
+ function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () {
+ if (!_this.ENV.getBool('DEBUG') && !_this.state.profiling) {
+ outputs = kernelFunc();
+ }
+ else {
+ kernelProfile = _this.profiler.profileKernel(kernelOrScopeName, inputs, function () { return kernelFunc(); });
+ if (_this.ENV.getBool('DEBUG')) {
+ _this.profiler.logKernelProfile(kernelProfile);
+ }
+ outputs = kernelProfile.outputs;
+ }
+ });
+ if (isTapeOn) {
+ this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
+ }
+ if (this.state.profiling) {
+ this.state.activeProfile.kernels.push({
+ name: kernelOrScopeName,
+ bytesAdded: this.state.numBytes - startingBytecount,
+ totalBytesSnapshot: this.state.numBytes,
+ tensorsAdded: this.state.numTensors - startingNumTensors,
+ totalTensorsSnapshot: this.state.numTensors,
+ inputShapes: Object.keys(inputs).map(function (key) { return inputs[key] != null ? inputs[key].shape : null; }),
+ outputShapes: outputs.map(function (item) { return item.shape; }),
+ kernelTimeMs: kernelProfile.timeMs,
+ extraInfo: kernelProfile.extraInfo
+ });
+ }
+ return (Array.isArray(out) ? outputs : outputs[0]);
+ };
+ /**
+ * Saves tensors used in forward mode for use in backward mode.
+ *
+ * @param tensors the list of tensors to save.
+ */
+ Engine.prototype.saveTensorsForBackwardMode = function (tensors) {
+ var _this = this;
+ var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); });
+ return saved;
+ };
+ /**
+ * Returns a list of tensors to save for a given gradient calculation.
+ *
+ * @param kernelName name of kernel to look up gradient for.
+ * @param inputs a map of input tensors.
+ * @param outputs an array of output tensors from forward mode of kernel.
+ */
+ Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) {
+ var gradConfig = getGradient(kernelName);
+ if (gradConfig != null) {
+ var inputsToSave = gradConfig.inputsToSave || [];
+ var outputsToSave_1 = gradConfig.outputsToSave || [];
+ // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
+ // specified in inputsToSave will be saved.
+ var inputTensorsToSave = void 0;
+ if (gradConfig.saveAllInputs) {
+ assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; });
+ inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; });
+ }
+ else {
+ inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; });
+ }
+ var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; });
+ return inputTensorsToSave.concat(outputTensorsToSave);
+ }
+ // We return an empty list rather than throw an error because the kernel we
+ // are looking up may not actually be relevant to backproping through the
+ // overall function
+ //
+ // See 'does not error if irrelevant (pruned) ops are missing grads' test
+ // in gradients_test.ts for an example.
+ return [];
+ };
+ /**
+ * Internal method used by public APIs for tensor creation. Makes a new
+ * tensor with the provided shape, dtype and values. It always
+ * creates a new data id and writes the values to the underlying backend.
+ */
+ Engine.prototype.makeTensor = function (values, shape, dtype, backend) {
+ if (values == null) {
+ throw new Error('Values passed to engine.makeTensor() are null');
+ }
+ dtype = dtype || 'float32';
+ backend = backend || this.backend;
+ var backendVals = values;
+ if (dtype === 'string' && isString(values[0])) {
+ backendVals = values.map(function (d) { return encodeString(d); });
+ }
+ var dataId = backend.write(backendVals, shape, dtype);
+ var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
+ this.trackTensor(t, backend);
+ // Count bytes for string tensors.
+ if (dtype === 'string') {
+ var info = this.state.tensorInfo.get(dataId);
+ var newBytes = bytesFromStringArray(backendVals);
+ this.state.numBytes += newBytes - info.bytes;
+ info.bytes = newBytes;
+ }
+ return t;
+ };
+ /**
+ * Internal method used by backends. Makes a new tensor
+ * that is a wrapper around an existing data id. It doesn't create
+ * a new data id, only increments the ref count used in memory tracking.
+ */
+ Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) {
+ dtype = dtype || 'float32';
+ var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
+ this.trackTensor(t, backend);
+ return t;
+ };
+ Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) {
+ if (trainable === void 0) { trainable = true; }
+ name = name || this.nextVariableId().toString();
+ if (dtype != null && dtype !== initialValue.dtype) {
+ initialValue = initialValue.cast(dtype);
+ }
+ var v = new Variable(initialValue, trainable, name, this.nextTensorId());
+ if (this.state.registeredVariables[v.name] != null) {
+ throw new Error("Variable with name " + v.name + " was already registered");
+ }
+ this.state.registeredVariables[v.name] = v;
+ this.incRef(v, this.backend);
+ return v;
+ };
+ Engine.prototype.trackTensor = function (a, backend) {
+ this.state.numTensors++;
+ if (a.dtype === 'string') {
+ this.state.numStringTensors++;
+ }
+ // Bytes for complex numbers are counted by their components. Bytes for
+ // string tensors are counted when writing values.
+ var bytes = 0;
+ if (a.dtype !== 'complex64' && a.dtype !== 'string') {
+ bytes = a.size * bytesPerElement(a.dtype);
+ }
+ this.state.numBytes += bytes;
+ if (!this.state.tensorInfo.has(a.dataId)) {
+ this.state.numDataBuffers++;
+ this.state.tensorInfo.set(a.dataId, {
+ backend: backend || this.backend,
+ dtype: a.dtype,
+ shape: a.shape,
+ bytes: bytes
+ });
+ }
+ if (!(a instanceof Variable)) {
+ this.track(a);
+ }
+ };
+ // Track the tensor by dataId and increase the refCount for the dataId in the
+ // backend.
+ // TODO(pyu10055): This is currently used by makeVariable method, to increase
+ // refCount on the backend for the dataId. It can potentially be replaced with
+ // Identity op indead of calling backend directly.
+ Engine.prototype.incRef = function (a, backend) {
+ this.trackTensor(a, backend);
+ this.backend.incRef(a.dataId);
+ };
+ Engine.prototype.removeDataId = function (dataId, backend) {
+ if (this.state.tensorInfo.has(dataId) &&
+ this.state.tensorInfo.get(dataId).backend === backend) {
+ this.state.tensorInfo.delete(dataId);
+ this.state.numDataBuffers--;
+ }
+ };
+ Engine.prototype.disposeTensor = function (a) {
+ if (!this.state.tensorInfo.has(a.dataId)) {
+ return;
+ }
+ var info = this.state.tensorInfo.get(a.dataId);
+ this.state.numTensors--;
+ if (a.dtype === 'string') {
+ this.state.numStringTensors--;
+ this.state.numBytes -= info.bytes;
+ }
+ // Don't count bytes for complex numbers as they are counted by their
+ // components.
+ if (a.dtype !== 'complex64' && a.dtype !== 'string') {
+ var bytes = a.size * bytesPerElement(a.dtype);
+ this.state.numBytes -= bytes;
+ }
+ // Remove the reference to dataId if backend dispose the data successfully
+ if (info.backend.disposeData(a.dataId)) {
+ this.removeDataId(a.dataId, info.backend);
+ }
+ // TODO(nsthorat): Construct an error and save the stack trace for
+ // debugging when in debug mode. Creating a stack trace is too expensive
+ // to do unconditionally.
+ };
+ Engine.prototype.disposeVariables = function () {
+ for (var varName in this.state.registeredVariables) {
+ var v = this.state.registeredVariables[varName];
+ this.disposeVariable(v);
+ }
+ };
+ Engine.prototype.disposeVariable = function (v) {
+ this.disposeTensor(v);
+ if (this.state.registeredVariables[v.name] != null) {
+ delete this.state.registeredVariables[v.name];
+ }
+ };
+ Engine.prototype.memory = function () {
+ var info = this.backend.memory();
+ info.numTensors = this.state.numTensors;
+ info.numDataBuffers = this.state.numDataBuffers;
+ info.numBytes = this.state.numBytes;
+ if (this.state.numStringTensors > 0) {
+ info.unreliable = true;
+ if (info.reasons == null) {
+ info.reasons = [];
+ }
+ info.reasons.push('Memory usage by string tensors is approximate ' +
+ '(2 bytes per character)');
+ }
+ return info;
+ };
+ Engine.prototype.profile = function (query) {
+ return __awaiter(this, void 0, void 0, function () {
+ var startBytes, startNumTensors, _a, _b, _c, kernel, _d, _e, e_1_1;
+ var e_1, _f;
+ return __generator(this, function (_g) {
+ switch (_g.label) {
+ case 0:
+ this.state.profiling = true;
+ startBytes = this.state.numBytes;
+ startNumTensors = this.state.numTensors;
+ this.state.activeProfile.kernels = [];
+ _a = this.state.activeProfile;
+ return [4 /*yield*/, query()];
+ case 1:
+ _a.result = _g.sent();
+ this.state.profiling = false;
+ this.state.activeProfile.peakBytes = Math.max.apply(Math, __spread(this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })));
+ this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
+ this.state.activeProfile.newTensors =
+ this.state.numTensors - startNumTensors;
+ _g.label = 2;
+ case 2:
+ _g.trys.push([2, 8, 9, 10]);
+ _b = __values(this.state.activeProfile.kernels), _c = _b.next();
+ _g.label = 3;
+ case 3:
+ if (!!_c.done) return [3 /*break*/, 7];
+ kernel = _c.value;
+ _d = kernel;
+ return [4 /*yield*/, kernel.kernelTimeMs];
+ case 4:
+ _d.kernelTimeMs = _g.sent();
+ _e = kernel;
+ return [4 /*yield*/, kernel.extraInfo];
+ case 5:
+ _e.extraInfo = _g.sent();
+ _g.label = 6;
+ case 6:
+ _c = _b.next();
+ return [3 /*break*/, 3];
+ case 7: return [3 /*break*/, 10];
+ case 8:
+ e_1_1 = _g.sent();
+ e_1 = { error: e_1_1 };
+ return [3 /*break*/, 10];
+ case 9:
+ try {
+ if (_c && !_c.done && (_f = _b.return)) _f.call(_b);
+ }
+ finally { if (e_1) throw e_1.error; }
+ return [7 /*endfinally*/];
+ case 10: return [2 /*return*/, this.state.activeProfile];
+ }
+ });
+ });
+ };
+ Engine.prototype.isTapeOn = function () {
+ return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
+ };
+ Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
+ var _this = this;
+ var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved };
+ var gradConfig = getGradient(kernelName);
+ if (gradConfig != null) {
+ gradientsFunc = gradConfig.gradFunc;
+ }
+ if (gradientsFunc != null) {
+ tapeNode.gradient = function (dys) {
+ // TODO(smilkov): To optimize back-prop, pass dys that are not used in
+ // the backprop graph to the user as null instead of zeros
+ dys = dys.map(function (dy, i) {
+ if (dy == null) {
+ var output = outputs[i];
+ var vals = makeZerosTypedArray(output.size, output.dtype);
+ return _this.makeTensor(vals, output.shape, output.dtype);
+ }
+ return dy;
+ });
+ // Grad functions of ops with single outputs expect a dy, while ops
+ // with multiple outputs expect dys (array of dy).
+ return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
+ };
+ }
+ this.state.activeTape.push(tapeNode);
+ };
+ Engine.prototype.keep = function (result) {
+ result.kept = true;
+ return result;
+ };
+ Engine.prototype.startTape = function () {
+ if (this.state.gradientDepth === 0) {
+ this.state.activeTape = [];
+ }
+ this.state.gradientDepth++;
+ };
+ Engine.prototype.endTape = function () {
+ this.state.gradientDepth--;
+ };
+ /**
+ * Start a scope. Use this with endScope() to achieve the same functionality
+ * as scope() without the need for a function closure.
+ */
+ Engine.prototype.startScope = function (name) {
+ var scopeInfo = {
+ track: [],
+ name: 'unnamed scope',
+ id: this.state.nextScopeId++
+ };
+ if (name) {
+ scopeInfo.name = name;
+ }
+ this.state.scopeStack.push(scopeInfo);
+ this.state.activeScope = scopeInfo;
+ };
+ /**
+ * End a scope. Use this with startScope() to achieve the same functionality
+ * as scope() without the need for a function closure.
+ */
+ Engine.prototype.endScope = function (result) {
+ var _this = this;
+ var tensorsToTrackInParent = getTensorsInContainer(result);
+ var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; }));
+ // Dispose the arrays tracked in this scope.
+ for (var i = 0; i < this.state.activeScope.track.length; i++) {
+ var tensor = this.state.activeScope.track[i];
+ if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
+ tensor.dispose();
+ }
+ }
+ var oldScope = this.state.scopeStack.pop();
+ this.state.activeScope = this.state.scopeStack.length === 0 ?
+ null :
+ this.state.scopeStack[this.state.scopeStack.length - 1];
+ // Track the current result in the parent scope.
+ tensorsToTrackInParent.forEach(function (tensor) {
+ // Only track the tensor if was allocated in the inner scope and is not
+ // globally kept.
+ if (!tensor.kept && tensor.scopeId === oldScope.id) {
+ _this.track(tensor);
+ }
+ });
+ };
+ /**
+ * Returns gradients of `f` with respect to each of the `xs`. The gradients
+ * returned are of the same length as `xs`, but some might be null if `f`
+ * was not a function of that `x`. It also takes optional dy to multiply the
+ * gradient, which defaults to `1`.
+ */
+ Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) {
+ var _this = this;
+ if (allowNoGradients === void 0) { allowNoGradients = false; }
+ assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; });
+ if (dy != null && dy.dtype !== 'float32') {
+ throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'");
+ }
+ var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); });
+ assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; });
+ // Filter out the nodes that don't connect x => y.
+ var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
+ if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
+ throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
+ 'that the f you passed encloses all operations that lead from x ' +
+ 'to y.');
+ }
+ return this.tidy('backward', function () {
+ var accumulatedGradientMap = {};
+ accumulatedGradientMap[y.id] = (dy == null) ? ones$1(y.shape) : dy;
+ // Backprop gradients through the filtered nodes.
+ backpropagateGradients(accumulatedGradientMap, filteredTape,
+ // Pass the tidy function to avoid circular dep with `tape.ts`.
+ function (f) { return _this.tidy(f); },
+ // Pass an add function to avoide a circular dep with `tape.ts`.
+ add$1);
+ var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; });
+ if (_this.state.gradientDepth === 0) {
+ // This means that we are not computing higher-order gradients
+ // and can clean up the tape.
+ _this.state.activeTape.forEach(function (node) {
+ var e_2, _a;
+ try {
+ for (var _b = __values(node.saved), _c = _b.next(); !_c.done; _c = _b.next()) {
+ var tensor = _c.value;
+ tensor.dispose();
+ }
+ }
+ catch (e_2_1) { e_2 = { error: e_2_1 }; }
+ finally {
+ try {
+ if (_c && !_c.done && (_a = _b.return)) _a.call(_b);
+ }
+ finally { if (e_2) throw e_2.error; }
+ }
+ });
+ _this.state.activeTape = null;
+ }
+ return { value: y, grads: grads };
+ });
+ };
+ Engine.prototype.customGrad = function (f) {
+ var _this = this;
+ assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; });
+ return function () {
+ var inputs = [];
+ for (var _i = 0; _i < arguments.length; _i++) {
+ inputs[_i] = arguments[_i];
+ }
+ assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
+ 'tensors'; });
+ var res;
+ var inputMap = {};
+ inputs.forEach(function (input, i) {
+ inputMap[i] = input;
+ });
+ var forwardFunc = function (_, save) {
+ res = f.apply(void 0, __spread(inputs, [save]));
+ assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' +
+ 'object where `obj.value` is a tensor'; });
+ assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' +
+ 'object where `obj.gradFunc` is a function.'; });
+ return res.value;
+ };
+ var backwardsFunc = function (dy, saved) {
+ var gradRes = res.gradFunc(dy, saved);
+ var grads = Array.isArray(gradRes) ? gradRes : [gradRes];
+ assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' +
+ 'object where `obj.gradFunc` is a function that returns ' +
+ 'the same number of tensors as inputs passed to f(...).'; });
+ assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' +
+ 'object where `obj.gradFunc` is a function that returns ' +
+ 'a list of only tensors.'; });
+ var gradMap = {};
+ grads.forEach(function (grad, i) {
+ gradMap[i] = function () { return grad; };
+ });
+ return gradMap;
+ };
+ return _this.runKernelFunc({
+ forwardFunc: forwardFunc,
+ backwardsFunc: backwardsFunc,
+ inputs: inputMap,
+ });
+ };
+ };
+ Engine.prototype.readSync = function (dataId) {
+ // Route the read to the correct backend.
+ var info = this.state.tensorInfo.get(dataId);
+ return info.backend.readSync(dataId);
+ };
+ Engine.prototype.read = function (dataId) {
+ // Route the read to the correct backend.
+ var info = this.state.tensorInfo.get(dataId);
+ return info.backend.read(dataId);
+ };
+ Engine.prototype.time = function (query) {
+ return __awaiter(this, void 0, void 0, function () {
+ var start, timingInfo;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ start = now();
+ return [4 /*yield*/, this.backend.time(query)];
+ case 1:
+ timingInfo = _a.sent();
+ timingInfo.wallMs = now() - start;
+ return [2 /*return*/, timingInfo];
+ }
+ });
+ });
+ };
+ /**
+ * Tracks a Tensor in the current scope to be automatically cleaned up
+ * when the current scope ends, and returns the value.
+ *
+ * @param result The Tensor to track in the current scope.
+ */
+ Engine.prototype.track = function (result) {
+ if (this.state.activeScope != null) {
+ result.scopeId = this.state.activeScope.id;
+ this.state.activeScope.track.push(result);
+ }
+ return result;
+ };
+ Object.defineProperty(Engine.prototype, "registeredVariables", {
+ get: function () {
+ return this.state.registeredVariables;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ /**
+ * Resets the engine state. Removes all backends but does not remove
+ * registered backend factories.
+ */
+ Engine.prototype.reset = function () {
+ // Make any pending promise obsolete.
+ this.pendingBackendInitId++;
+ this.state.dispose();
+ this.ENV.reset();
+ this.state = new EngineState();
+ for (var backendName in this.registry) {
+ this.disposeRegisteredKernels(backendName);
+ this.registry[backendName].dispose();
+ delete this.registry[backendName];
+ }
+ this.backendName = null;
+ this.backendInstance = null;
+ this.pendingBackendInit = null;
+ };
+ return Engine;
+ }());
+ Engine.nextTensorId = 0;
+ Engine.nextVariableId = 0;
+ function ones$1(shape) {
+ var values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
+ return ENGINE.makeTensor(values, shape, 'float32');
+ }
+ function getOrMakeEngine() {
+ var ns = getGlobalNamespace();
+ if (ns._tfengine == null) {
+ var environment = new Environment(ns);
+ ns._tfengine = new Engine(environment);
+ }
+ setEnvironmentGlobal(ns._tfengine.ENV);
+ // Tell the current tensor interface that the global engine is responsible
+ // for tracking.
+ setTensorTracker(function () { return ns._tfengine; });
+ return ns._tfengine;
+ }
+ var ENGINE = getOrMakeEngine();
+ /**
+ * A implementation of the add op for use within engine and tape.
+ *
+ * This allows us to avoid a circular dependency between add.ts and engine.
+ * It is exported to be available in tape tests.
+ */
+ function add$1(a, b) {
+ // We duplicate Add here to avoid a circular dependency with add.ts.
+ var inputs = { a: a, b: b };
+ return ENGINE.runKernel(Add, inputs);
+ }
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // tslint:disable-next-line:no-any
+ function _isNavigatorDefined() {
+ return typeof navigator !== 'undefined' && navigator != null;
+ }
+ var isMobileMockValue;
+ function mockIsMobile(value) {
+ isMobileMockValue = value;
+ }
+ function isMobile(nav) {
+ if (isMobileMockValue !== undefined) {
+ return isMobileMockValue;
+ }
+ if (nav || _isNavigatorDefined()) {
+ if (!nav) {
+ nav = navigator;
+ }
+ if (nav.product === 'ReactNative') {
+ return true;
+ }
+ var a = nav.userAgent || nav.vendor ||
+ // tslint:disable-next-line:no-any
+ (typeof window !== 'undefined' ? window.opera : '');
+ // Use `navigator.userAgentData.mobile` as fallback.
+ if (!a) {
+ // tslint:disable-next-line:no-any
+ var navAny = nav;
+ return navAny.userAgentData && navAny.userAgentData.mobile;
+ }
+ // tslint:disable-next-line:max-line-length
+ return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
+ .test(a) ||
+ // tslint:disable-next-line:max-line-length
+ /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
+ .test(a.substr(0, 4));
+ }
+ return false;
+ }
+ function isBrowser() {
+ return (typeof window !== 'undefined' && window.document != null) ||
+ //@ts-ignore
+ (typeof WorkerGlobalScope !== 'undefined');
+ }
+
+ var device_util = {
+ __proto__: null,
+ mockIsMobile: mockIsMobile,
+ isMobile: isMobile,
+ isBrowser: isBrowser
+ };
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var ENV = env();
+ /**
+ * This file contains environment-related flag registrations.
+ */
+ /** Whether to enable debug mode. */
+ ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) {
+ if (debugValue) {
+ console.warn('Debugging mode is ON. The output of every math call will ' +
+ 'be downloaded to CPU and checked for NaNs. ' +
+ 'This significantly impacts performance.');
+ }
+ });
+ /** Whether we are in a browser (as versus, say, node.js) environment. */
+ ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); });
+ /** Whether we are in a browser (as versus, say, node.js) environment. */
+ ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') &&
+ (typeof process.versions !== 'undefined') &&
+ (typeof process.versions.node !== 'undefined'); });
+ /** Whether this browser is Chrome. */
+ ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null &&
+ navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
+ /Google Inc/.test(navigator.vendor); });
+ /**
+ * True when the environment is "production" where we disable safety checks
+ * to gain performance.
+ */
+ ENV.registerFlag('PROD', function () { return false; });
+ /**
+ * Whether to do sanity checks when inferring a shape from user-provided
+ * values, used when creating a new tensor.
+ */
+ ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); });
+ /** Whether deprecation warnings are enabled. */
+ ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; });
+ /** True if running unit tests. */
+ ENV.registerFlag('IS_TEST', function () { return false; });
+ /** Whether to check computation result for errors. */
+ ENV.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', function () { return true; });
+ /** Whether the backend needs to wrap input to imageBitmap. */
+ ENV.registerFlag('WRAP_TO_IMAGEBITMAP', function () { return false; });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function inferShape(val, dtype) {
+ var firstElem = val;
+ if (isTypedArray(val)) {
+ return dtype === 'string' ? [] : [val.length];
+ }
+ if (!Array.isArray(val)) {
+ return []; // Scalar.
+ }
+ var shape = [];
+ while (Array.isArray(firstElem) ||
+ isTypedArray(firstElem) && dtype !== 'string') {
+ shape.push(firstElem.length);
+ firstElem = firstElem[0];
+ }
+ if (Array.isArray(val) &&
+ env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
+ deepAssertShapeConsistency(val, shape, []);
+ }
+ return shape;
+ }
+ function deepAssertShapeConsistency(val, shape, indices) {
+ indices = indices || [];
+ if (!(Array.isArray(val)) && !isTypedArray(val)) {
+ assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " +
+ ("but should be an array/TypedArray of " + shape[0] + " elements"); });
+ return;
+ }
+ assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " +
+ ("but is an array of " + val.length + " elements"); });
+ assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " +
+ ("elements, but has " + val.length + " elements"); });
+ var subShape = shape.slice(1);
+ for (var i = 0; i < val.length; ++i) {
+ deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
+ }
+ }
+ function assertDtype(expectedDtype, actualDType, argName, functionName) {
+ if (expectedDtype === 'string_or_numeric') {
+ return;
+ }
+ if (expectedDtype == null) {
+ throw new Error("Expected dtype cannot be null.");
+ }
+ if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
+ expectedDtype === 'numeric' && actualDType === 'string') {
+ throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " +
+ ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
+ }
+ }
+ function convertToTensor(x, argName, functionName, parseAsDtype) {
+ if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
+ if (x instanceof Tensor) {
+ assertDtype(parseAsDtype, x.dtype, argName, functionName);
+ return x;
+ }
+ var inferredDtype = inferDtype(x);
+ // If the user expects a bool/int/float, use that info to update the
+ // inferredDtype when it is not a string.
+ if (inferredDtype !== 'string' &&
+ ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
+ inferredDtype = parseAsDtype;
+ }
+ assertDtype(parseAsDtype, inferredDtype, argName, functionName);
+ if ((x == null) ||
+ (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
+ typeof x !== 'boolean' && typeof x !== 'string')) {
+ var type = x == null ? 'null' : x.constructor.name;
+ throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " +
+ ("Tensor or TensorLike, but got '" + type + "'"));
+ }
+ var inferredShape = inferShape(x, inferredDtype);
+ if (!isTypedArray(x) && !Array.isArray(x)) {
+ x = [x];
+ }
+ var skipTypedArray = true;
+ var values = inferredDtype !== 'string' ?
+ toTypedArray(x, inferredDtype) :
+ flatten(x, [], skipTypedArray);
+ return ENGINE.makeTensor(values, inferredShape, inferredDtype);
+ }
+ function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
+ if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
+ if (!Array.isArray(arg)) {
+ throw new Error("Argument " + argName + " passed to " + functionName + " must be a " +
+ '`Tensor[]` or `TensorLike[]`');
+ }
+ var tensors = arg;
+ return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName, parseAsDtype); });
+ }
+
+ var OP_SCOPE_SUFFIX = '__op';
+ /**
+ * Used for wrapping functions that perform math operations on
+ * Tensors. The function will be wrapped in a named scope that cleans all
+ * memory usage after the function is done.
+ */
+ function op(f) {
+ var keys = Object.keys(f);
+ if (keys.length !== 1) {
+ throw new Error("Please provide an object with a single key " +
+ "(operation name) mapping to a function. Got an object with " +
+ (keys.length + " keys."));
+ }
+ var opName = keys[0];
+ var fn = f[opName];
+ // Strip the underscore from the end of the function name.
+ if (opName.endsWith('_')) {
+ opName = opName.substring(0, opName.length - 1);
+ }
+ // add an __op suffix to distinguish ops from kernels in tf.profile
+ opName = opName + OP_SCOPE_SUFFIX;
+ // tslint:disable-next-line:no-any
+ var f2 = function () {
+ var args = [];
+ for (var _i = 0; _i < arguments.length; _i++) {
+ args[_i] = arguments[_i];
+ }
+ ENGINE.startScope(opName);
+ try {
+ var result = fn.apply(void 0, __spread(args));
+ if (isPromise(result)) {
+ console.error('Cannot return a Promise inside of tidy.');
+ }
+ ENGINE.endScope(result);
+ return result;
+ }
+ catch (ex) {
+ ENGINE.endScope(null);
+ throw ex;
+ }
+ };
+ Object.defineProperty(f2, 'name', { value: opName, configurable: true });
+ // tslint:disable-next-line:no-any
+ return f2;
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Converts two real numbers to a complex number.
+ *
+ * Given a tensor `real` representing the real part of a complex number, and a
+ * tensor `imag` representing the imaginary part of a complex number, this
+ * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
+ * where r represents the real part and i represents the imag part.
+ *
+ * The input tensors real and imag must have the same shape.
+ *
+ * ```js
+ * const real = tf.tensor1d([2.25, 3.25]);
+ * const imag = tf.tensor1d([4.75, 5.75]);
+ * const complex = tf.complex(real, imag);
+ *
+ * complex.print();
+ * ```
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function complex_(real, imag) {
+ var $real = convertToTensor(real, 'real', 'complex');
+ var $imag = convertToTensor(imag, 'imag', 'complex');
+ assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " +
+ "must match in call to tf.complex().");
+ var inputs = { real: $real, imag: $imag };
+ return ENGINE.runKernel(Complex, inputs);
+ }
+ var complex = op({ complex_: complex_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /** This is shared code across all tensor creation methods. */
+ function makeTensor(values, shape, inferredShape, dtype) {
+ if (dtype == null) {
+ dtype = inferDtype(values);
+ }
+ if (dtype === 'complex64') {
+ throw new Error("Cannot construct a complex64 tensor directly. " +
+ "Please use tf.complex(real, imag).");
+ }
+ if (!isTypedArray(values) && !Array.isArray(values) &&
+ typeof values !== 'number' && typeof values !== 'boolean' &&
+ typeof values !== 'string') {
+ throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
+ 'an array of numbers/booleans/strings, or a TypedArray');
+ }
+ if (shape != null) {
+ assertNonNegativeIntegerDimensions(shape);
+ var providedSize_1 = sizeFromShape(shape);
+ var inferredSize_1 = sizeFromShape(inferredShape);
+ assert(providedSize_1 === inferredSize_1, function () { return "Based on the provided shape, [" + shape + "], the tensor should have " +
+ (providedSize_1 + " values but has " + inferredSize_1); });
+ for (var i = 0; i < inferredShape.length; ++i) {
+ var inferred = inferredShape[i];
+ var flatDimsDontMatch = i === inferredShape.length - 1 ?
+ inferred !== sizeFromShape(shape.slice(i)) :
+ true;
+ assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " +
+ ("(" + inferredShape + ") does not match the provided ") +
+ ("shape (" + shape + "). "); });
+ }
+ }
+ if (!isTypedArray(values) && !Array.isArray(values)) {
+ values = [values];
+ }
+ shape = shape || inferredShape;
+ values = dtype !== 'string' ?
+ toTypedArray(values, dtype) :
+ flatten(values, [], true);
+ return ENGINE.makeTensor(values, shape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * ```js
+ * // Pass an array of values to create a vector.
+ * tf.tensor([1, 2, 3, 4]).print();
+ * ```
+ *
+ * ```js
+ * // Pass a nested array of values to make a matrix or a higher
+ * // dimensional tensor.
+ * tf.tensor([[1, 2], [3, 4]]).print();
+ * ```
+ *
+ * ```js
+ * // Pass a flat array and specify a shape yourself.
+ * tf.tensor([1, 2, 3, 4], [2, 2]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`. If the values are strings,
+ * they will be encoded as utf-8 and kept as `Uint8Array[]`.
+ * @param shape The shape of the tensor. Optional. If not provided,
+ * it is inferred from `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor(values, shape, dtype) {
+ var inferredShape = inferShape(values, dtype);
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /* Type definitions for exporting and importing of models. */
+ /**
+ * A map from Tensor dtype to number of bytes per element of the Tensor.
+ */
+ var DTYPE_VALUE_SIZE_MAP = {
+ 'float32': 4,
+ 'float16': 2,
+ 'int32': 4,
+ 'uint16': 2,
+ 'uint8': 1,
+ 'bool': 1,
+ 'complex64': 8
+ };
+
+ /** Number of bytes reserved for the length of the string. (32bit integer). */
+ var NUM_BYTES_STRING_LENGTH = 4;
+ /**
+ * Encode a map from names to weight values as an ArrayBuffer, along with an
+ * `Array` of `WeightsManifestEntry` as specification of the encoded weights.
+ *
+ * This function does not perform sharding.
+ *
+ * This function is the reverse of `decodeWeights`.
+ *
+ * @param tensors A map ("dict") from names to tensors.
+ * @param group Group to which the weights belong (optional).
+ * @returns A `Promise` of
+ * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
+ * concatenated.
+ * - An `Array` of `WeightManifestEntry`s, carrying information including
+ * tensor names, `dtype`s and shapes.
+ * @throws Error: on unsupported tensor `dtype`.
+ */
+ function encodeWeights(tensors, group) {
+ return __awaiter(this, void 0, void 0, function () {
+ var specs, dataPromises, names, _loop_1, i, tensorValues;
+ var _this = this;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ specs = [];
+ dataPromises = [];
+ names = Array.isArray(tensors) ?
+ tensors.map(function (tensor) { return tensor.name; }) :
+ Object.keys(tensors);
+ _loop_1 = function (i) {
+ var name = names[i];
+ var t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
+ if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
+ t.dtype !== 'string' && t.dtype !== 'complex64') {
+ throw new Error("Unsupported dtype in weight '" + name + "': " + t.dtype);
+ }
+ var spec = { name: name, shape: t.shape, dtype: t.dtype };
+ if (t.dtype === 'string') {
+ var utf8bytes = new Promise(function (resolve) { return __awaiter(_this, void 0, void 0, function () {
+ var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, t.bytes()];
+ case 1:
+ vals = _a.sent();
+ totalNumBytes = vals.reduce(function (p, c) { return p + c.length; }, 0) +
+ NUM_BYTES_STRING_LENGTH * vals.length;
+ bytes = new Uint8Array(totalNumBytes);
+ offset = 0;
+ for (i_1 = 0; i_1 < vals.length; i_1++) {
+ val = vals[i_1];
+ bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
+ bytes.set(bytesOfLength, offset);
+ offset += NUM_BYTES_STRING_LENGTH;
+ bytes.set(val, offset);
+ offset += val.length;
+ }
+ resolve(bytes);
+ return [2 /*return*/];
+ }
+ });
+ }); });
+ dataPromises.push(utf8bytes);
+ }
+ else {
+ dataPromises.push(t.data());
+ }
+ if (group != null) {
+ spec.group = group;
+ }
+ specs.push(spec);
+ };
+ for (i = 0; i < names.length; ++i) {
+ _loop_1(i);
+ }
+ return [4 /*yield*/, Promise.all(dataPromises)];
+ case 1:
+ tensorValues = _a.sent();
+ return [2 /*return*/, { data: concatenateTypedArrays(tensorValues), specs: specs }];
+ }
+ });
+ });
+ }
+ /**
+ * Decode flat ArrayBuffer as weights.
+ *
+ * This function does not handle sharding.
+ *
+ * This function is the reverse of `encodeWeights`.
+ *
+ * @param buffer A flat ArrayBuffer carrying the binary values of the tensors
+ * concatenated in the order specified in `specs`.
+ * @param specs Specifications of the names, dtypes and shapes of the tensors
+ * whose value are encoded by `buffer`.
+ * @return A map from tensor name to tensor value, with the names corresponding
+ * to names in `specs`.
+ * @throws Error, if any of the tensors has unsupported dtype.
+ */
+ function decodeWeights(buffer, specs) {
+ var e_1, _a;
+ // TODO(adarob, cais): Support quantization.
+ var out = {};
+ var float16Decode;
+ var offset = 0;
+ try {
+ for (var specs_1 = __values(specs), specs_1_1 = specs_1.next(); !specs_1_1.done; specs_1_1 = specs_1.next()) {
+ var spec = specs_1_1.value;
+ var name = spec.name;
+ var dtype = spec.dtype;
+ var shape = spec.shape;
+ var size = sizeFromShape(shape);
+ var values = void 0;
+ if ('quantization' in spec) {
+ var quantization = spec.quantization;
+ if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
+ if (!('min' in quantization && 'scale' in quantization)) {
+ throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " " +
+ "doesn't have corresponding metadata min and scale.");
+ }
+ }
+ else if (quantization.dtype === 'float16') {
+ if (dtype !== 'float32') {
+ throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " +
+ ("which only supports weights of type float32 not " + dtype + "."));
+ }
+ }
+ else {
+ throw new Error("Weight " + spec.name + " has unknown " +
+ ("quantization dtype " + quantization.dtype + ". ") +
+ "Supported quantization dtypes are: " +
+ "'uint8', 'uint16', and 'float16'.");
+ }
+ var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
+ var byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor);
+ var quantizedArray = (quantization.dtype === 'uint8') ?
+ new Uint8Array(byteBuffer) :
+ new Uint16Array(byteBuffer);
+ if (dtype === 'float32') {
+ if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
+ values = new Float32Array(quantizedArray.length);
+ for (var i = 0; i < quantizedArray.length; i++) {
+ var v = quantizedArray[i];
+ values[i] = v * quantization.scale + quantization.min;
+ }
+ }
+ else if (quantization.dtype === 'float16') {
+ if (float16Decode === undefined) {
+ float16Decode = getFloat16Decoder();
+ }
+ values = float16Decode(quantizedArray);
+ }
+ else {
+ throw new Error("Unsupported quantization type " + quantization.dtype + " " +
+ "for weight type float32.");
+ }
+ }
+ else if (dtype === 'int32') {
+ if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
+ throw new Error("Unsupported quantization type " + quantization.dtype + " " +
+ "for weight type int32.");
+ }
+ values = new Int32Array(quantizedArray.length);
+ for (var i = 0; i < quantizedArray.length; i++) {
+ var v = quantizedArray[i];
+ values[i] = Math.round(v * quantization.scale + quantization.min);
+ }
+ }
+ else {
+ throw new Error("Unsupported dtype in weight '" + name + "': " + dtype);
+ }
+ offset += size * quantizationSizeFactor;
+ }
+ else if (dtype === 'string') {
+ var size_1 = sizeFromShape(spec.shape);
+ values = [];
+ for (var i = 0; i < size_1; i++) {
+ var byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
+ offset += NUM_BYTES_STRING_LENGTH;
+ var bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
+ values.push(bytes);
+ offset += byteLength;
+ }
+ }
+ else {
+ var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
+ var byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);
+ if (dtype === 'float32') {
+ values = new Float32Array(byteBuffer);
+ }
+ else if (dtype === 'int32') {
+ values = new Int32Array(byteBuffer);
+ }
+ else if (dtype === 'bool') {
+ values = new Uint8Array(byteBuffer);
+ }
+ else if (dtype === 'complex64') {
+ values = new Float32Array(byteBuffer);
+ var real = new Float32Array(values.length / 2);
+ var image = new Float32Array(values.length / 2);
+ for (var i = 0; i < real.length; i++) {
+ real[i] = values[i * 2];
+ image[i] = values[i * 2 + 1];
+ }
+ var realTensor = tensor(real, shape, 'float32');
+ var imageTensor = tensor(image, shape, 'float32');
+ out[name] = complex(realTensor, imageTensor);
+ realTensor.dispose();
+ imageTensor.dispose();
+ }
+ else {
+ throw new Error("Unsupported dtype in weight '" + name + "': " + dtype);
+ }
+ offset += size * dtypeFactor;
+ }
+ if (dtype !== 'complex64') {
+ out[name] = tensor(values, shape, dtype);
+ }
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (specs_1_1 && !specs_1_1.done && (_a = specs_1.return)) _a.call(specs_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ return out;
+ }
+ /**
+ * Concatenate TypedArrays into an ArrayBuffer.
+ */
+ function concatenateTypedArrays(xs) {
+ // TODO(adarob, cais): Support quantization.
+ if (xs === null) {
+ throw new Error("Invalid input value: " + JSON.stringify(xs));
+ }
+ var totalByteLength = 0;
+ // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
+ // can have a different byte length from that of the `TypedArray` itself,
+ // for example, when the `TypedArray` is created from an offset in an
+ // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
+ // the `TypedArray` in byte length. If an element of `xs` does not show
+ // this property, a new `TypedArray` that satisfy this property will be
+ // constructed and pushed into `normalizedXs`.
+ var normalizedXs = [];
+ xs.forEach(function (x) {
+ totalByteLength += x.byteLength;
+ // tslint:disable:no-any
+ normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
+ new x.constructor(x));
+ if (!(x instanceof Float32Array || x instanceof Int32Array ||
+ x instanceof Uint8Array)) {
+ throw new Error("Unsupported TypedArray subtype: " + x.constructor.name);
+ }
+ // tslint:enable:no-any
+ });
+ var y = new Uint8Array(totalByteLength);
+ var offset = 0;
+ normalizedXs.forEach(function (x) {
+ y.set(new Uint8Array(x.buffer), offset);
+ offset += x.byteLength;
+ });
+ return y.buffer;
+ }
+ // Use Buffer on Node.js instead of Blob/atob/btoa
+ var useNodeBuffer = typeof Buffer !== 'undefined' &&
+ (typeof Blob === 'undefined' || typeof atob === 'undefined' ||
+ typeof btoa === 'undefined');
+ /**
+ * Calculate the byte length of a JavaScript string.
+ *
+ * Note that a JavaScript string can contain wide characters, therefore the
+ * length of the string is not necessarily equal to the byte length.
+ *
+ * @param str Input string.
+ * @returns Byte length.
+ */
+ function stringByteLength(str) {
+ if (useNodeBuffer) {
+ return Buffer.byteLength(str);
+ }
+ return new Blob([str]).size;
+ }
+ /**
+ * Encode an ArrayBuffer as a base64 encoded string.
+ *
+ * @param buffer `ArrayBuffer` to be converted.
+ * @returns A string that base64-encodes `buffer`.
+ */
+ function arrayBufferToBase64String(buffer) {
+ if (useNodeBuffer) {
+ return Buffer.from(buffer).toString('base64');
+ }
+ var buf = new Uint8Array(buffer);
+ var s = '';
+ for (var i = 0, l = buf.length; i < l; i++) {
+ s += String.fromCharCode(buf[i]);
+ }
+ return btoa(s);
+ }
+ /**
+ * Decode a base64 string as an ArrayBuffer.
+ *
+ * @param str Base64 string.
+ * @returns Decoded `ArrayBuffer`.
+ */
+ function base64StringToArrayBuffer(str) {
+ if (useNodeBuffer) {
+ var buf = Buffer.from(str, 'base64');
+ return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
+ }
+ var s = atob(str);
+ var buffer = new Uint8Array(s.length);
+ for (var i = 0; i < s.length; ++i) {
+ buffer.set([s.charCodeAt(i)], i);
+ }
+ return buffer.buffer;
+ }
+ /**
+ * Concatenate a number of ArrayBuffers into one.
+ *
+ * @param buffers A number of array buffers to concatenate.
+ * @returns Result of concatenating `buffers` in order.
+ */
+ function concatenateArrayBuffers(buffers) {
+ if (buffers.length === 1) {
+ return buffers[0];
+ }
+ var totalByteLength = 0;
+ buffers.forEach(function (buffer) {
+ totalByteLength += buffer.byteLength;
+ });
+ var temp = new Uint8Array(totalByteLength);
+ var offset = 0;
+ buffers.forEach(function (buffer) {
+ temp.set(new Uint8Array(buffer), offset);
+ offset += buffer.byteLength;
+ });
+ return temp.buffer;
+ }
+ /**
+ * Get the basename of a path.
+ *
+ * Behaves in a way analogous to Linux's basename command.
+ *
+ * @param path
+ */
+ function basename(path) {
+ var SEPARATOR = '/';
+ path = path.trim();
+ while (path.endsWith(SEPARATOR)) {
+ path = path.slice(0, path.length - 1);
+ }
+ var items = path.split(SEPARATOR);
+ return items[items.length - 1];
+ }
+ /**
+ * Create `ModelJSON` from `ModelArtifacts`.
+ *
+ * @param artifacts Model artifacts, describing the model and its weights.
+ * @param manifest Weight manifest, describing where the weights of the
+ * `ModelArtifacts` are stored, and some metadata about them.
+ * @returns Object representing the `model.json` file describing the model
+ * artifacts and weights
+ */
+ function getModelJSONForModelArtifacts(artifacts, manifest) {
+ var result = {
+ modelTopology: artifacts.modelTopology,
+ format: artifacts.format,
+ generatedBy: artifacts.generatedBy,
+ convertedBy: artifacts.convertedBy,
+ weightsManifest: manifest
+ };
+ if (artifacts.signature != null) {
+ result.signature = artifacts.signature;
+ }
+ if (artifacts.userDefinedMetadata != null) {
+ result.userDefinedMetadata = artifacts.userDefinedMetadata;
+ }
+ if (artifacts.modelInitializer != null) {
+ result.modelInitializer = artifacts.modelInitializer;
+ }
+ if (artifacts.trainingConfig != null) {
+ result.trainingConfig = artifacts.trainingConfig;
+ }
+ return result;
+ }
+ /**
+ * Create `ModelArtifacts` from a JSON file.
+ *
+ * @param modelJSON Object containing the parsed JSON of `model.json`
+ * @param loadWeights Function that takes the JSON file's weights manifest,
+ * reads weights from the listed path(s), and returns a Promise of the
+ * weight manifest entries along with the weights data.
+ * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
+ */
+ function getModelArtifactsForJSON(modelJSON, loadWeights) {
+ return __awaiter(this, void 0, void 0, function () {
+ var modelArtifacts, _a, weightSpecs, weightData;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ modelArtifacts = {
+ modelTopology: modelJSON.modelTopology,
+ format: modelJSON.format,
+ generatedBy: modelJSON.generatedBy,
+ convertedBy: modelJSON.convertedBy
+ };
+ if (modelJSON.trainingConfig != null) {
+ modelArtifacts.trainingConfig = modelJSON.trainingConfig;
+ }
+ if (!(modelJSON.weightsManifest != null)) return [3 /*break*/, 2];
+ return [4 /*yield*/, loadWeights(modelJSON.weightsManifest)];
+ case 1:
+ _a = __read.apply(void 0, [_b.sent(), 2]), weightSpecs = _a[0], weightData = _a[1];
+ modelArtifacts.weightSpecs = weightSpecs;
+ modelArtifacts.weightData = weightData;
+ _b.label = 2;
+ case 2:
+ if (modelJSON.signature != null) {
+ modelArtifacts.signature = modelJSON.signature;
+ }
+ if (modelJSON.userDefinedMetadata != null) {
+ modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
+ }
+ if (modelJSON.modelInitializer != null) {
+ modelArtifacts.modelInitializer = modelJSON.modelInitializer;
+ }
+ return [2 /*return*/, modelArtifacts];
+ }
+ });
+ });
+ }
+ /**
+ * Populate ModelArtifactsInfo fields for a model with JSON topology.
+ * @param modelArtifacts
+ * @returns A ModelArtifactsInfo object.
+ */
+ function getModelArtifactsInfoForJSON(modelArtifacts) {
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
+ throw new Error('Expected JSON model topology, received ArrayBuffer.');
+ }
+ return {
+ dateSaved: new Date(),
+ modelTopologyType: 'JSON',
+ modelTopologyBytes: modelArtifacts.modelTopology == null ?
+ 0 :
+ stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
+ weightSpecsBytes: modelArtifacts.weightSpecs == null ?
+ 0 :
+ stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
+ weightDataBytes: modelArtifacts.weightData == null ?
+ 0 :
+ modelArtifacts.weightData.byteLength,
+ };
+ }
+ /**
+ * Computes mantisa table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 2048 mantissa lookup values.
+ */
+ function computeFloat16MantisaTable() {
+ var convertMantissa = function (i) {
+ var m = i << 13;
+ var e = 0;
+ while ((m & 0x00800000) === 0) {
+ e -= 0x00800000;
+ m <<= 1;
+ }
+ m &= ~0x00800000;
+ e += 0x38800000;
+ return m | e;
+ };
+ var mantisaTable = new Uint32Array(2048);
+ mantisaTable[0] = 0;
+ for (var i = 1; i < 1024; i++) {
+ mantisaTable[i] = convertMantissa(i);
+ }
+ for (var i = 1024; i < 2048; i++) {
+ mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
+ }
+ return mantisaTable;
+ }
+ /**
+ * Computes exponent table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 64 exponent lookup values.
+ */
+ function computeFloat16ExponentTable() {
+ var exponentTable = new Uint32Array(64);
+ exponentTable[0] = 0;
+ exponentTable[31] = 0x47800000;
+ exponentTable[32] = 0x80000000;
+ exponentTable[63] = 0xc7800000;
+ for (var i = 1; i < 31; i++) {
+ exponentTable[i] = i << 23;
+ }
+ for (var i = 33; i < 63; i++) {
+ exponentTable[i] = 0x80000000 + ((i - 32) << 23);
+ }
+ return exponentTable;
+ }
+ /**
+ * Computes offset table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 6d offset values.
+ */
+ function computeFloat16OffsetTable() {
+ var offsetTable = new Uint32Array(64);
+ for (var i = 0; i < 64; i++) {
+ offsetTable[i] = 1024;
+ }
+ offsetTable[0] = offsetTable[32] = 0;
+ return offsetTable;
+ }
+ /**
+ * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
+ * to a Float32Array.
+ *
+ * @returns Function (buffer: Uint16Array) => Float32Array which decodes
+ * the Uint16Array of Float16 bytes to a Float32Array.
+ */
+ function getFloat16Decoder() {
+ // Algorithm is based off of
+ // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ // Cache lookup tables
+ var mantisaTable = computeFloat16MantisaTable();
+ var exponentTable = computeFloat16ExponentTable();
+ var offsetTable = computeFloat16OffsetTable();
+ return function (quantizedArray) {
+ var buffer = new ArrayBuffer(4 * quantizedArray.length);
+ var bufferUint32View = new Uint32Array(buffer);
+ for (var index = 0; index < quantizedArray.length; index++) {
+ var float16Bits = quantizedArray[index];
+ var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
+ exponentTable[float16Bits >> 10];
+ bufferUint32View[index] = float32Bits;
+ }
+ return new Float32Array(buffer);
+ };
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var IORouterRegistry = /** @class */ (function () {
+ function IORouterRegistry() {
+ this.saveRouters = [];
+ this.loadRouters = [];
+ }
+ IORouterRegistry.getInstance = function () {
+ if (IORouterRegistry.instance == null) {
+ IORouterRegistry.instance = new IORouterRegistry();
+ }
+ return IORouterRegistry.instance;
+ };
+ /**
+ * Register a save-handler router.
+ *
+ * @param saveRouter A function that maps a URL-like string onto an instance
+ * of `IOHandler` with the `save` method defined or `null`.
+ */
+ IORouterRegistry.registerSaveRouter = function (saveRouter) {
+ IORouterRegistry.getInstance().saveRouters.push(saveRouter);
+ };
+ /**
+ * Register a load-handler router.
+ *
+ * @param loadRouter A function that maps a URL-like string onto an instance
+ * of `IOHandler` with the `load` method defined or `null`.
+ */
+ IORouterRegistry.registerLoadRouter = function (loadRouter) {
+ IORouterRegistry.getInstance().loadRouters.push(loadRouter);
+ };
+ /**
+ * Look up IOHandler for saving, given a URL-like string.
+ *
+ * @param url
+ * @returns If only one match is found, an instance of IOHandler with the
+ * `save` method defined. If no match is found, `null`.
+ * @throws Error, if more than one match is found.
+ */
+ IORouterRegistry.getSaveHandlers = function (url) {
+ return IORouterRegistry.getHandlers(url, 'save');
+ };
+ /**
+ * Look up IOHandler for loading, given a URL-like string.
+ *
+ * @param url
+ * @param loadOptions Optional, custom load options.
+ * @returns All valid handlers for `url`, given the currently registered
+ * handler routers.
+ */
+ IORouterRegistry.getLoadHandlers = function (url, loadOptions) {
+ return IORouterRegistry.getHandlers(url, 'load', loadOptions);
+ };
+ IORouterRegistry.getHandlers = function (url, handlerType, loadOptions) {
+ var validHandlers = [];
+ var routers = handlerType === 'load' ?
+ IORouterRegistry.getInstance().loadRouters :
+ IORouterRegistry.getInstance().saveRouters;
+ routers.forEach(function (router) {
+ var handler = router(url, loadOptions);
+ if (handler !== null) {
+ validHandlers.push(handler);
+ }
+ });
+ return validHandlers;
+ };
+ return IORouterRegistry;
+ }());
+ var registerSaveRouter = function (loudRouter) { return IORouterRegistry.registerSaveRouter(loudRouter); };
+ var registerLoadRouter = function (loudRouter) { return IORouterRegistry.registerLoadRouter(loudRouter); };
+ var getSaveHandlers = function (url) { return IORouterRegistry.getSaveHandlers(url); };
+ var getLoadHandlers = function (url, loadOptions) { return IORouterRegistry.getLoadHandlers(url, loadOptions); };
+
+ var DATABASE_NAME = 'tensorflowjs';
+ var DATABASE_VERSION = 1;
+ // Model data and ModelArtifactsInfo (metadata) are stored in two separate
+ // stores for efficient access of the list of stored models and their metadata.
+ // 1. The object store for model data: topology, weights and weight manifests.
+ var MODEL_STORE_NAME = 'models_store';
+ // 2. The object store for ModelArtifactsInfo, including meta-information such
+ // as the type of topology (JSON vs binary), byte size of the topology, byte
+ // size of the weights, etc.
+ var INFO_STORE_NAME = 'model_info_store';
+ function getIndexedDBFactory() {
+ if (!env().getBool('IS_BROWSER')) {
+ // TODO(cais): Add more info about what IOHandler subtypes are available.
+ // Maybe point to a doc page on the web and/or automatically determine
+ // the available IOHandlers and print them in the error message.
+ throw new Error('Failed to obtain IndexedDB factory because the current environment' +
+ 'is not a web browser.');
+ }
+ // tslint:disable-next-line:no-any
+ var theWindow = typeof window === 'undefined' ? self : window;
+ var factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
+ theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
+ theWindow.shimIndexedDB;
+ if (factory == null) {
+ throw new Error('The current browser does not appear to support IndexedDB.');
+ }
+ return factory;
+ }
+ function setUpDatabase(openRequest) {
+ var db = openRequest.result;
+ db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
+ db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
+ }
+ /**
+ * IOHandler subclass: Browser IndexedDB.
+ *
+ * See the doc string of `browserIndexedDB` for more details.
+ */
+ var BrowserIndexedDB = /** @class */ (function () {
+ function BrowserIndexedDB(modelPath) {
+ this.indexedDB = getIndexedDBFactory();
+ if (modelPath == null || !modelPath) {
+ throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
+ }
+ this.modelPath = modelPath;
+ }
+ BrowserIndexedDB.prototype.save = function (modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ // TODO(cais): Support saving GraphDef models.
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
+ throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
+ 'in binary formats yet.');
+ }
+ return [2 /*return*/, this.databaseAction(this.modelPath, modelArtifacts)];
+ });
+ });
+ };
+ BrowserIndexedDB.prototype.load = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ return [2 /*return*/, this.databaseAction(this.modelPath)];
+ });
+ });
+ };
+ /**
+ * Perform database action to put model artifacts into or read model artifacts
+ * from IndexedDB object store.
+ *
+ * Whether the action is put or get depends on whether `modelArtifacts` is
+ * specified. If it is specified, the action will be put; otherwise the action
+ * will be get.
+ *
+ * @param modelPath A unique string path for the model.
+ * @param modelArtifacts If specified, it will be the model artifacts to be
+ * stored in IndexedDB.
+ * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
+ * of `ModelArtifacts`, if the action is get.
+ */
+ BrowserIndexedDB.prototype.databaseAction = function (modelPath, modelArtifacts) {
+ var _this = this;
+ return new Promise(function (resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); };
+ openRequest.onsuccess = function () {
+ var db = openRequest.result;
+ if (modelArtifacts == null) {
+ // Read model out from object store.
+ var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
+ var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
+ var getRequest_1 = modelStore.get(_this.modelPath);
+ getRequest_1.onsuccess = function () {
+ if (getRequest_1.result == null) {
+ db.close();
+ return reject(new Error("Cannot find model with path '" + _this.modelPath + "' " +
+ "in IndexedDB."));
+ }
+ else {
+ resolve(getRequest_1.result.modelArtifacts);
+ }
+ };
+ getRequest_1.onerror = function (error) {
+ db.close();
+ return reject(getRequest_1.error);
+ };
+ modelTx.oncomplete = function () { return db.close(); };
+ }
+ else {
+ // Put model into object store.
+ var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts);
+ // First, put ModelArtifactsInfo into info store.
+ var infoTx_1 = db.transaction(INFO_STORE_NAME, 'readwrite');
+ var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
+ var putInfoRequest_1 = infoStore_1.put({ modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1 });
+ var modelTx_1;
+ putInfoRequest_1.onsuccess = function () {
+ // Second, put model data into model store.
+ modelTx_1 = db.transaction(MODEL_STORE_NAME, 'readwrite');
+ var modelStore = modelTx_1.objectStore(MODEL_STORE_NAME);
+ var putModelRequest = modelStore.put({
+ modelPath: _this.modelPath,
+ modelArtifacts: modelArtifacts,
+ modelArtifactsInfo: modelArtifactsInfo_1
+ });
+ putModelRequest.onsuccess = function () { return resolve({ modelArtifactsInfo: modelArtifactsInfo_1 }); };
+ putModelRequest.onerror = function (error) {
+ // If the put-model request fails, roll back the info entry as
+ // well.
+ infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME);
+ var deleteInfoRequest = infoStore_1.delete(_this.modelPath);
+ deleteInfoRequest.onsuccess = function () {
+ db.close();
+ return reject(putModelRequest.error);
+ };
+ deleteInfoRequest.onerror = function (error) {
+ db.close();
+ return reject(putModelRequest.error);
+ };
+ };
+ };
+ putInfoRequest_1.onerror = function (error) {
+ db.close();
+ return reject(putInfoRequest_1.error);
+ };
+ infoTx_1.oncomplete = function () {
+ if (modelTx_1 == null) {
+ db.close();
+ }
+ else {
+ modelTx_1.oncomplete = function () { return db.close(); };
+ }
+ };
+ }
+ };
+ openRequest.onerror = function (error) { return reject(openRequest.error); };
+ });
+ };
+ return BrowserIndexedDB;
+ }());
+ BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
+ var indexedDBRouter = function (url) {
+ if (!env().getBool('IS_BROWSER')) {
+ return null;
+ }
+ else {
+ if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
+ return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
+ }
+ else {
+ return null;
+ }
+ }
+ };
+ IORouterRegistry.registerSaveRouter(indexedDBRouter);
+ IORouterRegistry.registerLoadRouter(indexedDBRouter);
+ /**
+ * Creates a browser IndexedDB IOHandler for saving and loading models.
+ *
+ * ```js
+ * const model = tf.sequential();
+ * model.add(
+ * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
+ *
+ * const saveResult = await model.save('indexeddb://MyModel'));
+ * console.log(saveResult);
+ * ```
+ *
+ * @param modelPath A unique identifier for the model to be saved. Must be a
+ * non-empty string.
+ * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`),
+ * which can be used with, e.g., `tf.Model.save`.
+ */
+ function browserIndexedDB(modelPath) {
+ return new BrowserIndexedDB(modelPath);
+ }
+ function maybeStripScheme$1(key) {
+ return key.startsWith(BrowserIndexedDB.URL_SCHEME) ?
+ key.slice(BrowserIndexedDB.URL_SCHEME.length) :
+ key;
+ }
+ var BrowserIndexedDBManager = /** @class */ (function () {
+ function BrowserIndexedDBManager() {
+ this.indexedDB = getIndexedDBFactory();
+ }
+ BrowserIndexedDBManager.prototype.listModels = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var _this = this;
+ return __generator(this, function (_a) {
+ return [2 /*return*/, new Promise(function (resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); };
+ openRequest.onsuccess = function () {
+ var db = openRequest.result;
+ var tx = db.transaction(INFO_STORE_NAME, 'readonly');
+ var store = tx.objectStore(INFO_STORE_NAME);
+ // tslint:disable:max-line-length
+ // Need to cast `store` as `any` here because TypeScript's DOM
+ // library does not have the `getAll()` method even though the
+ // method is supported in the latest version of most mainstream
+ // browsers:
+ // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
+ // tslint:enable:max-line-length
+ // tslint:disable-next-line:no-any
+ var getAllInfoRequest = store.getAll();
+ getAllInfoRequest.onsuccess = function () {
+ var e_1, _a;
+ var out = {};
+ try {
+ for (var _b = __values(getAllInfoRequest.result), _c = _b.next(); !_c.done; _c = _b.next()) {
+ var item = _c.value;
+ out[item.modelPath] = item.modelArtifactsInfo;
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (_c && !_c.done && (_a = _b.return)) _a.call(_b);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ resolve(out);
+ };
+ getAllInfoRequest.onerror = function (error) {
+ db.close();
+ return reject(getAllInfoRequest.error);
+ };
+ tx.oncomplete = function () { return db.close(); };
+ };
+ openRequest.onerror = function (error) { return reject(openRequest.error); };
+ })];
+ });
+ });
+ };
+ BrowserIndexedDBManager.prototype.removeModel = function (path) {
+ return __awaiter(this, void 0, void 0, function () {
+ var _this = this;
+ return __generator(this, function (_a) {
+ path = maybeStripScheme$1(path);
+ return [2 /*return*/, new Promise(function (resolve, reject) {
+ var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
+ openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); };
+ openRequest.onsuccess = function () {
+ var db = openRequest.result;
+ var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
+ var infoStore = infoTx.objectStore(INFO_STORE_NAME);
+ var getInfoRequest = infoStore.get(path);
+ var modelTx;
+ getInfoRequest.onsuccess = function () {
+ if (getInfoRequest.result == null) {
+ db.close();
+ return reject(new Error("Cannot find model with path '" + path + "' " +
+ "in IndexedDB."));
+ }
+ else {
+ // First, delete the entry in the info store.
+ var deleteInfoRequest = infoStore.delete(path);
+ var deleteModelData_1 = function () {
+ // Second, delete the entry in the model store.
+ modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
+ var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
+ var deleteModelRequest = modelStore.delete(path);
+ deleteModelRequest.onsuccess = function () { return resolve(getInfoRequest.result.modelArtifactsInfo); };
+ deleteModelRequest.onerror = function (error) { return reject(getInfoRequest.error); };
+ };
+ // Proceed with deleting model data regardless of whether deletion
+ // of info data succeeds or not.
+ deleteInfoRequest.onsuccess = deleteModelData_1;
+ deleteInfoRequest.onerror = function (error) {
+ deleteModelData_1();
+ db.close();
+ return reject(getInfoRequest.error);
+ };
+ }
+ };
+ getInfoRequest.onerror = function (error) {
+ db.close();
+ return reject(getInfoRequest.error);
+ };
+ infoTx.oncomplete = function () {
+ if (modelTx == null) {
+ db.close();
+ }
+ else {
+ modelTx.oncomplete = function () { return db.close(); };
+ }
+ };
+ };
+ openRequest.onerror = function (error) { return reject(openRequest.error); };
+ })];
+ });
+ });
+ };
+ return BrowserIndexedDBManager;
+ }());
+
+ var PATH_SEPARATOR = '/';
+ var PATH_PREFIX = 'tensorflowjs_models';
+ var INFO_SUFFIX = 'info';
+ var MODEL_TOPOLOGY_SUFFIX = 'model_topology';
+ var WEIGHT_SPECS_SUFFIX = 'weight_specs';
+ var WEIGHT_DATA_SUFFIX = 'weight_data';
+ var MODEL_METADATA_SUFFIX = 'model_metadata';
+ function getModelKeys(path) {
+ return {
+ info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
+ topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
+ weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
+ weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
+ modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
+ };
+ }
+ function removeItems(keys) {
+ var e_1, _a;
+ try {
+ for (var _b = __values(Object.values(keys)), _c = _b.next(); !_c.done; _c = _b.next()) {
+ var key = _c.value;
+ window.localStorage.removeItem(key);
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (_c && !_c.done && (_a = _b.return)) _a.call(_b);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ }
+ /**
+ * Get model path from a local-storage key.
+ *
+ * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
+ *
+ * @param key
+ */
+ function getModelPathFromKey(key) {
+ var items = key.split(PATH_SEPARATOR);
+ if (items.length < 3) {
+ throw new Error("Invalid key format: " + key);
+ }
+ return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
+ }
+ function maybeStripScheme(key) {
+ return key.startsWith(BrowserLocalStorage.URL_SCHEME) ?
+ key.slice(BrowserLocalStorage.URL_SCHEME.length) :
+ key;
+ }
+ /**
+ * IOHandler subclass: Browser Local Storage.
+ *
+ * See the doc string to `browserLocalStorage` for more details.
+ */
+ var BrowserLocalStorage = /** @class */ (function () {
+ function BrowserLocalStorage(modelPath) {
+ if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
+ typeof window.localStorage === 'undefined') {
+ // TODO(cais): Add more info about what IOHandler subtypes are
+ // available.
+ // Maybe point to a doc page on the web and/or automatically determine
+ // the available IOHandlers and print them in the error message.
+ throw new Error('The current environment does not support local storage.');
+ }
+ this.LS = window.localStorage;
+ if (modelPath == null || !modelPath) {
+ throw new Error('For local storage, modelPath must not be null, undefined or empty.');
+ }
+ this.modelPath = modelPath;
+ this.keys = getModelKeys(this.modelPath);
+ }
+ /**
+ * Save model artifacts to browser local storage.
+ *
+ * See the documentation to `browserLocalStorage` for details on the saved
+ * artifacts.
+ *
+ * @param modelArtifacts The model artifacts to be stored.
+ * @returns An instance of SaveResult.
+ */
+ BrowserLocalStorage.prototype.save = function (modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function () {
+ var topology, weightSpecs, modelArtifactsInfo, metadata;
+ return __generator(this, function (_a) {
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
+ throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
+ 'in binary formats yet.');
+ }
+ else {
+ topology = JSON.stringify(modelArtifacts.modelTopology);
+ weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
+ modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
+ try {
+ this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
+ this.LS.setItem(this.keys.topology, topology);
+ this.LS.setItem(this.keys.weightSpecs, weightSpecs);
+ this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData));
+ metadata = {
+ format: modelArtifacts.format,
+ generatedBy: modelArtifacts.generatedBy,
+ convertedBy: modelArtifacts.convertedBy,
+ signature: modelArtifacts.signature != null ?
+ modelArtifacts.signature :
+ undefined,
+ userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ?
+ modelArtifacts.userDefinedMetadata :
+ undefined,
+ modelInitializer: modelArtifacts.modelInitializer != null ?
+ modelArtifacts.modelInitializer :
+ undefined,
+ trainingConfig: modelArtifacts.trainingConfig != null ?
+ modelArtifacts.trainingConfig :
+ undefined
+ };
+ this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
+ return [2 /*return*/, { modelArtifactsInfo: modelArtifactsInfo }];
+ }
+ catch (err) {
+ // If saving failed, clean up all items saved so far.
+ removeItems(this.keys);
+ throw new Error("Failed to save model '" + this.modelPath + "' to local storage: " +
+ "size quota being exceeded is a possible cause of this failure: " +
+ ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") +
+ ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") +
+ ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + "."));
+ }
+ }
+ return [2 /*return*/];
+ });
+ });
+ };
+ /**
+ * Load a model from local storage.
+ *
+ * See the documentation to `browserLocalStorage` for details on the saved
+ * artifacts.
+ *
+ * @returns The loaded model (if loading succeeds).
+ */
+ BrowserLocalStorage.prototype.load = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64;
+ return __generator(this, function (_a) {
+ info = JSON.parse(this.LS.getItem(this.keys.info));
+ if (info == null) {
+ throw new Error("In local storage, there is no model with name '" + this.modelPath + "'");
+ }
+ if (info.modelTopologyType !== 'JSON') {
+ throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
+ 'topology yet.');
+ }
+ out = {};
+ topology = JSON.parse(this.LS.getItem(this.keys.topology));
+ if (topology == null) {
+ throw new Error("In local storage, the topology of model '" + this.modelPath + "' " +
+ "is missing.");
+ }
+ out.modelTopology = topology;
+ weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
+ if (weightSpecs == null) {
+ throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' " +
+ "are missing.");
+ }
+ out.weightSpecs = weightSpecs;
+ metadataString = this.LS.getItem(this.keys.modelMetadata);
+ if (metadataString != null) {
+ metadata = JSON.parse(metadataString);
+ out.format = metadata.format;
+ out.generatedBy = metadata.generatedBy;
+ out.convertedBy = metadata.convertedBy;
+ if (metadata.signature != null) {
+ out.signature = metadata.signature;
+ }
+ if (metadata.userDefinedMetadata != null) {
+ out.userDefinedMetadata = metadata.userDefinedMetadata;
+ }
+ if (metadata.modelInitializer != null) {
+ out.modelInitializer = metadata.modelInitializer;
+ }
+ if (metadata.trainingConfig != null) {
+ out.trainingConfig = metadata.trainingConfig;
+ }
+ }
+ weightDataBase64 = this.LS.getItem(this.keys.weightData);
+ if (weightDataBase64 == null) {
+ throw new Error("In local storage, the binary weight values of model " +
+ ("'" + this.modelPath + "' are missing."));
+ }
+ out.weightData = base64StringToArrayBuffer(weightDataBase64);
+ return [2 /*return*/, out];
+ });
+ });
+ };
+ return BrowserLocalStorage;
+ }());
+ BrowserLocalStorage.URL_SCHEME = 'localstorage://';
+ var localStorageRouter = function (url) {
+ if (!env().getBool('IS_BROWSER')) {
+ return null;
+ }
+ else {
+ if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
+ return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
+ }
+ else {
+ return null;
+ }
+ }
+ };
+ IORouterRegistry.registerSaveRouter(localStorageRouter);
+ IORouterRegistry.registerLoadRouter(localStorageRouter);
+ /**
+ * Factory function for local storage IOHandler.
+ *
+ * This `IOHandler` supports both `save` and `load`.
+ *
+ * For each model's saved artifacts, four items are saved to local storage.
+ * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
+ * model, such as date saved, type of the topology, size in bytes, etc.
+ * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
+ * style models, this is a stringized JSON.
+ * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
+ * model, can be used to decode the saved binary weight values (see
+ * item below).
+ * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
+ * weight values, stored as a base64-encoded string.
+ *
+ * Saving may throw an `Error` if the total size of the artifacts exceed the
+ * browser-specific quota.
+ *
+ * @param modelPath A unique identifier for the model to be saved. Must be a
+ * non-empty string.
+ * @returns An instance of `IOHandler`, which can be used with, e.g.,
+ * `tf.Model.save`.
+ */
+ function browserLocalStorage(modelPath) {
+ return new BrowserLocalStorage(modelPath);
+ }
+ var BrowserLocalStorageManager = /** @class */ (function () {
+ function BrowserLocalStorageManager() {
+ assert(env().getBool('IS_BROWSER'), function () { return 'Current environment is not a web browser'; });
+ assert(typeof window === 'undefined' ||
+ typeof window.localStorage !== 'undefined', function () { return 'Current browser does not appear to support localStorage'; });
+ this.LS = window.localStorage;
+ }
+ BrowserLocalStorageManager.prototype.listModels = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var out, prefix, suffix, i, key, modelPath;
+ return __generator(this, function (_a) {
+ out = {};
+ prefix = PATH_PREFIX + PATH_SEPARATOR;
+ suffix = PATH_SEPARATOR + INFO_SUFFIX;
+ for (i = 0; i < this.LS.length; ++i) {
+ key = this.LS.key(i);
+ if (key.startsWith(prefix) && key.endsWith(suffix)) {
+ modelPath = getModelPathFromKey(key);
+ out[modelPath] = JSON.parse(this.LS.getItem(key));
+ }
+ }
+ return [2 /*return*/, out];
+ });
+ });
+ };
+ BrowserLocalStorageManager.prototype.removeModel = function (path) {
+ return __awaiter(this, void 0, void 0, function () {
+ var keys, info;
+ return __generator(this, function (_a) {
+ path = maybeStripScheme(path);
+ keys = getModelKeys(path);
+ if (this.LS.getItem(keys.info) == null) {
+ throw new Error("Cannot find model at path '" + path + "'");
+ }
+ info = JSON.parse(this.LS.getItem(keys.info));
+ removeItems(keys);
+ return [2 /*return*/, info];
+ });
+ });
+ };
+ return BrowserLocalStorageManager;
+ }());
+
+ var URL_SCHEME_SUFFIX = '://';
+ var ModelStoreManagerRegistry = /** @class */ (function () {
+ function ModelStoreManagerRegistry() {
+ this.managers = {};
+ }
+ ModelStoreManagerRegistry.getInstance = function () {
+ if (ModelStoreManagerRegistry.instance == null) {
+ ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
+ }
+ return ModelStoreManagerRegistry.instance;
+ };
+ /**
+ * Register a save-handler router.
+ *
+ * @param saveRouter A function that maps a URL-like string onto an instance
+ * of `IOHandler` with the `save` method defined or `null`.
+ */
+ ModelStoreManagerRegistry.registerManager = function (scheme, manager) {
+ assert(scheme != null, function () { return 'scheme must not be undefined or null.'; });
+ if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
+ scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
+ }
+ assert(scheme.length > 0, function () { return 'scheme must not be an empty string.'; });
+ var registry = ModelStoreManagerRegistry.getInstance();
+ assert(registry.managers[scheme] == null, function () { return "A model store manager is already registered for scheme '" + scheme + "'."; });
+ registry.managers[scheme] = manager;
+ };
+ ModelStoreManagerRegistry.getManager = function (scheme) {
+ var manager = this.getInstance().managers[scheme];
+ if (manager == null) {
+ throw new Error("Cannot find model manager for scheme '" + scheme + "'");
+ }
+ return manager;
+ };
+ ModelStoreManagerRegistry.getSchemes = function () {
+ return Object.keys(this.getInstance().managers);
+ };
+ return ModelStoreManagerRegistry;
+ }());
+ /**
+ * Helper method for parsing a URL string into a scheme and a path.
+ *
+ * @param url E.g., 'localstorage://my-model'
+ * @returns A dictionary with two fields: scheme and path.
+ * Scheme: e.g., 'localstorage' in the example above.
+ * Path: e.g., 'my-model' in the example above.
+ */
+ function parseURL(url) {
+ if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
+ throw new Error("The url string provided does not contain a scheme. " +
+ "Supported schemes are: " +
+ ("" + ModelStoreManagerRegistry.getSchemes().join(',')));
+ }
+ return {
+ scheme: url.split(URL_SCHEME_SUFFIX)[0],
+ path: url.split(URL_SCHEME_SUFFIX)[1],
+ };
+ }
+ function cloneModelInternal(sourceURL, destURL, deleteSource) {
+ if (deleteSource === void 0) { deleteSource = false; }
+ return __awaiter(this, void 0, void 0, function () {
+ var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ assert(sourceURL !== destURL, function () { return "Old path and new path are the same: '" + sourceURL + "'"; });
+ loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
+ assert(loadHandlers.length > 0, function () { return "Copying failed because no load handler is found for source URL " + sourceURL + "."; });
+ assert(loadHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " +
+ ("load handlers for source URL " + sourceURL + "."); });
+ loadHandler = loadHandlers[0];
+ saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
+ assert(saveHandlers.length > 0, function () { return "Copying failed because no save handler is found for destination " +
+ ("URL " + destURL + "."); });
+ assert(saveHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " +
+ ("save handlers for destination URL " + destURL + "."); });
+ saveHandler = saveHandlers[0];
+ sourceScheme = parseURL(sourceURL).scheme;
+ sourcePath = parseURL(sourceURL).path;
+ sameMedium = sourceScheme === parseURL(sourceURL).scheme;
+ return [4 /*yield*/, loadHandler.load()];
+ case 1:
+ modelArtifacts = _a.sent();
+ if (!(deleteSource && sameMedium)) return [3 /*break*/, 3];
+ return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme)
+ .removeModel(sourcePath)];
+ case 2:
+ _a.sent();
+ _a.label = 3;
+ case 3: return [4 /*yield*/, saveHandler.save(modelArtifacts)];
+ case 4:
+ saveResult = _a.sent();
+ if (!(deleteSource && !sameMedium)) return [3 /*break*/, 6];
+ return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme)
+ .removeModel(sourcePath)];
+ case 5:
+ _a.sent();
+ _a.label = 6;
+ case 6: return [2 /*return*/, saveResult.modelArtifactsInfo];
+ }
+ });
+ });
+ }
+ /**
+ * List all models stored in registered storage mediums.
+ *
+ * For a web browser environment, the registered mediums are Local Storage and
+ * IndexedDB.
+ *
+ * ```js
+ * // First create and save a model.
+ * const model = tf.sequential();
+ * model.add(tf.layers.dense(
+ * {units: 1, inputShape: [10], activation: 'sigmoid'}));
+ * await model.save('localstorage://demo/management/model1');
+ *
+ * // Then list existing models.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Delete the model.
+ * await tf.io.removeModel('localstorage://demo/management/model1');
+ *
+ * // List models again.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ * ```
+ *
+ * @returns A `Promise` of a dictionary mapping URLs of existing models to
+ * their model artifacts info. URLs include medium-specific schemes, e.g.,
+ * 'indexeddb://my/model/1'. Model artifacts info include type of the
+ * model's topology, byte sizes of the topology, weights, etc.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Management',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function listModels() {
+ return __awaiter(this, void 0, void 0, function () {
+ var schemes, out, schemes_1, schemes_1_1, scheme, schemeOut, path, url, e_1_1;
+ var e_1, _a;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ schemes = ModelStoreManagerRegistry.getSchemes();
+ out = {};
+ _b.label = 1;
+ case 1:
+ _b.trys.push([1, 6, 7, 8]);
+ schemes_1 = __values(schemes), schemes_1_1 = schemes_1.next();
+ _b.label = 2;
+ case 2:
+ if (!!schemes_1_1.done) return [3 /*break*/, 5];
+ scheme = schemes_1_1.value;
+ return [4 /*yield*/, ModelStoreManagerRegistry.getManager(scheme).listModels()];
+ case 3:
+ schemeOut = _b.sent();
+ for (path in schemeOut) {
+ url = scheme + URL_SCHEME_SUFFIX + path;
+ out[url] = schemeOut[path];
+ }
+ _b.label = 4;
+ case 4:
+ schemes_1_1 = schemes_1.next();
+ return [3 /*break*/, 2];
+ case 5: return [3 /*break*/, 8];
+ case 6:
+ e_1_1 = _b.sent();
+ e_1 = { error: e_1_1 };
+ return [3 /*break*/, 8];
+ case 7:
+ try {
+ if (schemes_1_1 && !schemes_1_1.done && (_a = schemes_1.return)) _a.call(schemes_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ return [7 /*endfinally*/];
+ case 8: return [2 /*return*/, out];
+ }
+ });
+ });
+ }
+ /**
+ * Remove a model specified by URL from a reigstered storage medium.
+ *
+ * ```js
+ * // First create and save a model.
+ * const model = tf.sequential();
+ * model.add(tf.layers.dense(
+ * {units: 1, inputShape: [10], activation: 'sigmoid'}));
+ * await model.save('localstorage://demo/management/model1');
+ *
+ * // Then list existing models.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Delete the model.
+ * await tf.io.removeModel('localstorage://demo/management/model1');
+ *
+ * // List models again.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ * ```
+ *
+ * @param url A URL to a stored model, with a scheme prefix, e.g.,
+ * 'localstorage://my-model-1', 'indexeddb://my/model/2'.
+ * @returns ModelArtifactsInfo of the deleted model (if and only if deletion
+ * is successful).
+ * @throws Error if deletion fails, e.g., if no model exists at `path`.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Management',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function removeModel(url) {
+ return __awaiter(this, void 0, void 0, function () {
+ var schemeAndPath, manager;
+ return __generator(this, function (_a) {
+ schemeAndPath = parseURL(url);
+ manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
+ return [2 /*return*/, manager.removeModel(schemeAndPath.path)];
+ });
+ });
+ }
+ /**
+ * Copy a model from one URL to another.
+ *
+ * This function supports:
+ *
+ * 1. Copying within a storage medium, e.g.,
+ * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
+ * 2. Copying between two storage mediums, e.g.,
+ * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
+ *
+ * ```js
+ * // First create and save a model.
+ * const model = tf.sequential();
+ * model.add(tf.layers.dense(
+ * {units: 1, inputShape: [10], activation: 'sigmoid'}));
+ * await model.save('localstorage://demo/management/model1');
+ *
+ * // Then list existing models.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Copy the model, from Local Storage to IndexedDB.
+ * await tf.io.copyModel(
+ * 'localstorage://demo/management/model1',
+ * 'indexeddb://demo/management/model1');
+ *
+ * // List models again.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Remove both models.
+ * await tf.io.removeModel('localstorage://demo/management/model1');
+ * await tf.io.removeModel('indexeddb://demo/management/model1');
+ * ```
+ *
+ * @param sourceURL Source URL of copying.
+ * @param destURL Destination URL of copying.
+ * @returns ModelArtifactsInfo of the copied model (if and only if copying
+ * is successful).
+ * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
+ * if `oldPath` and `newPath` are identical.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Management',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function copyModel(sourceURL, destURL) {
+ return __awaiter(this, void 0, void 0, function () {
+ var deleteSource;
+ return __generator(this, function (_a) {
+ deleteSource = false;
+ return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)];
+ });
+ });
+ }
+ /**
+ * Move a model from one URL to another.
+ *
+ * This function supports:
+ *
+ * 1. Moving within a storage medium, e.g.,
+ * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
+ * 2. Moving between two storage mediums, e.g.,
+ * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
+ *
+ * ```js
+ * // First create and save a model.
+ * const model = tf.sequential();
+ * model.add(tf.layers.dense(
+ * {units: 1, inputShape: [10], activation: 'sigmoid'}));
+ * await model.save('localstorage://demo/management/model1');
+ *
+ * // Then list existing models.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Move the model, from Local Storage to IndexedDB.
+ * await tf.io.moveModel(
+ * 'localstorage://demo/management/model1',
+ * 'indexeddb://demo/management/model1');
+ *
+ * // List models again.
+ * console.log(JSON.stringify(await tf.io.listModels()));
+ *
+ * // Remove the moved model.
+ * await tf.io.removeModel('indexeddb://demo/management/model1');
+ * ```
+ *
+ * @param sourceURL Source URL of moving.
+ * @param destURL Destination URL of moving.
+ * @returns ModelArtifactsInfo of the copied model (if and only if copying
+ * is successful).
+ * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
+ * if `oldPath` and `newPath` are identical.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Management',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function moveModel(sourceURL, destURL) {
+ return __awaiter(this, void 0, void 0, function () {
+ var deleteSource;
+ return __generator(this, function (_a) {
+ deleteSource = true;
+ return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)];
+ });
+ });
+ }
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PlatformBrowser = /** @class */ (function () {
+ function PlatformBrowser() {
+ }
+ PlatformBrowser.prototype.fetch = function (path, init) {
+ return fetch(path, init);
+ };
+ PlatformBrowser.prototype.now = function () {
+ return performance.now();
+ };
+ PlatformBrowser.prototype.encode = function (text, encoding) {
+ if (encoding !== 'utf-8' && encoding !== 'utf8') {
+ throw new Error("Browser's encoder only supports utf-8, but got " + encoding);
+ }
+ if (this.textEncoder == null) {
+ this.textEncoder = new TextEncoder();
+ }
+ return this.textEncoder.encode(text);
+ };
+ PlatformBrowser.prototype.decode = function (bytes, encoding) {
+ return new TextDecoder(encoding).decode(bytes);
+ };
+ return PlatformBrowser;
+ }());
+ if (env().get('IS_BROWSER')) {
+ env().setPlatform('browser', new PlatformBrowser());
+ // Register LocalStorage IOHandler
+ try {
+ ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
+ }
+ catch (err) {
+ }
+ // Register IndexedDB IOHandler
+ try {
+ ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
+ }
+ catch (err) {
+ }
+ }
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // We are wrapping this within an object so it can be stubbed by Jasmine.
+ var getNodeFetch = {
+ // tslint:disable-next-line:no-require-imports
+ importFetch: function () { return require('node-fetch'); }
+ };
+ var systemFetch;
+ var PlatformNode = /** @class */ (function () {
+ function PlatformNode() {
+ // tslint:disable-next-line:no-require-imports
+ this.util = require('util');
+ // According to the spec, the built-in encoder can do only UTF-8 encoding.
+ // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
+ this.textEncoder = new this.util.TextEncoder();
+ }
+ PlatformNode.prototype.fetch = function (path, requestInits) {
+ if (env().global.fetch != null) {
+ return env().global.fetch(path, requestInits);
+ }
+ if (systemFetch == null) {
+ systemFetch = getNodeFetch.importFetch();
+ }
+ return systemFetch(path, requestInits);
+ };
+ PlatformNode.prototype.now = function () {
+ var time = process.hrtime();
+ return time[0] * 1000 + time[1] / 1000000;
+ };
+ PlatformNode.prototype.encode = function (text, encoding) {
+ if (encoding !== 'utf-8' && encoding !== 'utf8') {
+ throw new Error("Node built-in encoder only supports utf-8, but got " + encoding);
+ }
+ return this.textEncoder.encode(text);
+ };
+ PlatformNode.prototype.decode = function (bytes, encoding) {
+ if (bytes.length === 0) {
+ return '';
+ }
+ return new this.util.TextDecoder(encoding).decode(bytes);
+ };
+ return PlatformNode;
+ }());
+ if (env().get('IS_NODE')) {
+ env().setPlatform('node', new PlatformNode());
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
+ *
+ * The values are stored in CPU as `TypedArray`. Fill the buffer using
+ * `buffer.set()`, or by modifying directly `buffer.values`.
+ *
+ * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
+ * those values.
+ *
+ * ```js
+ * // Create a buffer and set values at particular indices.
+ * const buffer = tf.buffer([2, 2]);
+ * buffer.set(3, 0, 0);
+ * buffer.set(5, 1, 0);
+ *
+ * // Convert the buffer back to a tensor.
+ * buffer.toTensor().print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param dtype The dtype of the buffer. Defaults to 'float32'.
+ * @param values The values of the buffer as `TypedArray`. Defaults to
+ * zeros.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function buffer(shape, dtype, values) {
+ if (dtype === void 0) { dtype = 'float32'; }
+ dtype = dtype || 'float32';
+ assertNonNegativeIntegerDimensions(shape);
+ return new TensorBuffer(shape, dtype, values);
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Casts a `tf.Tensor` to a new dtype.
+ *
+ * ```js
+ * const x = tf.tensor1d([1.5, 2.5, 3]);
+ * tf.cast(x, 'int32').print();
+ * ```
+ * @param x The input tensor to be casted.
+ * @param dtype The dtype to cast the input tensor to.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function cast_(x, dtype) {
+ var $x = convertToTensor(x, 'x', 'cast');
+ // Sanity checks.
+ if (!isValidDtype(dtype)) {
+ throw new Error("Failed to cast to unknown dtype " + dtype);
+ }
+ if (dtype === 'string' && $x.dtype !== 'string' ||
+ dtype !== 'string' && $x.dtype === 'string') {
+ throw new Error('Only strings can be casted to strings');
+ }
+ var inputs = { x: $x };
+ var attrs = { dtype: dtype };
+ return ENGINE.runKernel(Cast, inputs, attrs);
+ }
+ var cast = op({ cast_: cast_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a new tensor with the same values and shape as the specified
+ * tensor.
+ *
+ * ```js
+ * const x = tf.tensor([1, 2]);
+ *
+ * x.clone().print();
+ * ```
+ *
+ * @param x The tensor to clone.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function clone_(x) {
+ var $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
+ var inputs = { x: $x };
+ // Note this op is called tf.identity in python. Hence the kernel name used
+ // here.
+ return ENGINE.runKernel(Identity, inputs);
+ }
+ var clone = op({ clone_: clone_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Prints information about the `tf.Tensor` including its data.
+ *
+ * ```js
+ * const verbose = true;
+ * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
+ * ```
+ * @param x The tensor to be printed.
+ * @param verbose Whether to print verbose information about the ` Tensor`,
+ * including dtype and size.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function print(x, verbose) {
+ if (verbose === void 0) { verbose = false; }
+ console.log(x.toString(verbose));
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ getOrMakeEngine();
+ var opHandler = {
+ buffer: buffer,
+ cast: cast,
+ clone: clone,
+ print: print
+ };
+ setOpHandler(opHandler);
+
+ var DEFAULT_FILE_NAME_PREFIX = 'model';
+ var DEFAULT_JSON_EXTENSION_NAME = '.json';
+ var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
+ function defer(f) {
+ return new Promise(function (resolve) { return setTimeout(resolve); }).then(f);
+ }
+ var BrowserDownloads = /** @class */ (function () {
+ function BrowserDownloads(fileNamePrefix) {
+ if (!env().getBool('IS_BROWSER')) {
+ // TODO(cais): Provide info on what IOHandlers are available under the
+ // current environment.
+ throw new Error('browserDownloads() cannot proceed because the current environment ' +
+ 'is not a browser.');
+ }
+ if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
+ fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
+ }
+ if (fileNamePrefix == null || fileNamePrefix.length === 0) {
+ fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
+ }
+ this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
+ this.weightDataFileName =
+ fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
+ }
+ BrowserDownloads.prototype.save = function (modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function () {
+ var weightsURL, weightsManifest, modelJSON, modelJsonURL, jsonAnchor_1, weightDataAnchor_1;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ if (typeof (document) === 'undefined') {
+ throw new Error('Browser downloads are not supported in ' +
+ 'this environment since `document` is not present');
+ }
+ weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' }));
+ if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 1];
+ throw new Error('BrowserDownloads.save() does not support saving model topology ' +
+ 'in binary formats yet.');
+ case 1:
+ weightsManifest = [{
+ paths: ['./' + this.weightDataFileName],
+ weights: modelArtifacts.weightSpecs
+ }];
+ modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
+ modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' }));
+ jsonAnchor_1 = this.modelJsonAnchor == null ?
+ document.createElement('a') :
+ this.modelJsonAnchor;
+ jsonAnchor_1.download = this.modelJsonFileName;
+ jsonAnchor_1.href = modelJsonURL;
+ // Trigger downloads by evoking a click event on the download anchors.
+ // When multiple downloads are started synchronously, Firefox will only
+ // save the last one.
+ return [4 /*yield*/, defer(function () { return jsonAnchor_1.dispatchEvent(new MouseEvent('click')); })];
+ case 2:
+ // Trigger downloads by evoking a click event on the download anchors.
+ // When multiple downloads are started synchronously, Firefox will only
+ // save the last one.
+ _a.sent();
+ if (!(modelArtifacts.weightData != null)) return [3 /*break*/, 4];
+ weightDataAnchor_1 = this.weightDataAnchor == null ?
+ document.createElement('a') :
+ this.weightDataAnchor;
+ weightDataAnchor_1.download = this.weightDataFileName;
+ weightDataAnchor_1.href = weightsURL;
+ return [4 /*yield*/, defer(function () { return weightDataAnchor_1.dispatchEvent(new MouseEvent('click')); })];
+ case 3:
+ _a.sent();
+ _a.label = 4;
+ case 4: return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }];
+ }
+ });
+ });
+ };
+ return BrowserDownloads;
+ }());
+ BrowserDownloads.URL_SCHEME = 'downloads://';
+ var BrowserFiles = /** @class */ (function () {
+ function BrowserFiles(files) {
+ if (files == null || files.length < 1) {
+ throw new Error("When calling browserFiles, at least 1 file is required, " +
+ ("but received " + files));
+ }
+ this.jsonFile = files[0];
+ this.weightsFiles = files.slice(1);
+ }
+ BrowserFiles.prototype.load = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var _this = this;
+ return __generator(this, function (_a) {
+ return [2 /*return*/, new Promise(function (resolve, reject) {
+ var jsonReader = new FileReader();
+ jsonReader.onload = function (event) {
+ // tslint:disable-next-line:no-any
+ var modelJSON = JSON.parse(event.target.result);
+ var modelTopology = modelJSON.modelTopology;
+ if (modelTopology == null) {
+ reject(new Error("modelTopology field is missing from file " + _this.jsonFile.name));
+ return;
+ }
+ var weightsManifest = modelJSON.weightsManifest;
+ if (weightsManifest == null) {
+ reject(new Error("weightManifest field is missing from file " + _this.jsonFile.name));
+ return;
+ }
+ if (_this.weightsFiles.length === 0) {
+ resolve({ modelTopology: modelTopology });
+ return;
+ }
+ var modelArtifactsPromise = getModelArtifactsForJSON(modelJSON, function (weightsManifest) { return _this.loadWeights(weightsManifest); });
+ resolve(modelArtifactsPromise);
+ };
+ jsonReader.onerror = function (error) { return reject("Failed to read model topology and weights manifest JSON " +
+ ("from file '" + _this.jsonFile.name + "'. BrowserFiles supports loading ") +
+ "Keras-style tf.Model artifacts only."); };
+ jsonReader.readAsText(_this.jsonFile);
+ })];
+ });
+ });
+ };
+ BrowserFiles.prototype.loadWeights = function (weightsManifest) {
+ var e_1, _a;
+ var _this = this;
+ var weightSpecs = [];
+ var paths = [];
+ try {
+ for (var weightsManifest_1 = __values(weightsManifest), weightsManifest_1_1 = weightsManifest_1.next(); !weightsManifest_1_1.done; weightsManifest_1_1 = weightsManifest_1.next()) {
+ var entry = weightsManifest_1_1.value;
+ weightSpecs.push.apply(weightSpecs, __spread(entry.weights));
+ paths.push.apply(paths, __spread(entry.paths));
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (weightsManifest_1_1 && !weightsManifest_1_1.done && (_a = weightsManifest_1.return)) _a.call(weightsManifest_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ var pathToFile = this.checkManifestAndWeightFiles(weightsManifest);
+ var promises = paths.map(function (path) { return _this.loadWeightsFile(path, pathToFile[path]); });
+ return Promise.all(promises).then(function (buffers) { return [weightSpecs, concatenateArrayBuffers(buffers)]; });
+ };
+ BrowserFiles.prototype.loadWeightsFile = function (path, file) {
+ return new Promise(function (resolve, reject) {
+ var weightFileReader = new FileReader();
+ weightFileReader.onload = function (event) {
+ // tslint:disable-next-line:no-any
+ var weightData = event.target.result;
+ resolve(weightData);
+ };
+ weightFileReader.onerror = function (error) { return reject("Failed to weights data from file of path '" + path + "'."); };
+ weightFileReader.readAsArrayBuffer(file);
+ });
+ };
+ /**
+ * Check the compatibility between weights manifest and weight files.
+ */
+ BrowserFiles.prototype.checkManifestAndWeightFiles = function (manifest) {
+ var e_2, _a;
+ var _this = this;
+ var basenames = [];
+ var fileNames = this.weightsFiles.map(function (file) { return basename(file.name); });
+ var pathToFile = {};
+ try {
+ for (var manifest_1 = __values(manifest), manifest_1_1 = manifest_1.next(); !manifest_1_1.done; manifest_1_1 = manifest_1.next()) {
+ var group = manifest_1_1.value;
+ group.paths.forEach(function (path) {
+ var pathBasename = basename(path);
+ if (basenames.indexOf(pathBasename) !== -1) {
+ throw new Error("Duplicate file basename found in weights manifest: " +
+ ("'" + pathBasename + "'"));
+ }
+ basenames.push(pathBasename);
+ if (fileNames.indexOf(pathBasename) === -1) {
+ throw new Error("Weight file with basename '" + pathBasename + "' is not provided.");
+ }
+ else {
+ pathToFile[path] = _this.weightsFiles[fileNames.indexOf(pathBasename)];
+ }
+ });
+ }
+ }
+ catch (e_2_1) { e_2 = { error: e_2_1 }; }
+ finally {
+ try {
+ if (manifest_1_1 && !manifest_1_1.done && (_a = manifest_1.return)) _a.call(manifest_1);
+ }
+ finally { if (e_2) throw e_2.error; }
+ }
+ if (basenames.length !== this.weightsFiles.length) {
+ throw new Error("Mismatch in the number of files in weights manifest " +
+ ("(" + basenames.length + ") and the number of weight files provided ") +
+ ("(" + this.weightsFiles.length + ")."));
+ }
+ return pathToFile;
+ };
+ return BrowserFiles;
+ }());
+ var browserDownloadsRouter = function (url) {
+ if (!env().getBool('IS_BROWSER')) {
+ return null;
+ }
+ else {
+ if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
+ return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
+ }
+ else {
+ return null;
+ }
+ }
+ };
+ IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
+ /**
+ * Creates an IOHandler that triggers file downloads from the browser.
+ *
+ * The returned `IOHandler` instance can be used as model exporting methods such
+ * as `tf.Model.save` and supports only saving.
+ *
+ * ```js
+ * const model = tf.sequential();
+ * model.add(tf.layers.dense(
+ * {units: 1, inputShape: [10], activation: 'sigmoid'}));
+ * const saveResult = await model.save('downloads://mymodel');
+ * // This will trigger downloading of two files:
+ * // 'mymodel.json' and 'mymodel.weights.bin'.
+ * console.log(saveResult);
+ * ```
+ *
+ * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
+ * `tf.Model`, `fileNamePrefix` should follow either of the following two
+ * formats:
+ * 1. `null` or `undefined`, in which case the default file
+ * names will be used:
+ * - 'model.json' for the JSON file containing the model topology and
+ * weights manifest.
+ * - 'model.weights.bin' for the binary file containing the binary weight
+ * values.
+ * 2. A single string or an Array of a single string, as the file name prefix.
+ * For example, if `'foo'` is provided, the downloaded JSON
+ * file and binary weights file will be named 'foo.json' and
+ * 'foo.weights.bin', respectively.
+ * @param config Additional configuration for triggering downloads.
+ * @returns An instance of `BrowserDownloads` `IOHandler`.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Loading',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function browserDownloads(fileNamePrefix) {
+ if (fileNamePrefix === void 0) { fileNamePrefix = 'model'; }
+ return new BrowserDownloads(fileNamePrefix);
+ }
+ /**
+ * Creates an IOHandler that loads model artifacts from user-selected files.
+ *
+ * This method can be used for loading from files such as user-selected files
+ * in the browser.
+ * When used in conjunction with `tf.loadLayersModel`, an instance of
+ * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
+ *
+ * ```js
+ * // Note: This code snippet won't run properly without the actual file input
+ * // elements in the HTML DOM.
+ *
+ * // Suppose there are two HTML file input (`<input type="file" ...>`)
+ * // elements.
+ * const uploadJSONInput = document.getElementById('upload-json');
+ * const uploadWeightsInput = document.getElementById('upload-weights');
+ * const model = await tf.loadLayersModel(tf.io.browserFiles(
+ * [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
+ * ```
+ *
+ * @param files `File`s to load from. Currently, this function supports only
+ * loading from files that contain Keras-style models (i.e., `tf.Model`s), for
+ * which an `Array` of `File`s is expected (in that order):
+ * - A JSON file containing the model topology and weight manifest.
+ * - Optionally, One or more binary files containing the binary weights.
+ * These files must have names that match the paths in the `weightsManifest`
+ * contained by the aforementioned JSON file, or errors will be thrown
+ * during loading. These weights files have the same format as the ones
+ * generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
+ * Python PIP package. If no weights files are provided, only the model
+ * topology will be loaded from the JSON file above.
+ * @returns An instance of `Files` `IOHandler`.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Loading',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function browserFiles(files) {
+ return new BrowserFiles(files);
+ }
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Monitor Promise.all progress, fire onProgress callback function.
+ *
+ * @param promises Promise list going to be monitored
+ * @param onProgress Callback function. Fired when a promise resolved.
+ * @param startFraction Optional fraction start. Default to 0.
+ * @param endFraction Optional fraction end. Default to 1.
+ */
+ function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
+ checkPromises(promises);
+ startFraction = startFraction == null ? 0 : startFraction;
+ endFraction = endFraction == null ? 1 : endFraction;
+ checkFraction(startFraction, endFraction);
+ var resolvedPromise = 0;
+ var registerMonitor = function (promise) {
+ promise.then(function (value) {
+ var fraction = startFraction +
+ ++resolvedPromise / promises.length * (endFraction - startFraction);
+ // pass fraction as parameter to callback function.
+ onProgress(fraction);
+ return value;
+ });
+ return promise;
+ };
+ function checkPromises(promises) {
+ assert(promises != null && Array.isArray(promises) && promises.length > 0, function () { return 'promises must be a none empty array'; });
+ }
+ function checkFraction(startFraction, endFraction) {
+ assert(startFraction >= 0 && startFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " +
+ ("got startFraction " + startFraction); });
+ assert(endFraction >= 0 && endFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " +
+ ("got endFraction " + endFraction); });
+ assert(endFraction >= startFraction, function () { return "startFraction must be no more than endFraction, but " +
+ ("got startFraction " + startFraction + " and endFraction ") +
+ ("" + endFraction); });
+ }
+ return Promise.all(promises.map(registerMonitor));
+ }
+
+ /**
+ * Reads binary weights data from a number of URLs.
+ *
+ * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
+ * @param requestOptions RequestInit (options) for the HTTP requests.
+ * @param fetchFunc Optional overriding value for the `window.fetch` function.
+ * @param onProgress Optional, progress callback function, fired periodically
+ * before the load is completed.
+ * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
+ * length as `fetchURLs`.
+ */
+ function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
+ return __awaiter(this, void 0, void 0, function () {
+ var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b;
+ return __generator(this, function (_c) {
+ switch (_c.label) {
+ case 0:
+ if (loadOptions == null) {
+ loadOptions = {};
+ }
+ fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
+ loadOptions.fetchFunc;
+ requests = fetchURLs.map(function (fetchURL) { return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); });
+ fetchStartFraction = 0;
+ fetchEndFraction = 0.5;
+ if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2];
+ return [4 /*yield*/, Promise.all(requests)];
+ case 1:
+ _a = _c.sent();
+ return [3 /*break*/, 4];
+ case 2: return [4 /*yield*/, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)];
+ case 3:
+ _a = _c.sent();
+ _c.label = 4;
+ case 4:
+ responses = _a;
+ bufferPromises = responses.map(function (response) { return response.arrayBuffer(); });
+ bufferStartFraction = 0.5;
+ bufferEndFraction = 1;
+ if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6];
+ return [4 /*yield*/, Promise.all(bufferPromises)];
+ case 5:
+ _b = _c.sent();
+ return [3 /*break*/, 8];
+ case 6: return [4 /*yield*/, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)];
+ case 7:
+ _b = _c.sent();
+ _c.label = 8;
+ case 8:
+ buffers = _b;
+ return [2 /*return*/, buffers];
+ }
+ });
+ });
+ }
+ /**
+ * Reads a weights manifest JSON configuration, fetches the weights and
+ * returns them as `Tensor`s.
+ *
+ * @param manifest The weights manifest JSON.
+ * @param filePathPrefix The path prefix for filenames given in the manifest.
+ * Defaults to the empty string.
+ * @param weightNames The names of the weights to be fetched.
+ */
+ function loadWeights(manifest, filePathPrefix, weightNames, requestInit) {
+ if (filePathPrefix === void 0) { filePathPrefix = ''; }
+ return __awaiter(this, void 0, void 0, function () {
+ var fetchWeights, loadWeights;
+ return __generator(this, function (_a) {
+ fetchWeights = function (fetchUrls) { return loadWeightsAsArrayBuffer(fetchUrls, { requestInit: requestInit }); };
+ loadWeights = weightsLoaderFactory(fetchWeights);
+ return [2 /*return*/, loadWeights(manifest, filePathPrefix, weightNames)];
+ });
+ });
+ }
+ /**
+ * Creates a function, which reads a weights manifest JSON configuration,
+ * fetches the weight files using the specified function and returns them as
+ * `Tensor`s.
+ *
+ * ```js
+ * // example for creating a nodejs weight loader, which reads the weight files
+ * // from disk using fs.readFileSync
+ *
+ * import * as fs from 'fs'
+ *
+ * const fetchWeightsFromDisk = (filePaths: string[]) =>
+ * filePaths.map(filePath => fs.readFileSync(filePath).buffer)
+ *
+ * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
+ *
+ * const manifest = JSON.parse(
+ * fs.readFileSync('./my_model-weights_manifest').toString()
+ * )
+ * const weightMap = await loadWeights(manifest, './')
+ * ```
+ * @param fetchWeightsFunction The function used for fetching the weight files.
+ * @returns Weight loading function.
+ */
+ function weightsLoaderFactory(fetchWeightsFunction) {
+ var _this = this;
+ return function (manifest, filePathPrefix, weightNames) {
+ if (filePathPrefix === void 0) { filePathPrefix = ''; }
+ return __awaiter(_this, void 0, void 0, function () {
+ var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ groupIndicesToFetchMap = manifest.map(function () { return false; });
+ groupWeightsToFetch = {};
+ weightsFound = weightNames != null ? weightNames.map(function () { return false; }) : [];
+ allManifestWeightNames = [];
+ manifest.forEach(function (manifestGroupConfig, groupIndex) {
+ var groupOffset = 0;
+ manifestGroupConfig.weights.forEach(function (weightsEntry) {
+ var rawDtype = ('quantization' in weightsEntry) ?
+ weightsEntry.quantization.dtype :
+ weightsEntry.dtype;
+ var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
+ sizeFromShape(weightsEntry.shape);
+ var enqueueWeightsForFetchingFn = function () {
+ groupIndicesToFetchMap[groupIndex] = true;
+ if (groupWeightsToFetch[groupIndex] == null) {
+ groupWeightsToFetch[groupIndex] = [];
+ }
+ groupWeightsToFetch[groupIndex].push({
+ manifestEntry: weightsEntry,
+ groupOffset: groupOffset,
+ sizeBytes: weightsBytes
+ });
+ };
+ if (weightNames != null) {
+ weightNames.forEach(function (weightName, weightIndex) {
+ if (weightName === weightsEntry.name) {
+ enqueueWeightsForFetchingFn();
+ weightsFound[weightIndex] = true;
+ }
+ });
+ }
+ else {
+ enqueueWeightsForFetchingFn();
+ }
+ allManifestWeightNames.push(weightsEntry.name);
+ groupOffset += weightsBytes;
+ });
+ });
+ if (!weightsFound.every(function (found) { return found; })) {
+ weightsNotFound = weightNames.filter(function (_, i) { return !weightsFound[i]; });
+ throw new Error("Could not find weights in manifest with names: " +
+ (weightsNotFound.join(', ') + ". \n") +
+ "Manifest JSON has weights with names: " +
+ (allManifestWeightNames.join(', ') + "."));
+ }
+ groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) {
+ if (shouldFetch) {
+ accumulator.push(i);
+ }
+ return accumulator;
+ }, []);
+ fetchUrls = [];
+ groupIndicesToFetch.forEach(function (i) {
+ manifest[i].paths.forEach(function (filepath) {
+ var fetchUrl = filePathPrefix +
+ (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
+ fetchUrls.push(fetchUrl);
+ });
+ });
+ return [4 /*yield*/, fetchWeightsFunction(fetchUrls)];
+ case 1:
+ buffers = _a.sent();
+ weightsTensorMap = {};
+ bufferIndexOffset = 0;
+ groupIndicesToFetch.forEach(function (i) {
+ var numBuffers = manifest[i].paths.length;
+ var groupBytes = 0;
+ for (var i_1 = 0; i_1 < numBuffers; i_1++) {
+ groupBytes += buffers[bufferIndexOffset + i_1].byteLength;
+ }
+ // Create a buffer for the whole group.
+ var groupBuffer = new ArrayBuffer(groupBytes);
+ var groupByteBuffer = new Uint8Array(groupBuffer);
+ var groupBufferOffset = 0;
+ for (var i_2 = 0; i_2 < numBuffers; i_2++) {
+ var buffer = new Uint8Array(buffers[bufferIndexOffset + i_2]);
+ groupByteBuffer.set(buffer, groupBufferOffset);
+ groupBufferOffset += buffer.byteLength;
+ }
+ var weightsEntries = groupWeightsToFetch[i];
+ weightsEntries.forEach(function (weightsEntry) {
+ var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
+ var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
+ for (var name in nameToTensorMap) {
+ weightsTensorMap[name] = nameToTensorMap[name];
+ }
+ });
+ bufferIndexOffset += numBuffers;
+ });
+ return [2 /*return*/, weightsTensorMap];
+ }
+ });
+ });
+ };
+ }
+
+ var OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
+ var JSON_TYPE = 'application/json';
+ var HTTPRequest = /** @class */ (function () {
+ function HTTPRequest(path, loadOptions) {
+ this.DEFAULT_METHOD = 'POST';
+ if (loadOptions == null) {
+ loadOptions = {};
+ }
+ this.weightPathPrefix = loadOptions.weightPathPrefix;
+ this.onProgress = loadOptions.onProgress;
+ this.weightUrlConverter = loadOptions.weightUrlConverter;
+ if (loadOptions.fetchFunc != null) {
+ assert(typeof loadOptions.fetchFunc === 'function', function () { return 'Must pass a function that matches the signature of ' +
+ '`fetch` (see ' +
+ 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'; });
+ this.fetch = loadOptions.fetchFunc;
+ }
+ else {
+ this.fetch = env().platform.fetch;
+ }
+ assert(path != null && path.length > 0, function () { return 'URL path for http must not be null, undefined or ' +
+ 'empty.'; });
+ if (Array.isArray(path)) {
+ assert(path.length === 2, function () { return 'URL paths for http must have a length of 2, ' +
+ ("(actual length is " + path.length + ")."); });
+ }
+ this.path = path;
+ if (loadOptions.requestInit != null &&
+ loadOptions.requestInit.body != null) {
+ throw new Error('requestInit is expected to have no pre-existing body, but has one.');
+ }
+ this.requestInit = loadOptions.requestInit || {};
+ }
+ HTTPRequest.prototype.save = function (modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function () {
+ var init, weightsManifest, modelTopologyAndWeightManifest, response;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
+ throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' +
+ 'in binary formats yet.');
+ }
+ init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit);
+ init.body = new FormData();
+ weightsManifest = [{
+ paths: ['./model.weights.bin'],
+ weights: modelArtifacts.weightSpecs,
+ }];
+ modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
+ init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json');
+ if (modelArtifacts.weightData != null) {
+ init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin');
+ }
+ return [4 /*yield*/, this.fetch(this.path, init)];
+ case 1:
+ response = _a.sent();
+ if (response.ok) {
+ return [2 /*return*/, {
+ modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
+ responses: [response],
+ }];
+ }
+ else {
+ throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " +
+ (response.status + "."));
+ }
+ }
+ });
+ });
+ };
+ /**
+ * Load model artifacts via HTTP request(s).
+ *
+ * See the documentation to `tf.io.http` for details on the saved
+ * artifacts.
+ *
+ * @returns The loaded model artifacts (if loading succeeds).
+ */
+ HTTPRequest.prototype.load = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var modelConfigRequest, modelJSON, message, modelTopology, weightsManifest;
+ var _this = this;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.fetch(this.path, this.requestInit)];
+ case 1:
+ modelConfigRequest = _a.sent();
+ if (!modelConfigRequest.ok) {
+ throw new Error("Request to " + this.path + " failed with status code " +
+ (modelConfigRequest.status + ". Please verify this URL points to ") +
+ "the model JSON of the model to load.");
+ }
+ _a.label = 2;
+ case 2:
+ _a.trys.push([2, 4, , 5]);
+ return [4 /*yield*/, modelConfigRequest.json()];
+ case 3:
+ modelJSON = _a.sent();
+ return [3 /*break*/, 5];
+ case 4:
+ _a.sent();
+ message = "Failed to parse model JSON of response from " + this.path + ".";
+ // TODO(nsthorat): Remove this after some time when we're comfortable that
+ // .pb files are mostly gone.
+ if (this.path.endsWith('.pb')) {
+ message += ' Your path contains a .pb file extension. ' +
+ 'Support for .pb models have been removed in TensorFlow.js 1.0 ' +
+ 'in favor of .json models. You can re-convert your Python ' +
+ 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' +
+ 'or you can convert your.pb models with the \'pb2json\'' +
+ 'NPM script in the tensorflow/tfjs-converter repository.';
+ }
+ else {
+ message += ' Please make sure the server is serving valid ' +
+ 'JSON for this request.';
+ }
+ throw new Error(message);
+ case 5:
+ modelTopology = modelJSON.modelTopology;
+ weightsManifest = modelJSON.weightsManifest;
+ if (modelTopology == null && weightsManifest == null) {
+ throw new Error("The JSON from HTTP path " + this.path + " contains neither model " +
+ "topology or manifest for weights.");
+ }
+ return [2 /*return*/, getModelArtifactsForJSON(modelJSON, function (weightsManifest) { return _this.loadWeights(weightsManifest); })];
+ }
+ });
+ });
+ };
+ HTTPRequest.prototype.loadWeights = function (weightsManifest) {
+ return __awaiter(this, void 0, void 0, function () {
+ var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, weightsManifest_1, weightsManifest_1_1, entry, fetchURLs, urlPromises, weightsManifest_2, weightsManifest_2_1, weightsGroup, _b, _c, path, _d, _e, _f, buffers;
+ var e_2, _g, e_3, _h, e_4, _j;
+ return __generator(this, function (_k) {
+ switch (_k.label) {
+ case 0:
+ weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
+ _a = __read(parseUrl(weightPath), 2), prefix = _a[0], suffix = _a[1];
+ pathPrefix = this.weightPathPrefix || prefix;
+ weightSpecs = [];
+ try {
+ for (weightsManifest_1 = __values(weightsManifest), weightsManifest_1_1 = weightsManifest_1.next(); !weightsManifest_1_1.done; weightsManifest_1_1 = weightsManifest_1.next()) {
+ entry = weightsManifest_1_1.value;
+ weightSpecs.push.apply(weightSpecs, __spread(entry.weights));
+ }
+ }
+ catch (e_2_1) { e_2 = { error: e_2_1 }; }
+ finally {
+ try {
+ if (weightsManifest_1_1 && !weightsManifest_1_1.done && (_g = weightsManifest_1.return)) _g.call(weightsManifest_1);
+ }
+ finally { if (e_2) throw e_2.error; }
+ }
+ fetchURLs = [];
+ urlPromises = [];
+ try {
+ for (weightsManifest_2 = __values(weightsManifest), weightsManifest_2_1 = weightsManifest_2.next(); !weightsManifest_2_1.done; weightsManifest_2_1 = weightsManifest_2.next()) {
+ weightsGroup = weightsManifest_2_1.value;
+ try {
+ for (_b = (e_4 = void 0, __values(weightsGroup.paths)), _c = _b.next(); !_c.done; _c = _b.next()) {
+ path = _c.value;
+ if (this.weightUrlConverter != null) {
+ urlPromises.push(this.weightUrlConverter(path));
+ }
+ else {
+ fetchURLs.push(pathPrefix + path + suffix);
+ }
+ }
+ }
+ catch (e_4_1) { e_4 = { error: e_4_1 }; }
+ finally {
+ try {
+ if (_c && !_c.done && (_j = _b.return)) _j.call(_b);
+ }
+ finally { if (e_4) throw e_4.error; }
+ }
+ }
+ }
+ catch (e_3_1) { e_3 = { error: e_3_1 }; }
+ finally {
+ try {
+ if (weightsManifest_2_1 && !weightsManifest_2_1.done && (_h = weightsManifest_2.return)) _h.call(weightsManifest_2);
+ }
+ finally { if (e_3) throw e_3.error; }
+ }
+ if (!this.weightUrlConverter) return [3 /*break*/, 2];
+ _e = (_d = fetchURLs.push).apply;
+ _f = [fetchURLs];
+ return [4 /*yield*/, Promise.all(urlPromises)];
+ case 1:
+ _e.apply(_d, _f.concat([__spread.apply(void 0, [_k.sent()])]));
+ _k.label = 2;
+ case 2: return [4 /*yield*/, loadWeightsAsArrayBuffer(fetchURLs, {
+ requestInit: this.requestInit,
+ fetchFunc: this.fetch,
+ onProgress: this.onProgress
+ })];
+ case 3:
+ buffers = _k.sent();
+ return [2 /*return*/, [weightSpecs, concatenateArrayBuffers(buffers)]];
+ }
+ });
+ });
+ };
+ return HTTPRequest;
+ }());
+ HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
+ /**
+ * Extract the prefix and suffix of the url, where the prefix is the path before
+ * the last file, and suffix is the search params after the last file.
+ * ```
+ * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
+ * [prefix, suffix] = parseUrl(url)
+ * // prefix = 'http://tfhub.dev/model/1/'
+ * // suffix = '?tfjs-format=file'
+ * ```
+ * @param url the model url to be parsed.
+ */
+ function parseUrl(url) {
+ var lastSlash = url.lastIndexOf('/');
+ var lastSearchParam = url.lastIndexOf('?');
+ var prefix = url.substring(0, lastSlash);
+ var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
+ return [prefix + '/', suffix];
+ }
+ function isHTTPScheme(url) {
+ return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
+ }
+ var httpRouter = function (url, loadOptions) {
+ if (typeof fetch === 'undefined' &&
+ (loadOptions == null || loadOptions.fetchFunc == null)) {
+ // `http` uses `fetch` or `node-fetch`, if one wants to use it in
+ // an environment that is not the browser or node they have to setup a
+ // global fetch polyfill.
+ return null;
+ }
+ else {
+ var isHTTP = true;
+ if (Array.isArray(url)) {
+ isHTTP = url.every(function (urlItem) { return isHTTPScheme(urlItem); });
+ }
+ else {
+ isHTTP = isHTTPScheme(url);
+ }
+ if (isHTTP) {
+ return http(url, loadOptions);
+ }
+ }
+ return null;
+ };
+ IORouterRegistry.registerSaveRouter(httpRouter);
+ IORouterRegistry.registerLoadRouter(httpRouter);
+ /**
+ * Creates an IOHandler subtype that sends model artifacts to HTTP server.
+ *
+ * An HTTP request of the `multipart/form-data` mime type will be sent to the
+ * `path` URL. The form data includes artifacts that represent the topology
+ * and/or weights of the model. In the case of Keras-style `tf.Model`, two
+ * blobs (files) exist in form-data:
+ * - A JSON file consisting of `modelTopology` and `weightsManifest`.
+ * - A binary weights file consisting of the concatenated weight values.
+ * These files are in the same format as the one generated by
+ * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
+ *
+ * The following code snippet exemplifies the client-side code that uses this
+ * function:
+ *
+ * ```js
+ * const model = tf.sequential();
+ * model.add(
+ * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
+ *
+ * const saveResult = await model.save(tf.io.http(
+ * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
+ * console.log(saveResult);
+ * ```
+ *
+ * If the default `POST` method is to be used, without any custom parameters
+ * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
+ *
+ * ```js
+ * const saveResult = await model.save('http://model-server:5000/upload');
+ * ```
+ *
+ * The following GitHub Gist
+ * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
+ * implements a server based on [flask](https://github.com/pallets/flask) that
+ * can receive the request. Upon receiving the model artifacts via the requst,
+ * this particular server reconsistutes instances of [Keras
+ * Models](https://keras.io/models/model/) in memory.
+ *
+ *
+ * @param path A URL path to the model.
+ * Can be an absolute HTTP path (e.g.,
+ * 'http://localhost:8000/model-upload)') or a relative path (e.g.,
+ * './model-upload').
+ * @param requestInit Request configurations to be used when sending
+ * HTTP request to server using `fetch`. It can contain fields such as
+ * `method`, `credentials`, `headers`, `mode`, etc. See
+ * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
+ * for more information. `requestInit` must not have a body, because the
+ * body will be set by TensorFlow.js. File blobs representing the model
+ * topology (filename: 'model.json') and the weights of the model (filename:
+ * 'model.weights.bin') will be appended to the body. If `requestInit` has a
+ * `body`, an Error will be thrown.
+ * @param loadOptions Optional configuration for the loading. It includes the
+ * following fields:
+ * - weightPathPrefix Optional, this specifies the path prefix for weight
+ * files, by default this is calculated from the path param.
+ * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
+ * the `fetch` from node-fetch can be used here.
+ * - onProgress Optional, progress callback function, fired periodically
+ * before the load is completed.
+ * @returns An instance of `IOHandler`.
+ *
+ * @doc {
+ * heading: 'Models',
+ * subheading: 'Loading',
+ * namespace: 'io',
+ * ignoreCI: true
+ * }
+ */
+ function http(path, loadOptions) {
+ return new HTTPRequest(path, loadOptions);
+ }
+ /**
+ * Deprecated. Use `tf.io.http`.
+ * @param path
+ * @param loadOptions
+ */
+ function browserHTTPRequest(path, loadOptions) {
+ return http(path, loadOptions);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PassthroughLoader = /** @class */ (function () {
+ function PassthroughLoader(modelArtifacts) {
+ this.modelArtifacts = modelArtifacts;
+ }
+ PassthroughLoader.prototype.load = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ return [2 /*return*/, this.modelArtifacts];
+ });
+ });
+ };
+ return PassthroughLoader;
+ }());
+ var PassthroughSaver = /** @class */ (function () {
+ function PassthroughSaver(saveHandler) {
+ this.saveHandler = saveHandler;
+ }
+ PassthroughSaver.prototype.save = function (modelArtifacts) {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ return [2 /*return*/, this.saveHandler(modelArtifacts)];
+ });
+ });
+ };
+ return PassthroughSaver;
+ }());
+ /**
+ * Creates an IOHandler that loads model artifacts from memory.
+ *
+ * When used in conjunction with `tf.loadLayersModel`, an instance of
+ * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
+ *
+ * ```js
+ * const model = await tf.loadLayersModel(tf.io.fromMemory(
+ * modelTopology, weightSpecs, weightData));
+ * ```
+ *
+ * @param modelArtifacts a object containing model topology (i.e., parsed from
+ * the JSON format).
+ * @param weightSpecs An array of `WeightsManifestEntry` objects describing the
+ * names, shapes, types, and quantization of the weight data.
+ * @param weightData A single `ArrayBuffer` containing the weight data,
+ * concatenated in the order described by the weightSpecs.
+ * @param trainingConfig Model training configuration. Optional.
+ *
+ * @returns A passthrough `IOHandler` that simply loads the provided data.
+ */
+ function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
+ if (arguments.length === 1) {
+ var isModelArtifacts = modelArtifacts.modelTopology != null ||
+ modelArtifacts.weightSpecs != null;
+ if (isModelArtifacts) {
+ return new PassthroughLoader(modelArtifacts);
+ }
+ else {
+ // Legacy support: with only modelTopology.
+ // TODO(cais): Remove this deprecated API.
+ console.warn('Please call tf.io.fromMemory() with only one argument. ' +
+ 'The argument should be of type ModelArtifacts. ' +
+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
+ 'deprecated and will be removed in a future release.');
+ return new PassthroughLoader({ modelTopology: modelArtifacts });
+ }
+ }
+ else {
+ // Legacy support.
+ // TODO(cais): Remove this deprecated API.
+ console.warn('Please call tf.io.fromMemory() with only one argument. ' +
+ 'The argument should be of type ModelArtifacts. ' +
+ 'The multi-argument signature of tf.io.fromMemory() has been ' +
+ 'deprecated and will be removed in a future release.');
+ return new PassthroughLoader({
+ modelTopology: modelArtifacts,
+ weightSpecs: weightSpecs,
+ weightData: weightData,
+ trainingConfig: trainingConfig
+ });
+ }
+ }
+ /**
+ * Creates an IOHandler that passes saved model artifacts to a callback.
+ *
+ * ```js
+ * function handleSave(artifacts) {
+ * // ... do something with the artifacts ...
+ * return {modelArtifactsInfo: {...}, ...};
+ * }
+ *
+ * const saveResult = model.save(tf.io.withSaveHandler(handleSave));
+ * ```
+ *
+ * @param saveHandler A function that accepts a `ModelArtifacts` and returns a
+ * `SaveResult`.
+ */
+ function withSaveHandler(saveHandler) {
+ return new PassthroughSaver(saveHandler);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+ var io = {
+ __proto__: null,
+ browserFiles: browserFiles,
+ browserHTTPRequest: browserHTTPRequest,
+ concatenateArrayBuffers: concatenateArrayBuffers,
+ decodeWeights: decodeWeights,
+ encodeWeights: encodeWeights,
+ fromMemory: fromMemory,
+ getLoadHandlers: getLoadHandlers,
+ getModelArtifactsForJSON: getModelArtifactsForJSON,
+ getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
+ getSaveHandlers: getSaveHandlers,
+ http: http,
+ isHTTPScheme: isHTTPScheme,
+ loadWeights: loadWeights,
+ registerLoadRouter: registerLoadRouter,
+ registerSaveRouter: registerSaveRouter,
+ weightsLoaderFactory: weightsLoaderFactory,
+ withSaveHandler: withSaveHandler,
+ copyModel: copyModel,
+ listModels: listModels,
+ moveModel: moveModel,
+ removeModel: removeModel
+ };
+
+ /**
+ * Computes the dot product of two matrices, A * B. These must be matrices.
+ *
+ * ```js
+ * const a = tf.tensor2d([1, 2], [1, 2]);
+ * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * a.matMul(b).print(); // or tf.matMul(a, b)
+ * ```
+ * @param a First matrix in dot product operation.
+ * @param b Second matrix in dot product operation.
+ * @param transposeA If true, `a` is transposed before multiplication.
+ * @param transposeB If true, `b` is transposed before multiplication.
+ *
+ * @doc {heading: 'Operations', subheading: 'Matrices'}
+ */
+ function matMul_(a, b, transposeA, transposeB) {
+ var _a;
+ if (transposeA === void 0) { transposeA = false; }
+ if (transposeB === void 0) { transposeB = false; }
+ var $a = convertToTensor(a, 'a', 'matMul');
+ var $b = convertToTensor(b, 'b', 'matMul');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ var attrs = { transposeA: transposeA, transposeB: transposeB };
+ return ENGINE.runKernel(BatchMatMul, inputs, attrs);
+ }
+ var matMul$1 = op({ matMul_: matMul_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
+ * value `onValue` (defaults to 1), while all other locations take value
+ * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
+ * `R+1` with the last axis of size `depth`.
+ *
+ * ```js
+ * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
+ * ```
+ *
+ * @param indices `tf.Tensor` of indices with dtype `int32`.
+ * @param depth The depth of the one hot dimension.
+ * @param onValue A number used to fill in the output when the index matches
+ * the location.
+ * @param offValue A number used to fill in the output when the index does
+ * not match the location.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function oneHot_(indices, depth, onValue, offValue) {
+ if (onValue === void 0) { onValue = 1; }
+ if (offValue === void 0) { offValue = 0; }
+ if (depth < 2) {
+ throw new Error("Error in oneHot: depth must be >=2, but it is " + depth);
+ }
+ var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
+ var inputs = { indices: $indices };
+ var attrs = { depth: depth, onValue: onValue, offValue: offValue };
+ return ENGINE.runKernel(OneHot, inputs, attrs);
+ }
+ var oneHot = op({ oneHot_: oneHot_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
+ *
+ * The returned `tf.Tensor`'s dimension `i` will correspond to the input
+ * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
+ * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
+ * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
+ *
+ * ```js
+ * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
+ *
+ * a.transpose().print(); // or tf.transpose(a)
+ * ```
+ *
+ * @param x The tensor to transpose.
+ * @param perm The permutation of the dimensions of a.
+ *
+ * @doc {heading: 'Operations', subheading: 'Matrices'}
+ */
+ function transpose_(x, perm) {
+ var $x = convertToTensor(x, 'x', 'transpose');
+ if (perm == null) {
+ perm = $x.shape.map(function (s, i) { return i; }).reverse();
+ }
+ assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " +
+ ("must match length of perm " + perm + "."); });
+ perm.forEach(function (axis) {
+ assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) +
+ (" but got " + perm); });
+ });
+ if ($x.rank <= 1) {
+ return $x.clone();
+ }
+ var inputs = { x: $x };
+ var attrs = { perm: perm };
+ return ENGINE.runKernel(Transpose, inputs, attrs);
+ }
+ var transpose = op({ transpose_: transpose_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the confusion matrix from true labels and predicted labels.
+ *
+ * ```js
+ * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
+ * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
+ * const numClasses = 3;
+ * const out = tf.math.confusionMatrix(labels, predictions, numClasses);
+ * out.print();
+ * // Expected output matrix:
+ * // [[2, 0, 0],
+ * // [0, 1, 1],
+ * // [0, 0, 1]]
+ * ```
+ *
+ * @param labels The target labels, assumed to be 0-based integers
+ * for the classes. The shape is `[numExamples]`, where
+ * `numExamples` is the number of examples included.
+ * @param predictions The predicted classes, assumed to be
+ * 0-based integers for the classes. Must have the same shape as `labels`.
+ * @param numClasses Number of all classes, as an integer.
+ * Its value must be larger than the largest element in `labels` and
+ * `predictions`.
+ * @returns The confusion matrix as a int32-type 2D tensor. The value at
+ * row `r` and column `c` is the number of times examples of actual class
+ * `r` were predicted as class `c`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Evaluation'}
+ */
+ function confusionMatrix_(labels, predictions, numClasses) {
+ var $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
+ var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
+ assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () { return "If provided, numClasses must be a positive integer, " +
+ ("but got " + numClasses); });
+ assert($labels.rank === 1, function () { return "Expected the rank of labels to be 1, but got " + $labels.rank; });
+ assert($predictions.rank === 1, function () { return "Expected the rank of predictions to be 1, " +
+ ("but got " + $predictions.rank); });
+ assert($labels.shape[0] === $predictions.shape[0], function () { return "Mismatch in the number of examples: " +
+ ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") +
+ "Labels and predictions should have the same number of elements."; });
+ assert(numClasses > 0 && Number.isInteger(numClasses), function () { return "numClasses is required to be a positive integer, but got " +
+ ("" + numClasses); });
+ // TODO(cais): In the future, if oneHot supports tensors inputs for
+ // `numClasses`, `confusionMatrix` can make `numClasses` optional.
+ var oneHotLabels = oneHot(cast($labels, 'int32'), numClasses);
+ var oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses);
+ var oneHotLabelsT = transpose(oneHotLabels);
+ var product = matMul$1(oneHotLabelsT, oneHotPredictions);
+ return cast(product, 'int32');
+ }
+ var confusionMatrix = op({ confusionMatrix_: confusionMatrix_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+ var math = {
+ __proto__: null,
+ confusionMatrix: confusionMatrix
+ };
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the dimensions in the input shape that are broadcasted to
+ * produce the provided output shape.
+ *
+ * The returned dimensions are 0-indexed and sorted. An example:
+ * inShape = [4, 1, 3]
+ * outShape = [5, 4, 3, 3]
+ * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
+ */
+ function getBroadcastDims(inShape, outShape) {
+ var inRank = inShape.length;
+ var dims = [];
+ for (var i = 0; i < inRank; i++) {
+ var dim = inRank - 1 - i;
+ var a = inShape[dim] || 1;
+ var b = outShape[outShape.length - 1 - i] || 1;
+ if (b > 1 && a === 1) {
+ dims.unshift(dim);
+ }
+ }
+ return dims;
+ }
+ /**
+ * Returns the axes in the output space that should be reduced to produce
+ * the input space.
+ */
+ function getReductionAxes(inShape, outShape) {
+ var result = [];
+ for (var i = 0; i < outShape.length; i++) {
+ var inDim = inShape[inShape.length - i - 1];
+ var outAxis = outShape.length - i - 1;
+ var outDim = outShape[outAxis];
+ if (inDim == null || (inDim === 1 && outDim > 1)) {
+ result.unshift(outAxis);
+ }
+ }
+ return result;
+ }
+ function assertAndGetBroadcastShape(shapeA, shapeB) {
+ var result = [];
+ var l = Math.max(shapeA.length, shapeB.length);
+ for (var i = 0; i < l; i++) {
+ var a = shapeA[shapeA.length - i - 1];
+ if (a == null) {
+ a = 1;
+ }
+ var b = shapeB[shapeB.length - i - 1];
+ if (b == null) {
+ b = 1;
+ }
+ if (a === 1) {
+ result.unshift(b);
+ }
+ else if (b === 1) {
+ result.unshift(a);
+ }
+ else if (a !== b) {
+ var errMsg = "Operands could not be broadcast together with shapes " +
+ (shapeA + " and " + shapeB + ".");
+ throw Error(errMsg);
+ }
+ else {
+ result.unshift(a);
+ }
+ }
+ return result;
+ }
+
+ var broadcast_util = {
+ __proto__: null,
+ getBroadcastDims: getBroadcastDims,
+ getReductionAxes: getReductionAxes,
+ assertAndGetBroadcastShape: assertAndGetBroadcastShape
+ };
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor3d` as it makes the code more readable.
+ *
+ * ```js
+ * // Pass a nested array.
+ * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
+ * ```
+ * ```js
+ * // Pass a flat array and specify a shape.
+ * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`.
+ * @param shape The shape of the tensor. If not provided, it is inferred from
+ * `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor3d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 3) {
+ throw new Error('tensor3d() requires shape to have three numbers');
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 3 && inferredShape.length !== 1) {
+ throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
+ }
+ if (inferredShape.length === 1 && shape == null) {
+ throw new Error('tensor3d() requires shape to be provided when `values` ' +
+ 'are a flat array');
+ }
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ var fromPixels2DContext;
+ /**
+ * Creates a `tf.Tensor` from an image.
+ *
+ * ```js
+ * const image = new ImageData(1, 1);
+ * image.data[0] = 100;
+ * image.data[1] = 150;
+ * image.data[2] = 200;
+ * image.data[3] = 255;
+ *
+ * tf.browser.fromPixels(image).print();
+ * ```
+ *
+ * @param pixels The input image to construct the tensor from. The
+ * supported image types are all 4-channel. You can also pass in an image
+ * object with following attributes:
+ * `{data: Uint8Array; width: number; height: number}`
+ * @param numChannels The number of channels of the output tensor. A
+ * numChannels value less than 4 allows you to ignore channels. Defaults to
+ * 3 (ignores alpha channel of input image).
+ *
+ * @returns A Tensor3D with the shape `[height, width, numChannels]`.
+ *
+ * Note: fromPixels can be lossy in some cases, same image may result in
+ * slightly different tensor values, if rendered by different rendering
+ * engines. This means that results from different browsers, or even same
+ * browser with CPU and GPU rendering engines can be different. See discussion
+ * in details:
+ * https://github.com/tensorflow/tfjs/issues/5482
+ *
+ * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
+ */
+ function fromPixels_(pixels, numChannels) {
+ if (numChannels === void 0) { numChannels = 3; }
+ // Sanity checks.
+ if (numChannels > 4) {
+ throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
+ }
+ if (pixels == null) {
+ throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
+ }
+ var isPixelData = false;
+ var isImageData = false;
+ var isVideo = false;
+ var isImage = false;
+ var isCanvasLike = false;
+ var isImageBitmap = false;
+ if (pixels.data instanceof Uint8Array) {
+ isPixelData = true;
+ }
+ else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {
+ isImageData = true;
+ }
+ else if (typeof (HTMLVideoElement) !== 'undefined' &&
+ pixels instanceof HTMLVideoElement) {
+ isVideo = true;
+ }
+ else if (typeof (HTMLImageElement) !== 'undefined' &&
+ pixels instanceof HTMLImageElement) {
+ isImage = true;
+ // tslint:disable-next-line: no-any
+ }
+ else if (pixels.getContext != null) {
+ isCanvasLike = true;
+ }
+ else if (typeof (ImageBitmap) !== 'undefined' && pixels instanceof ImageBitmap) {
+ isImageBitmap = true;
+ }
+ else {
+ throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' +
+ "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " +
+ "in browser, or OffscreenCanvas, ImageData in webworker" +
+ " or {data: Uint32Array, width: number, height: number}, " +
+ ("but was " + pixels.constructor.name));
+ }
+ if (isVideo) {
+ var HAVE_CURRENT_DATA_READY_STATE = 2;
+ if (isVideo &&
+ pixels.readyState <
+ HAVE_CURRENT_DATA_READY_STATE) {
+ throw new Error('The video element has not loaded data yet. Please wait for ' +
+ '`loadeddata` event on the <video> element.');
+ }
+ }
+ // If the current backend has 'FromPixels' registered, it has a more
+ // efficient way of handling pixel uploads, so we call that.
+ var kernel = getKernel(FromPixels, ENGINE.backendName);
+ if (kernel != null) {
+ var inputs = { pixels: pixels };
+ var attrs = { numChannels: numChannels };
+ return ENGINE.runKernel(FromPixels, inputs, attrs);
+ }
+ var _a = __read(isVideo ?
+ [
+ pixels.videoWidth,
+ pixels.videoHeight
+ ] :
+ [pixels.width, pixels.height], 2), width = _a[0], height = _a[1];
+ var vals;
+ if (isCanvasLike) {
+ vals =
+ // tslint:disable-next-line:no-any
+ pixels.getContext('2d').getImageData(0, 0, width, height).data;
+ }
+ else if (isImageData || isPixelData) {
+ vals = pixels.data;
+ }
+ else if (isImage || isVideo || isImageBitmap) {
+ if (fromPixels2DContext == null) {
+ if (typeof document === 'undefined') {
+ if (typeof OffscreenCanvas !== 'undefined' &&
+ typeof OffscreenCanvasRenderingContext2D !== 'undefined') {
+ // @ts-ignore
+ fromPixels2DContext = new OffscreenCanvas(1, 1).getContext('2d');
+ }
+ else {
+ throw new Error('Cannot parse input in current context. ' +
+ 'Reason: OffscreenCanvas Context2D rendering is not supported.');
+ }
+ }
+ else {
+ fromPixels2DContext = document.createElement('canvas').getContext('2d');
+ }
+ }
+ fromPixels2DContext.canvas.width = width;
+ fromPixels2DContext.canvas.height = height;
+ fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
+ vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
+ }
+ var values;
+ if (numChannels === 4) {
+ values = new Int32Array(vals);
+ }
+ else {
+ var numPixels = width * height;
+ values = new Int32Array(numPixels * numChannels);
+ for (var i = 0; i < numPixels; i++) {
+ for (var channel = 0; channel < numChannels; ++channel) {
+ values[i * numChannels + channel] = vals[i * 4 + channel];
+ }
+ }
+ }
+ var outShape = [height, width, numChannels];
+ return tensor3d(values, outShape, 'int32');
+ }
+ // Helper functions for |fromPixelsAsync| to check whether the input can
+ // be wrapped into imageBitmap.
+ function isPixelData(pixels) {
+ return (pixels != null) && (pixels.data instanceof Uint8Array);
+ }
+ function isImageBitmapFullySupported() {
+ return typeof window !== 'undefined' &&
+ typeof (ImageBitmap) !== 'undefined' &&
+ window.hasOwnProperty('createImageBitmap');
+ }
+ function isNonEmptyPixels(pixels) {
+ return pixels != null && pixels.width !== 0 && pixels.height !== 0;
+ }
+ function canWrapPixelsToImageBitmap(pixels) {
+ return isImageBitmapFullySupported() && !(pixels instanceof ImageBitmap) &&
+ isNonEmptyPixels(pixels) && !isPixelData(pixels);
+ }
+ /**
+ * Creates a `tf.Tensor` from an image in async way.
+ *
+ * ```js
+ * const image = new ImageData(1, 1);
+ * image.data[0] = 100;
+ * image.data[1] = 150;
+ * image.data[2] = 200;
+ * image.data[3] = 255;
+ *
+ * (await tf.browser.fromPixelsAsync(image)).print();
+ * ```
+ * This API is the async version of fromPixels. The API will first
+ * check |WRAP_TO_IMAGEBITMAP| flag, and try to wrap the input to
+ * imageBitmap if the flag is set to true.
+ *
+ * @param pixels The input image to construct the tensor from. The
+ * supported image types are all 4-channel. You can also pass in an image
+ * object with following attributes:
+ * `{data: Uint8Array; width: number; height: number}`
+ * @param numChannels The number of channels of the output tensor. A
+ * numChannels value less than 4 allows you to ignore channels. Defaults to
+ * 3 (ignores alpha channel of input image).
+ *
+ * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
+ */
+ function fromPixelsAsync(pixels, numChannels) {
+ if (numChannels === void 0) { numChannels = 3; }
+ return __awaiter(this, void 0, void 0, function () {
+ var inputs, imageBitmap;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ inputs = null;
+ if (!(env().getBool('WRAP_TO_IMAGEBITMAP') &&
+ canWrapPixelsToImageBitmap(pixels))) return [3 /*break*/, 5];
+ imageBitmap = void 0;
+ _a.label = 1;
+ case 1:
+ _a.trys.push([1, 3, , 4]);
+ return [4 /*yield*/, createImageBitmap(pixels, { premultiplyAlpha: 'none' })];
+ case 2:
+ // wrap in try-catch block, because createImageBitmap may not work
+ // properly in some browsers, e.g.
+ // https://bugzilla.mozilla.org/show_bug.cgi?id=1335594
+ // tslint:disable-next-line: no-any
+ imageBitmap = _a.sent();
+ return [3 /*break*/, 4];
+ case 3:
+ _a.sent();
+ imageBitmap = null;
+ return [3 /*break*/, 4];
+ case 4:
+ // createImageBitmap will clip the source size.
+ // In some cases, the input will have larger size than its content.
+ // E.g. new Image(10, 10) but with 1 x 1 content. Using
+ // createImageBitmap will clip the size from 10 x 10 to 1 x 1, which
+ // is not correct. We should avoid wrapping such resouce to
+ // imageBitmap.
+ if (imageBitmap != null && imageBitmap.width === pixels.width &&
+ imageBitmap.height === pixels.height) {
+ inputs = imageBitmap;
+ }
+ else {
+ inputs = pixels;
+ }
+ return [3 /*break*/, 6];
+ case 5:
+ inputs = pixels;
+ _a.label = 6;
+ case 6: return [2 /*return*/, fromPixels_(inputs, numChannels)];
+ }
+ });
+ });
+ }
+ /**
+ * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
+ * canvas.
+ *
+ * When the dtype of the input is 'float32', we assume values in the range
+ * [0-1]. Otherwise, when input is 'int32', we assume values in the range
+ * [0-255].
+ *
+ * Returns a promise that resolves when the canvas has been drawn to.
+ *
+ * @param img A rank-2 tensor with shape `[height, width]`, or a rank-3 tensor
+ * of shape `[height, width, numChannels]`. If rank-2, draws grayscale. If
+ * rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
+ * grayscale. When depth of 3, we draw with the first three components of
+ * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
+ * 4, all four components of the depth dimension correspond to r, g, b, a.
+ * @param canvas The canvas to draw to.
+ *
+ * @doc {heading: 'Browser', namespace: 'browser'}
+ */
+ function toPixels(img, canvas) {
+ return __awaiter(this, void 0, void 0, function () {
+ var $img, originalImgTensor, _a, height, width, depth, data, multiplier, bytes, i, rgba, d, value, j, ctx, imageData;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ $img = convertToTensor(img, 'img', 'toPixels');
+ if (!(img instanceof Tensor)) {
+ originalImgTensor = $img;
+ $img = cast(originalImgTensor, 'int32');
+ originalImgTensor.dispose();
+ }
+ if ($img.rank !== 2 && $img.rank !== 3) {
+ throw new Error("toPixels only supports rank 2 or 3 tensors, got rank " + $img.rank + ".");
+ }
+ _a = __read($img.shape.slice(0, 2), 2), height = _a[0], width = _a[1];
+ depth = $img.rank === 2 ? 1 : $img.shape[2];
+ if (depth > 4 || depth === 2) {
+ throw new Error("toPixels only supports depth of size " +
+ ("1, 3 or 4 but got " + depth));
+ }
+ if ($img.dtype !== 'float32' && $img.dtype !== 'int32') {
+ throw new Error("Unsupported type for toPixels: " + $img.dtype + "." +
+ " Please use float32 or int32 tensors.");
+ }
+ return [4 /*yield*/, $img.data()];
+ case 1:
+ data = _b.sent();
+ multiplier = $img.dtype === 'float32' ? 255 : 1;
+ bytes = new Uint8ClampedArray(width * height * 4);
+ for (i = 0; i < height * width; ++i) {
+ rgba = [0, 0, 0, 255];
+ for (d = 0; d < depth; d++) {
+ value = data[i * depth + d];
+ if ($img.dtype === 'float32') {
+ if (value < 0 || value > 1) {
+ throw new Error("Tensor values for a float32 Tensor must be in the " +
+ ("range [0 - 1] but encountered " + value + "."));
+ }
+ }
+ else if ($img.dtype === 'int32') {
+ if (value < 0 || value > 255) {
+ throw new Error("Tensor values for a int32 Tensor must be in the " +
+ ("range [0 - 255] but encountered " + value + "."));
+ }
+ }
+ if (depth === 1) {
+ rgba[0] = value * multiplier;
+ rgba[1] = value * multiplier;
+ rgba[2] = value * multiplier;
+ }
+ else {
+ rgba[d] = value * multiplier;
+ }
+ }
+ j = i * 4;
+ bytes[j + 0] = Math.round(rgba[0]);
+ bytes[j + 1] = Math.round(rgba[1]);
+ bytes[j + 2] = Math.round(rgba[2]);
+ bytes[j + 3] = Math.round(rgba[3]);
+ }
+ if (canvas != null) {
+ canvas.width = width;
+ canvas.height = height;
+ ctx = canvas.getContext('2d');
+ imageData = new ImageData(bytes, width, height);
+ ctx.putImageData(imageData, 0, 0);
+ }
+ if ($img !== img) {
+ $img.dispose();
+ }
+ return [2 /*return*/, bytes];
+ }
+ });
+ });
+ }
+ var fromPixels = op({ fromPixels_: fromPixels_ });
+
+ var browser = {
+ __proto__: null,
+ fromPixelsAsync: fromPixelsAsync,
+ toPixels: toPixels,
+ fromPixels: fromPixels
+ };
+
+ /**
+ * Validate gather nd inputs.
+ *
+ * @param tensor The tensor contains the source values.
+ * @param indices The tensor contains the indices to slice the source.
+ *
+ * @returns [resultShape, numUpdates, sliceSize, strides]
+ */
+ function prepareAndValidate(tensor, indices) {
+ var tensorRank = tensor.shape.length;
+ var indicesRank = indices.shape.length;
+ if (tensorRank < 1) {
+ throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
+ (" but the rank was " + tensorRank + "."));
+ }
+ if (indicesRank < 1) {
+ throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
+ (" but the rank was " + indicesRank + "."));
+ }
+ if (indices.dtype !== 'int32') {
+ throw new Error('tf.gatherND() expects the indices to be int32 type,' +
+ (" but the dtype was " + indices.dtype + "."));
+ }
+ if (indices.shape[indicesRank - 1] > tensorRank) {
+ throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
+ (indices.shape[indicesRank - 1] + " vs. " + tensorRank));
+ }
+ if (sizeFromShape(tensor.shape) === 0) {
+ throw new Error('Requested more than 0 entries, but input is empty.' +
+ (" Input shape: " + tensor.shape + "."));
+ }
+ var indicesShape = indices.shape;
+ var sliceRank = indicesShape[indicesShape.length - 1];
+ // The result shape is
+ // indices.shape[:-1] + params.shape[indices.shape[-1]:]
+ var nResult = 1;
+ for (var i = 0; i < indicesShape.length - 1; ++i) {
+ nResult *= indicesShape[i];
+ }
+ var inputShape = tensor.shape;
+ var resultShape = indicesShape.slice();
+ resultShape.pop();
+ var sliceSize = 1;
+ for (var i = sliceRank; i < tensorRank; ++i) {
+ sliceSize *= inputShape[i];
+ resultShape.push(inputShape[i]);
+ }
+ var strides = __spread(computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }), [1]).slice(0, sliceRank);
+ return [resultShape, nResult, sliceSize, strides];
+ }
+
+ var gather_nd_util = {
+ __proto__: null,
+ prepareAndValidate: prepareAndValidate
+ };
+
+ /**
+ * Check whether updates.shape = indices.shape[:batchDim] +
+ * shape[sliceDim:]
+ *
+ * @param x The input tensor.
+ */
+ function validateUpdateShape(shape, indices, updates) {
+ var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
+ var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
+ var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
+ ("shape[sliceDim:], got updates.shape: " + updates.shape) +
+ (", indices.shape: " + indices.shape + ", shape: " + shape) +
+ (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + ".");
+ if (updates.rank < batchDim) {
+ throw new Error(shapeError + (" update.rank < " + batchDim + ". "));
+ }
+ if (shape.length < sliceDim + (updates.rank - batchDim)) {
+ throw new Error(shapeError +
+ (" Output shape length < " + (sliceDim + (updates.rank - batchDim))));
+ }
+ if (updates.rank !== batchDim + shape.length - sliceDim) {
+ throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim)));
+ }
+ for (var d = 0; d < batchDim; ++d) {
+ if (updates.shape[d] !== indices.shape[d]) {
+ throw new Error(shapeError +
+ (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ")."));
+ }
+ }
+ for (var d = 0; d < updates.rank - batchDim; ++d) {
+ if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
+ throw new Error(shapeError +
+ (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")"));
+ }
+ }
+ }
+ /**
+ * Validate scatter nd inputs.
+ *
+ * @param update The tensor contains the update values.
+ * @param indices The tensor contains the indices for the update values.
+ * @param shape The shape of the output tensor.
+ */
+ function validateInput$1(updates, indices, shape) {
+ if (indices.rank < 1) {
+ throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
+ (" but the rank was " + indices.rank + "."));
+ }
+ if (updates.rank < 1) {
+ throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
+ (" but the rank was " + updates.rank + "."));
+ }
+ if (indices.dtype !== 'int32') {
+ throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype);
+ }
+ if (shape.length < 1) {
+ throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape);
+ }
+ if (shape.length === 0) {
+ if (indices.size === 0) {
+ throw new Error("Indices specified for empty output. indices shape: " + indices.shape);
+ }
+ if (updates.size === 0) {
+ throw new Error("Updates specified for empty output. updates shape: " + updates.shape);
+ }
+ }
+ validateUpdateShape(shape, indices, updates);
+ }
+ /**
+ * Calculate the shape information for the output.
+ *
+ * @param update The tensor contains the update values.
+ * @param indices The tensor contains the indices for the update values.
+ * @param shape The shape of the output tensor.
+ *
+ * @returns ScatterShapeInfo
+ */
+ function calculateShapes(updates, indices, shape) {
+ // Calculate the number of dimensions in indices
+ var indicesRank = indices.shape.length;
+ var sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
+ // Calculate the number of elements that make up each slice of our updated
+ // tensor. This allows us to work with flattened tensors and copy over whole
+ // slices at a time.
+ var totalNd = shape.length;
+ var sliceSize = 1;
+ for (var i = sliceRank; i < totalNd; ++i) {
+ sliceSize *= shape[i];
+ }
+ var safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
+ var numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
+ var strides = __spread(computeStrides(shape.slice(0, sliceRank)), [1]);
+ var outputSize = sizeFromShape(shape);
+ return { sliceRank: sliceRank, numUpdates: numUpdates, sliceSize: sliceSize, strides: strides, outputSize: outputSize };
+ }
+
+ var scatter_nd_util = {
+ __proto__: null,
+ validateUpdateShape: validateUpdateShape,
+ validateInput: validateInput$1,
+ calculateShapes: calculateShapes
+ };
+
+ var NEW_AXIS = -2;
+ var SHRINK_AXIS = -1;
+ function assertParamsValid(input, begin, size) {
+ var inputRank = input.shape.length;
+ assert(inputRank === begin.length, function () { return "Error in slice" + inputRank + "D: Length of begin " + begin + " must " +
+ ("match the rank of the array (" + inputRank + ")."); });
+ assert(inputRank === size.length, function () { return "Error in slice" + inputRank + "D: Length of size " + size + " must " +
+ ("match the rank of the array (" + inputRank + ")."); });
+ var _loop_1 = function (i) {
+ assert(begin[i] + size[i] <= input.shape[i], function () { return "Error in slice" + inputRank + "D: begin[" + i + "] + size[" + i + "] " +
+ ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")"); });
+ };
+ for (var i = 0; i < inputRank; ++i) {
+ _loop_1(i);
+ }
+ }
+ /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
+ function maskToAxes(mask) {
+ var axes = [];
+ var axis = 0;
+ while (mask > 0) {
+ if (mask & 1) {
+ axes.push(axis);
+ }
+ mask /= 2;
+ axis++;
+ }
+ return axes;
+ }
+ /** Computes the output shape given the strided slice params. */
+ function computeOutShape$2(begin, end, strides) {
+ var size = [];
+ for (var axis = 0; axis < begin.length; axis++) {
+ size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
+ }
+ return size;
+ }
+ // Creates full selection at the elided dimensions. If the dimension matches
+ // the ellipsis mask, override the current stride value. Otherwise, insert.
+ function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
+ var newStrides = __spread(strides);
+ for (var i = newStrides.length; i < inputShape.length; i++) {
+ newStrides.push(1);
+ }
+ for (var i = 0; i < numElidedAxes; i++) {
+ if (i === 0) {
+ newStrides[ellipsisInsertionIndex] = 1;
+ }
+ else {
+ newStrides.splice(ellipsisInsertionIndex, 0 /* num elements to delete */, 1 /* element to add */);
+ newStrides.pop();
+ }
+ }
+ return newStrides;
+ }
+ function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
+ if (normalizedAxis <= ellipsisInsertionIndex) {
+ return normalizedAxis;
+ }
+ return normalizedAxis - (numElidedAxes - 1);
+ }
+ function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
+ var elidedAxes = [];
+ for (var i = 0; i < numElidedAxes; i++) {
+ elidedAxes.push(ellipsisInsertionIndex + i);
+ }
+ return elidedAxes;
+ }
+ // Normalize the start, end and strides.
+ function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
+ var inputRank = inputShape.length;
+ var normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
+ if (ellipsisAxes.length && numInterpolatedAxes > 0) {
+ var fullIndex = ellipsisAxes[0];
+ // The ellipsis applies to the masked index as well as any dimensions
+ // that are interpolated.
+ var numElidedAxes = numInterpolatedAxes + 1;
+ normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
+ normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
+ normalizedStrides =
+ stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
+ }
+ else {
+ for (var axis = 0; axis < inputRank; axis++) {
+ normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
+ normalizedEnd[axis] =
+ stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
+ normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
+ }
+ }
+ return {
+ begin: normalizedBegin,
+ end: normalizedEnd,
+ strides: normalizedStrides
+ };
+ }
+ // Creates full selection at the elided dimensions. If the dimension matches
+ // the ellipsis mask, override the current start value. Otherwise, insert.
+ function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
+ var newIndices = __spread(inputShape);
+ var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
+ for (var axis = 0; axis < newIndices.length; axis++) {
+ if (elidedAxes.indexOf(axis) > -1) {
+ newIndices[axis] = 0;
+ }
+ else {
+ var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
+ var originalValue = originalBegin[originalAxis];
+ if (beginMask & 1 << originalAxis) {
+ originalValue = 0;
+ }
+ newIndices[axis] = originalValue;
+ }
+ }
+ return newIndices;
+ }
+ // Creates full selection at the elided dimensions. If the dimension matches
+ // the ellipsis mask, override the current stop value. Otherwise, insert.
+ function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
+ var newIndices = __spread(inputShape);
+ var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
+ for (var axis = 0; axis < newIndices.length; axis++) {
+ if (elidedAxes.indexOf(axis) > -1) {
+ newIndices[axis] = Number.MAX_SAFE_INTEGER;
+ }
+ else {
+ var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
+ var originalValue = originalEnd[originalAxis];
+ if (endMask & 1 << originalAxis) {
+ originalValue = Number.MAX_SAFE_INTEGER;
+ }
+ newIndices[axis] = originalValue;
+ }
+ }
+ for (var i = 0; i < newIndices.length; i++) {
+ // Handle negative indices
+ var axisSize = inputShape[i];
+ if (newIndices[i] < 0) {
+ newIndices[i] += axisSize;
+ }
+ newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
+ }
+ return newIndices;
+ }
+ function stridesForAxis(strides, axis, ellipsisMask) {
+ var stride = strides[axis];
+ if (ellipsisMask & (1 << axis) || stride == null) {
+ stride = 1;
+ }
+ return stride;
+ }
+ function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
+ // Begin with the specified index
+ var start = startIndices[axis];
+ var stride = strides[axis] || 1;
+ // Check the axis bit from right of masked axes, or the begin index is not set
+ // for the axis.
+ if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
+ if (stride > 0) {
+ // Forward iteration - use the first element. These values will get
+ // clamped below (Note: We could have set them to 0 and axis_size-1, but
+ // use lowest() and max() to maintain symmetry with StopForAxis())
+ start = Number.MIN_SAFE_INTEGER;
+ }
+ else {
+ // Backward iteration - use the last element.
+ start = Number.MAX_SAFE_INTEGER;
+ }
+ }
+ // Handle negative indices
+ var axisSize = inputShape[axis];
+ if (start < 0) {
+ start += axisSize;
+ }
+ // Clamping
+ start = clamp(0, start, axisSize - 1);
+ return start;
+ }
+ function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
+ // Begin with the specified index
+ var stop = stopIndices[axis];
+ var stride = strides[axis] || 1;
+ // Check the axis bit from right of masked axes, or if the stop index is not
+ // set for this axis.
+ if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
+ if (stride > 0) {
+ // Forward iteration - use the last element. These values will get
+ // clamped below
+ stop = Number.MAX_SAFE_INTEGER;
+ }
+ else {
+ // Backward iteration - use the first element.
+ stop = Number.MIN_SAFE_INTEGER;
+ }
+ }
+ // Handle negative indices
+ var axisSize = inputShape[axis];
+ if (stop < 0) {
+ stop += axisSize;
+ }
+ // Clamping
+ // Because the end index points one past the last element, we need slightly
+ // different clamping ranges depending on the direction.
+ if (stride > 0) {
+ // Forward iteration
+ stop = clamp(0, stop, axisSize);
+ }
+ else {
+ // Backward iteration
+ stop = clamp(-1, stop, axisSize - 1);
+ }
+ return stop;
+ }
+ /**
+ * Returns true if the slice occupies a continous set of elements in the
+ * 'flat' space.
+ */
+ function isSliceContinous(shape, begin, size) {
+ // Index of the first axis that has size > 1.
+ var firstNonOneAxis = size.length;
+ for (var i = 0; i < size.length; i++) {
+ if (size[i] > 1) {
+ firstNonOneAxis = i;
+ break;
+ }
+ }
+ for (var i = firstNonOneAxis + 1; i < size.length; i++) {
+ if (begin[i] > 0 || size[i] !== shape[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+ function computeFlatOffset(begin, strides) {
+ var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
+ for (var i = 0; i < begin.length - 1; i++) {
+ flatOffset += begin[i] * strides[i];
+ }
+ return flatOffset;
+ }
+ function parseSliceParams(x, begin, size) {
+ // The following logic allows for more ergonomic calls.
+ var begin_;
+ var xRank = x.shape.length;
+ if (typeof begin === 'number') {
+ begin_ = __spread([begin], new Array(xRank - 1).fill(0));
+ }
+ else if (begin.length < xRank) {
+ begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
+ }
+ else {
+ begin_ = begin.slice();
+ }
+ begin_.forEach(function (d) {
+ assert(d !== -1, function () { return 'slice() does not support negative begin indexing.'; });
+ });
+ var size_;
+ if (size == null) {
+ size_ = new Array(xRank).fill(-1);
+ }
+ else if (typeof size === 'number') {
+ size_ = __spread([size], new Array(xRank - 1).fill(-1));
+ }
+ else if (size.length < xRank) {
+ size_ = size.concat(new Array(xRank - size.length).fill(-1));
+ }
+ else {
+ size_ = size;
+ }
+ size_ = size_.map(function (d, i) {
+ if (d >= 0) {
+ return d;
+ }
+ else {
+ assert(d === -1, function () { return "Negative size values should be exactly -1 but got " +
+ (d + " for the slice() size at index " + i + "."); });
+ return x.shape[i] - begin_[i];
+ }
+ });
+ return [begin_, size_];
+ }
+ // Convert the slicing specification from a sparse representation to a dense
+ // representation. This means that all ellipses and newaxis are expanded out.
+ function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
+ var stridesNonNull;
+ if (strides == null) {
+ stridesNonNull = new Array(begin.length);
+ stridesNonNull.fill(1);
+ }
+ else {
+ stridesNonNull = strides;
+ }
+ // Only one non-zero bit is allowed in ellipsisMask, which means ellipsisMask
+ // is a power of 2. Use bit compares to ensure ellipsisMask is 0 or a power
+ // of 2. When i is a power of 2, i & (i - 1) is always 0.
+ // Also ref:
+ // https://stackoverflow.com/questions/600293/how-to-check-if-a-number-is-a-power-of-2
+ if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
+ throw new Error('Multiple ellipses in slice is not allowed.');
+ }
+ // Step 1: Account for ellipsis and new axis.
+ // Check for ellipsis and count how many non-newaxis there are after.
+ var ellipsisSeen = false;
+ var sparseSpec = {
+ dims: stridesNonNull.length,
+ numAddAxisAfterEllipsis: 0,
+ begin: begin.slice(),
+ end: end.slice(),
+ strides: stridesNonNull.slice(),
+ beginMask: beginMask,
+ endMask: endMask,
+ ellipsisMask: ellipsisMask,
+ newAxisMask: newAxisMask,
+ shrinkAxisMask: shrinkAxisMask
+ };
+ for (var i = 0; i < sparseSpec.dims; i++) {
+ if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
+ sparseSpec.numAddAxisAfterEllipsis++;
+ }
+ if ((1 << i) & ellipsisMask) {
+ ellipsisSeen = true;
+ }
+ }
+ // If no ellipsis insert one at the end.
+ if (!ellipsisSeen) {
+ sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
+ sparseSpec.dims++; // this effects loop iteration below
+ }
+ // Step 2: Make a sparse spec into a full index spec.
+ //
+ // The sparse spec deos not correspond to the number of dimensions.
+ // Make a dense spec that cooresponds to the number of dimensions.
+ //
+ // For example suppose foo[...,3:] on foo.shape = [2, 2, 3] then we need to
+ // produce the missing beginMask for the first two dimensions i.e. from
+ // beginMaskSpec = 0, endMaskSpec = 2, we achieve beginMask = 6 (110),
+ // endMask = 7 (111).
+ var denseSpec = {
+ dims: xShape.length,
+ beginMask: 0,
+ endMask: 0,
+ beginValid: false,
+ endValid: false
+ };
+ buildDenseSpec(sparseSpec, denseSpec);
+ // Step 3: Make implicit ranges (non-zero beginMasks and endMasks) explicit
+ // and bounds check.
+ var isIdentity = true;
+ var sliceDim0 = true;
+ var isSimpleSlice = true;
+ var processingShape = [];
+ var finalShape = [];
+ for (var i = 0; i < xShape.length; ++i) {
+ if (denseSpec.strides[i] === 0) {
+ throw Error("strides[" + i + "] must be non-zero");
+ }
+ var shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
+ var dimI = xShape[i];
+ if (dimI === -1) {
+ processingShape.push(shrinkI ? 1 : -1);
+ continue;
+ }
+ var masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
+ var validRange = [
+ denseSpec.strides[i] > 0 ? 0 : -1,
+ denseSpec.strides[i] > 0 ? dimI : dimI - 1
+ ];
+ if (shrinkI && denseSpec.strides[i] <= 0) {
+ throw Error('only stride 1 allowed on non-range indexing.');
+ }
+ isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
+ var beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
+ if (denseSpec.beginValid && denseSpec.endValid) {
+ if (shrinkI) {
+ // If we are shrinking, the end index is now possibly incorrect. In
+ // particular foo[-1] produces sparseBegin = -1, sparseEnd = 0.
+ // and canonical puts these to n-1 and 0, which implies a degenerate
+ // interval. Fortunately, it is now safe to re-create end as begin + 1.
+ var xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
+ denseSpec.begin[i];
+ denseSpec.begin[i] = xFwd;
+ denseSpec.end[i] = denseSpec.begin[i] + 1;
+ if (xFwd < 0 || xFwd >= dimI) {
+ throw Error("slice index " + denseSpec.begin[i] + " of dimension " + i + " out of bounds.");
+ }
+ }
+ else {
+ denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange);
+ denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
+ }
+ // Update optimization values
+ var takeAllInDimension = denseSpec.strides[i] === 1 &&
+ denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
+ isIdentity = isIdentity && takeAllInDimension;
+ sliceDim0 = sliceDim0 &&
+ ((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
+ }
+ else {
+ isIdentity =
+ isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
+ sliceDim0 = sliceDim0 &&
+ ((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
+ }
+ // Compute the processing shape (the intermediate Eigen will produce)
+ var intervalLength = void 0;
+ var knownInterval = false;
+ if (denseSpec.beginValid && denseSpec.endValid) {
+ intervalLength = denseSpec.end[i] - denseSpec.begin[i];
+ knownInterval = true;
+ }
+ else if (shrinkI) {
+ // The dimension is still known as 1 for the processingShape, but will be
+ // discarded for the final shape.
+ intervalLength = 1;
+ knownInterval = true;
+ }
+ else if (beginAndEndMasked) {
+ // Even if we don't have values for begin or end, we do know that this
+ // dimension covers the whole interval. If we have shape information for
+ // this dimension, that tells us the interval length.
+ if (dimI >= 0) {
+ if (denseSpec.strides[i] < 0) {
+ intervalLength = -dimI;
+ }
+ else {
+ intervalLength = dimI;
+ }
+ knownInterval = true;
+ }
+ }
+ if (knownInterval) {
+ var sizeI = void 0;
+ // Hold zero if the interval is degenerate, otherwise account for
+ // remainder
+ if (intervalLength === 0 ||
+ ((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
+ sizeI = 0;
+ }
+ else {
+ sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
+ (intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
+ }
+ processingShape.push(sizeI);
+ }
+ else {
+ processingShape.push(-1);
+ }
+ }
+ // Step 4: Compute the final shape
+ //
+ // newAxis will increase dimension by 1 (with a one-size dimension)
+ // slices like foo[3, ...] will reduce dimension by 1.
+ // This cannot be done earlier, because it depends on Step 3.
+ for (var denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
+ var gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
+ if (gatherIndex >= 0) {
+ finalShape.push(processingShape[gatherIndex]);
+ }
+ else if (gatherIndex === NEW_AXIS) {
+ finalShape.push(1);
+ }
+ }
+ var finalShapeSparse = finalShape.filter(function (dim, i) { return denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS; });
+ return {
+ finalShapeSparse: finalShapeSparse,
+ finalShape: finalShape,
+ isIdentity: isIdentity,
+ sliceDim0: sliceDim0,
+ isSimpleSlice: isSimpleSlice,
+ begin: denseSpec.begin,
+ end: denseSpec.end,
+ strides: denseSpec.strides
+ };
+ }
+ function buildDenseSpec(sparse, dense) {
+ dense.beginMask = 0;
+ dense.endMask = 0;
+ dense.shrinkAxisMask = 0;
+ var fullIndex = 0;
+ dense.beginValid = sparse.begin != null;
+ dense.endValid = sparse.end != null;
+ dense.begin = new Array(dense.dims);
+ dense.end = new Array(dense.dims);
+ dense.strides = new Array(dense.dims);
+ dense.finalShapeGatherIndices = [];
+ dense.finalShapeGatherIndicesSparse = [];
+ dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
+ for (var i = 0; i < sparse.dims; i++) {
+ if ((1 << i) & sparse.ellipsisMask) {
+ // Only the bit that has ellipsis will fall in this condition.
+ // Expand the ellipsis into the appropriate indices
+ // Note: this only works because we guaranteed one ellipsis.
+ var nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
+ for (; fullIndex < nextIndex; fullIndex++) {
+ // newAxis aren't real axis so you have to skip.
+ dense.begin[fullIndex] = 0;
+ dense.end[fullIndex] = 0;
+ dense.strides[fullIndex] = 1;
+ dense.beginMask |= (1 << fullIndex);
+ dense.endMask |= (1 << fullIndex);
+ dense.finalShapeGatherIndices.push(fullIndex);
+ dense.finalShapeGatherIndicesSparse.push(-1);
+ dense.inputShapeGatherIndicesSparse[fullIndex] = i;
+ }
+ }
+ else if ((1 << i) & sparse.newAxisMask) {
+ // Only the bit that has newAxis will fall in this condition.
+ dense.finalShapeGatherIndices.push(NEW_AXIS);
+ dense.finalShapeGatherIndicesSparse.push(-1);
+ }
+ else {
+ if (fullIndex === dense.begin.length) {
+ throw Error("Index out of range using input dim " + fullIndex + "; input " +
+ ("has only " + dense.dims + " dims, " + dense.begin.length + "."));
+ }
+ // Gather slicing spec into appropriate index.
+ if (sparse.begin != null) {
+ dense.begin[fullIndex] = sparse.begin[i];
+ }
+ if (sparse.end != null) {
+ dense.end[fullIndex] = sparse.end[i];
+ }
+ dense.strides[fullIndex] = sparse.strides[i];
+ if (sparse.beginMask & (1 << i)) {
+ dense.beginMask |= (1 << fullIndex);
+ }
+ if (sparse.endMask & (1 << i)) {
+ dense.endMask |= (1 << fullIndex);
+ }
+ // If shrink, record where to get the dimensionality from (i.e. newAxis)
+ // creates a fake 1 size dimension. Also remember shrink axis (now in
+ // dense form) so we can ignore dense.end below.
+ if (sparse.shrinkAxisMask & (1 << i)) {
+ dense.finalShapeGatherIndices.push(SHRINK_AXIS);
+ dense.finalShapeGatherIndicesSparse.push(-1);
+ dense.shrinkAxisMask |= (1 << fullIndex);
+ }
+ else {
+ dense.finalShapeGatherIndices.push(fullIndex);
+ // Remember that where in the sparse shape the dense dim comes from.
+ dense.finalShapeGatherIndicesSparse.push(i);
+ }
+ dense.inputShapeGatherIndicesSparse[fullIndex] = i;
+ fullIndex++;
+ }
+ }
+ }
+ function canonical(x, c, strideI, dimI, masks, validRange) {
+ if (masks[c]) {
+ return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1];
+ }
+ else {
+ var xFwd = x < 0 ? dimI + x : x; // make negative indices positive
+ return xFwd < validRange[0] ? validRange[0] :
+ xFwd > validRange[1] ? validRange[1] : xFwd;
+ }
+ }
+
+ var slice_util = {
+ __proto__: null,
+ assertParamsValid: assertParamsValid,
+ maskToAxes: maskToAxes,
+ computeOutShape: computeOutShape$2,
+ stridesWithElidedDims: stridesWithElidedDims,
+ getNormalizedAxes: getNormalizedAxes,
+ startIndicesWithElidedDims: startIndicesWithElidedDims,
+ stopIndicesWithElidedDims: stopIndicesWithElidedDims,
+ stridesForAxis: stridesForAxis,
+ startForAxis: startForAxis,
+ stopForAxis: stopForAxis,
+ isSliceContinous: isSliceContinous,
+ computeFlatOffset: computeFlatOffset,
+ parseSliceParams: parseSliceParams,
+ sliceInfo: sliceInfo
+ };
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Serializable defines the serialization contract.
+ *
+ * TFJS requires serializable classes to return their className when asked
+ * to avoid issues with minification.
+ */
+ var Serializable = /** @class */ (function () {
+ function Serializable() {
+ }
+ /**
+ * Return the class name for this class to use in serialization contexts.
+ *
+ * Generally speaking this will be the same thing that constructor.name
+ * would have returned. However, the class name needs to be robust
+ * against minification for serialization/deserialization to work properly.
+ *
+ * There's also places such as initializers.VarianceScaling, where
+ * implementation details between different languages led to different
+ * class hierarchies and a non-leaf node is used for serialization purposes.
+ */
+ Serializable.prototype.getClassName = function () {
+ return this.constructor
+ .className;
+ };
+ /**
+ * Creates an instance of T from a ConfigDict.
+ *
+ * This works for most descendants of serializable. A few need to
+ * provide special handling.
+ * @param cls A Constructor for the class to instantiate.
+ * @param config The Configuration for the object.
+ */
+ /** @nocollapse */
+ Serializable.fromConfig = function (cls, config) {
+ return new cls(config);
+ };
+ return Serializable;
+ }());
+ /**
+ * Maps string keys to class constructors.
+ *
+ * Used during (de)serialization from the cross-language JSON format, which
+ * requires the class name in the serialization format matches the class
+ * names as used in Python, should it exist.
+ */
+ var SerializationMap = /** @class */ (function () {
+ function SerializationMap() {
+ this.classNameMap = {};
+ }
+ /**
+ * Returns the singleton instance of the map.
+ */
+ SerializationMap.getMap = function () {
+ if (SerializationMap.instance == null) {
+ SerializationMap.instance = new SerializationMap();
+ }
+ return SerializationMap.instance;
+ };
+ /**
+ * Registers the class as serializable.
+ */
+ SerializationMap.register = function (cls) {
+ SerializationMap.getMap().classNameMap[cls.className] =
+ [cls, cls.fromConfig];
+ };
+ return SerializationMap;
+ }());
+ /**
+ * Register a class with the serialization map of TensorFlow.js.
+ *
+ * This is often used for registering custom Layers, so they can be
+ * serialized and deserialized.
+ *
+ * Example:
+ *
+ * ```js
+ * class MyCustomLayer extends tf.layers.Layer {
+ * static className = 'MyCustomLayer';
+ *
+ * constructor(config) {
+ * super(config);
+ * }
+ * }
+ * tf.serialization.registerClass(MyCustomLayer);
+ * ```
+ *
+ * @param cls The class to be registered. It must have a public static member
+ * called `className` defined and the value must be a non-empty string.
+ *
+ * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
+ */
+ function registerClass(cls) {
+ assert(cls.className != null, function () { return "Class being registered does not have the static className " +
+ "property defined."; });
+ assert(typeof cls.className === 'string', function () { return "className is required to be a string, but got type " +
+ typeof cls.className; });
+ assert(cls.className.length > 0, function () { return "Class being registered has an empty-string as its className, " +
+ "which is disallowed."; });
+ SerializationMap.register(cls);
+ }
+
+ var serialization = {
+ __proto__: null,
+ Serializable: Serializable,
+ SerializationMap: SerializationMap,
+ registerClass: registerClass
+ };
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var TEST_EPSILON_FLOAT32 = 1e-3;
+ var TEST_EPSILON_FLOAT16 = 1e-1;
+ function expectArraysClose(actual, expected, epsilon) {
+ if (epsilon == null) {
+ epsilon = testEpsilon();
+ }
+ return expectArraysPredicate(actual, expected, function (a, b) { return areClose(a, b, epsilon); });
+ }
+ function testEpsilon() {
+ return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 :
+ TEST_EPSILON_FLOAT16;
+ }
+ function expectArraysPredicate(actual, expected, predicate) {
+ var checkClassType = true;
+ if (isTypedArray(actual) || isTypedArray(expected)) {
+ checkClassType = false;
+ }
+ if (isTypedArray(actual) && isTypedArray(expected)) {
+ checkClassType = true;
+ }
+ if (checkClassType) {
+ var aType = actual.constructor.name;
+ var bType = expected.constructor.name;
+ if (aType !== bType) {
+ throw new Error("Arrays are of different type. Actual: " + aType + ". " +
+ ("Expected: " + bType));
+ }
+ }
+ if (Array.isArray(actual) && Array.isArray(expected)) {
+ var actualShape = inferShape(actual);
+ var expectedShape = inferShape(expected);
+ if (!arraysEqual(actualShape, expectedShape)) {
+ throw new Error("Arrays have different shapes. " +
+ ("Actual: [" + actualShape + "]. Expected: [" + expectedShape + "]"));
+ }
+ }
+ var actualFlat = isTypedArray(actual) ? actual : flatten(actual);
+ var expectedFlat = isTypedArray(expected) ?
+ expected :
+ flatten(expected);
+ if (actualFlat.length !== expectedFlat.length) {
+ throw new Error("Arrays have different lengths actual: " + actualFlat.length + " vs " +
+ ("expected: " + expectedFlat.length + ".\n") +
+ ("Actual: " + actualFlat + ".\n") +
+ ("Expected: " + expectedFlat + "."));
+ }
+ for (var i = 0; i < expectedFlat.length; ++i) {
+ var a = actualFlat[i];
+ var e = expectedFlat[i];
+ if (!predicate(a, e)) {
+ throw new Error("Arrays differ: actual[" + i + "] = " + a + ", expected[" + i + "] = " + e + ".\n" +
+ ("Actual: " + actualFlat + ".\n") +
+ ("Expected: " + expectedFlat + "."));
+ }
+ }
+ }
+ function expectPromiseToFail(fn, done) {
+ fn().then(function () { return done.fail(); }, function () { return done(); });
+ }
+ function expectArraysEqual(actual, expected) {
+ var exp = typeof expected === 'string' || typeof expected === 'number' ||
+ typeof expected === 'boolean' ?
+ [expected] :
+ expected;
+ if (isString(actual) || isString(actual[0]) ||
+ isString(expected) || isString(expected[0])) {
+ // tslint:disable-next-line: triple-equals
+ return expectArraysPredicate(actual, exp, function (a, b) { return a == b; });
+ }
+ return expectArraysPredicate(actual, expected, function (a, b) { return areClose(a, b, 0); });
+ }
+ function expectNumbersClose(a, e, epsilon) {
+ if (epsilon == null) {
+ epsilon = testEpsilon();
+ }
+ if (!areClose(a, e, epsilon)) {
+ throw new Error("Numbers differ: actual === " + a + ", expected === " + e);
+ }
+ }
+ function areClose(a, e, epsilon) {
+ if (!isFinite(a) && !isFinite(e)) {
+ return true;
+ }
+ if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
+ return false;
+ }
+ return true;
+ }
+ function expectValuesInRange(actual, low, high) {
+ for (var i = 0; i < actual.length; i++) {
+ if (actual[i] < low || actual[i] > high) {
+ throw new Error("Value out of range:" + actual[i] + " low: " + low + ", high: " + high);
+ }
+ }
+ }
+ function expectArrayBuffersEqual(actual, expected) {
+ // Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in
+ // a Float32Array solves this issue.
+ expect(new Float32Array(actual)).toEqual(new Float32Array(expected));
+ }
+ /** Encodes strings into utf-8 bytes. */
+ function encodeStrings(a) {
+ for (var i = 0; i < a.length; i++) {
+ var val = a[i];
+ if (Array.isArray(val)) {
+ encodeStrings(val);
+ }
+ else {
+ a[i] = encodeString(val);
+ }
+ }
+ return a;
+ }
+
+ var test_util = {
+ __proto__: null,
+ TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
+ expectArraysClose: expectArraysClose,
+ testEpsilon: testEpsilon,
+ expectPromiseToFail: expectPromiseToFail,
+ expectArraysEqual: expectArraysEqual,
+ expectNumbersClose: expectNumbersClose,
+ expectValuesInRange: expectValuesInRange,
+ expectArrayBuffersEqual: expectArrayBuffersEqual,
+ encodeStrings: encodeStrings
+ };
+
+ /** @license See the LICENSE file. */
+ // This code is auto-generated, do not modify this file!
+ var version = '3.12.0';
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Enables production mode which disables correctness checks in favor of
+ * performance.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ function enableProdMode() {
+ env().set('PROD', true);
+ }
+ /**
+ * Enables debug mode which will log information about all executed kernels:
+ * the elapsed time of the kernel execution, as well as the rank, shape, and
+ * size of the output tensor.
+ *
+ * Debug mode will significantly slow down your application as it will
+ * download the result of every operation to the CPU. This should not be used in
+ * production. Debug mode does not affect the timing information of the kernel
+ * execution as we do not measure download time in the kernel execution time.
+ *
+ * See also: `tf.profile`, `tf.memory`.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ function enableDebugMode() {
+ env().set('DEBUG', true);
+ }
+ /** Globally disables deprecation warnings */
+ function disableDeprecationWarnings() {
+ env().set('DEPRECATION_WARNINGS_ENABLED', false);
+ console.warn("TensorFlow.js deprecation warnings have been disabled.");
+ }
+ /** Warn users about deprecated functionality. */
+ function deprecationWarn(msg) {
+ if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
+ console.warn(msg + ' You can disable deprecation warnings with ' +
+ 'tf.disableDeprecationWarnings().');
+ }
+ }
+ /**
+ * Dispose all variables kept in backend engine.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ function disposeVariables() {
+ ENGINE.disposeVariables();
+ }
+ /**
+ * It returns the global engine that keeps track of all tensors and backends.
+ *
+ * @doc {heading: 'Environment'}
+ */
+ function engine() {
+ return ENGINE;
+ }
+ /**
+ * Returns memory info at the current time in the program. The result is an
+ * object with the following properties:
+ *
+ * - `numBytes`: Number of bytes allocated (undisposed) at this time.
+ * - `numTensors`: Number of unique tensors allocated.
+ * - `numDataBuffers`: Number of unique data buffers allocated
+ * (undisposed) at this time, which is ≤ the number of tensors
+ * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
+ * data buffer with `a`).
+ * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
+ * `unreliable` is true.
+ * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
+ * `unreliable` is true.
+ *
+ * WebGL Properties:
+ * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
+ * this time.
+ *
+ * @doc {heading: 'Performance', subheading: 'Memory'}
+ */
+ function memory() {
+ return ENGINE.memory();
+ }
+ /**
+ * Executes the provided function `f()` and returns a promise that resolves
+ * with information about the function's memory use:
+ * - `newBytes`: the number of new bytes allocated
+ * - `newTensors`: the number of new tensors created
+ * - `peakBytes`: the peak number of bytes allocated
+ * - `kernels`: an array of objects for each kernel involved that reports
+ * their input and output shapes, number of bytes used, and number of new
+ * tensors created.
+ * - `kernelNames`: an array of unique strings with just the names of the
+ * kernels in the `kernels` array.
+ *
+ * ```js
+ * const profile = await tf.profile(() => {
+ * const x = tf.tensor1d([1, 2, 3]);
+ * let x2 = x.square();
+ * x2.dispose();
+ * x2 = x.square();
+ * x2.dispose();
+ * return x;
+ * });
+ *
+ * console.log(`newBytes: ${profile.newBytes}`);
+ * console.log(`newTensors: ${profile.newTensors}`);
+ * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
+ * k.totalBytesSnapshot)}`);
+ * ```
+ *
+ *
+ * @doc {heading: 'Performance', subheading: 'Profile'}
+ */
+ function profile(f) {
+ return ENGINE.profile(f);
+ }
+ /**
+ * Executes the provided function `fn` and after it is executed, cleans up all
+ * intermediate tensors allocated by `fn` except those returned by `fn`.
+ * `fn` must not return a Promise (async functions not allowed). The returned
+ * result can be a complex object.
+ *
+ * Using this method helps avoid memory leaks. In general, wrap calls to
+ * operations in `tf.tidy` for automatic memory cleanup.
+ *
+ * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
+ * dispose variables, please use `tf.disposeVariables` or call dispose()
+ * directly on variables.
+ *
+ * ```js
+ * // y = 2 ^ 2 + 1
+ * const y = tf.tidy(() => {
+ * // a, b, and one will be cleaned up when the tidy ends.
+ * const one = tf.scalar(1);
+ * const a = tf.scalar(2);
+ * const b = a.square();
+ *
+ * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
+ *
+ * // The value returned inside the tidy function will return
+ * // through the tidy, in this case to the variable y.
+ * return b.add(one);
+ * });
+ *
+ * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
+ * y.print();
+ * ```
+ *
+ * @param nameOrFn The name of the closure, or the function to execute.
+ * If a name is provided, the 2nd argument should be the function.
+ * If debug mode is on, the timing and the memory usage of the function
+ * will be tracked and displayed on the console using the provided name.
+ * @param fn The function to execute.
+ *
+ * @doc {heading: 'Performance', subheading: 'Memory'}
+ */
+ function tidy(nameOrFn, fn) {
+ return ENGINE.tidy(nameOrFn, fn);
+ }
+ /**
+ * Disposes any `tf.Tensor`s found within the provided object.
+ *
+ * @param container an object that may be a `tf.Tensor` or may directly
+ * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
+ * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
+ * happens. In general it is safe to pass any object here, except that
+ * `Promise`s are not supported.
+ *
+ * @doc {heading: 'Performance', subheading: 'Memory'}
+ */
+ function dispose(container) {
+ var tensors = getTensorsInContainer(container);
+ tensors.forEach(function (tensor) { return tensor.dispose(); });
+ }
+ /**
+ * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
+ * automatically.
+ *
+ * ```js
+ * let b;
+ * const y = tf.tidy(() => {
+ * const one = tf.scalar(1);
+ * const a = tf.scalar(2);
+ *
+ * // b will not be cleaned up by the tidy. a and one will be cleaned up
+ * // when the tidy ends.
+ * b = tf.keep(a.square());
+ *
+ * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
+ *
+ * // The value returned inside the tidy function will return
+ * // through the tidy, in this case to the variable y.
+ * return b.add(one);
+ * });
+ *
+ * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
+ * console.log('y:');
+ * y.print();
+ * console.log('b:');
+ * b.print();
+ * ```
+ *
+ * @param result The tensor to keep from being disposed.
+ *
+ * @doc {heading: 'Performance', subheading: 'Memory'}
+ */
+ function keep(result) {
+ return ENGINE.keep(result);
+ }
+ /**
+ * Executes `f()` and returns a promise that resolves with timing
+ * information.
+ *
+ * The result is an object with the following properties:
+ *
+ * - `wallMs`: Wall execution time.
+ * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
+ * WebGL backend and the query timer extension is not available, this will
+ * return an error object.
+ * - On `WebGL` The following additional properties exist:
+ * - `uploadWaitMs`: CPU blocking time on texture uploads.
+ * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
+ *
+ * ```js
+ * const x = tf.randomNormal([20, 20]);
+ * const time = await tf.time(() => x.matMul(x));
+ *
+ * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
+ * ```
+ *
+ * @param f The function to execute and time.
+ *
+ * @doc {heading: 'Performance', subheading: 'Timing'}
+ */
+ function time(f) {
+ return ENGINE.time(f);
+ }
+ /**
+ * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
+ * executing operations on those tensors. Returns a promise that resolves
+ * to a boolean if the backend initialization was successful.
+ *
+ * Note this disposes the current backend, if any, as well as any tensors
+ * associated with it. A new backend is initialized, even if it is of the
+ * same type as the previous one.
+ *
+ * @param backendName The name of the backend. Currently supports
+ * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
+ * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function setBackend(backendName) {
+ return ENGINE.setBackend(backendName);
+ }
+ /**
+ * Returns a promise that resolves when the currently selected backend (or the
+ * highest priority one) has initialized. Await this promise when you are using
+ * a backend that has async initialization.
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function ready() {
+ return ENGINE.ready();
+ }
+ /**
+ * Returns the current backend name (cpu, webgl, etc). The backend is
+ * responsible for creating tensors and executing operations on those tensors.
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function getBackend() {
+ return ENGINE.backendName;
+ }
+ /**
+ * Removes a backend and the registered factory.
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function removeBackend(name) {
+ ENGINE.removeBackend(name);
+ }
+ /**
+ * Finds the backend registered under the provided name. Returns null if the
+ * name is not in the registry, or the registration hasn't finished yet.
+ */
+ function findBackend(name) {
+ return ENGINE.findBackend(name);
+ }
+ /**
+ * Finds the backend factory registered under the provided name. Returns a
+ * function that produces a new backend when called. Returns null if the name
+ * is not in the registry.
+ */
+ function findBackendFactory(name) {
+ return ENGINE.findBackendFactory(name);
+ }
+ /**
+ * Registers a global backend. The registration should happen when importing
+ * a module file (e.g. when importing `backend_webgl.ts`), and is used for
+ * modular builds (e.g. custom tfjs bundle with only webgl support).
+ *
+ * @param factory The backend factory function. When called, it should
+ * return a backend instance, or a promise of an instance.
+ * @param priority The priority of the backend (higher = more important).
+ * In case multiple backends are registered, the priority is used to find
+ * the best backend. Defaults to 1.
+ * @return False if there is already a registered backend under this name, true
+ * if not.
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function registerBackend(name, factory, priority) {
+ if (priority === void 0) { priority = 1; }
+ return ENGINE.registerBackend(name, factory, priority);
+ }
+ /**
+ * Gets the current backend. If no backends have been initialized, this will
+ * attempt to initialize the best backend. Will throw an error if the highest
+ * priority backend has async initialization, in which case, you should call
+ * 'await tf.ready()' before running other code.
+ *
+ * @doc {heading: 'Backends'}
+ */
+ function backend() {
+ return ENGINE.backend;
+ }
+ /**
+ * Sets the global platform.
+ *
+ * @param platformName The name of this platform.
+ * @param platform A platform implementation.
+ */
+ function setPlatform(platformName, platform) {
+ env().setPlatform(platformName, platform);
+ }
+
+ /**
+ * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
+ *
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3, 4]);
+ * const b = tf.tensor1d([10, 20, 30, 40]);
+ *
+ * a.add(b).print(); // or tf.add(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast add a with b.
+ * const a = tf.scalar(5);
+ * const b = tf.tensor1d([10, 20, 30, 40]);
+ *
+ * a.add(b).print(); // or tf.add(a, b)
+ * ```
+ * @param a The first `tf.Tensor` to add.
+ * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function add_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'add');
+ var $b = convertToTensor(b, 'b', 'add');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Add, inputs);
+ }
+ var add = op({ add_: add_ });
+
+ /**
+ * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
+ * The result is rounded with floor function.
+ *
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 9, 16]);
+ * const b = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * a.floorDiv(b).print(); // or tf.div(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast div a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(2);
+ *
+ * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
+ * ```
+ *
+ * @param a The first tensor as the numerator.
+ * @param b The second tensor as the denominator. Must have the same dtype as
+ * `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function floorDiv_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'floorDiv');
+ var $b = convertToTensor(b, 'b', 'floorDiv');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(FloorDiv, inputs);
+ }
+ var floorDiv = op({ floorDiv_: floorDiv_ });
+
+ /**
+ * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 9, 16]);
+ * const b = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * a.div(b).print(); // or tf.div(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast div a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(2);
+ *
+ * a.div(b).print(); // or tf.div(a, b)
+ * ```
+ *
+ * @param a The first tensor as the numerator.
+ * @param b The second tensor as the denominator. Must have the same dtype as
+ * `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function div_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'div');
+ var $b = convertToTensor(b, 'b', 'div');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ if ($a.dtype === 'int32' && $b.dtype === 'int32') {
+ return floorDiv($a, $b);
+ }
+ var inputs = { a: $a, b: $b };
+ var attrs = {};
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ return ENGINE.runKernel(RealDiv, inputs, attrs);
+ }
+ var div = op({ div_: div_ });
+
+ /**
+ * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
+ *
+ * We also expose `tf.mulStrict` which has the same signature as this op and
+ * asserts that `a` and `b` are the same shape (does not broadcast).
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3, 4]);
+ * const b = tf.tensor1d([2, 3, 4, 5]);
+ *
+ * a.mul(b).print(); // or tf.mul(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast mul a with b.
+ * const a = tf.tensor1d([1, 2, 3, 4]);
+ * const b = tf.scalar(5);
+ *
+ * a.mul(b).print(); // or tf.mul(a, b)
+ * ```
+ * @param a The first tensor to multiply.
+ * @param b The second tensor to multiply. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function mul_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'mul');
+ var $b = convertToTensor(b, 'b', 'mul');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Multiply, inputs);
+ }
+ var mul = op({ mul_: mul_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes absolute value element-wise: `abs(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ *
+ * x.abs().print(); // or tf.abs(x)
+ * ```
+ * @param x The input `tf.Tensor`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function abs_(x) {
+ var $x = convertToTensor(x, 'x', 'abs');
+ if ($x.dtype === 'complex64') {
+ var inputs = { x: $x };
+ return ENGINE.runKernel(ComplexAbs, inputs);
+ }
+ else {
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Abs, inputs);
+ }
+ }
+ var abs = op({ abs_: abs_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.acos().print(); // or tf.acos(x)
+ * ```
+ * @param x The input tensor.
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function acos_(x) {
+ var $x = convertToTensor(x, 'x', 'acos');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Acos, inputs);
+ }
+ var acos = op({ acos_: acos_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
+ * `acosh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([10, 1, 3, 5.7]);
+ *
+ * x.acosh().print(); // or tf.acosh(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function acosh_(x) {
+ var $x = convertToTensor(x, 'x', 'acosh');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Acosh, inputs);
+ }
+ var acosh = op({ acosh_: acosh_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ * const b = tf.tensor1d([3, 4]);
+ * const c = tf.tensor1d([5, 6]);
+ *
+ * tf.addN([a, b, c]).print();
+ * ```
+ * @param tensors A list of tensors with the same shape and dtype.
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function addN_(tensors) {
+ assert(Array.isArray(tensors), function () { return 'The argument passed to tf.addN() must be a list of tensors'; });
+ assert(tensors.length >= 1, function () { return "Must pass at least one tensor to tf.addN(), but got " +
+ ("" + tensors.length); });
+ var $tensors = tensors.map(function (t, i) { return convertToTensor(t, "tensors" + i, 'addN'); });
+ var firstTensor = $tensors[0];
+ $tensors.forEach(function (t) {
+ if (t.dtype !== firstTensor.dtype) {
+ throw new Error('All tensors passed to tf.addN() must have the same dtype');
+ }
+ });
+ $tensors.forEach(function (t) {
+ if (!arraysEqual(t.shape, firstTensor.shape)) {
+ throw new Error('All tensors passed to tf.addN() must have the same shape');
+ }
+ });
+ var inputs = $tensors;
+ return ENGINE.runKernel(AddN, inputs);
+ }
+ var addN = op({ addN_: addN_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the logical and of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
+ * `axes`. If `keepDims` is true, the reduced dimensions are retained with
+ * length 1. If `axes` has no entries, all dimensions are reduced, and an
+ * `tf.Tensor` with a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 1, 1], 'bool');
+ *
+ * x.all().print(); // or tf.all(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
+ *
+ * const axis = 1;
+ * x.all(axis).print(); // or tf.all(x, axis)
+ * ```
+ *
+ * @param x The input tensor. Must be of dtype bool.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function all_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'all', 'bool');
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ return ENGINE.runKernel(All, inputs, attrs);
+ }
+ var all = op({ all_: all_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the logical or of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
+ * `axes`. If `keepDims` is true, the reduced dimensions are retained with
+ * length 1. If `axes` has no entries, all dimensions are reduced, and an
+ * `tf.Tensor` with a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 1, 1], 'bool');
+ *
+ * x.any().print(); // or tf.any(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
+ *
+ * const axis = 1;
+ * x.any(axis).print(); // or tf.any(x, axis)
+ * ```
+ *
+ * @param x The input tensor. Must be of dtype bool.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function any_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'any', 'bool');
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ return ENGINE.runKernel(Any, inputs, attrs);
+ }
+ // tslint:disable-next-line:variable-name
+ var any = op({ any_: any_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the indices of the maximum values along an `axis`.
+ *
+ * The result has the same shape as `input` with the dimension along `axis`
+ * removed.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.argMax().print(); // or tf.argMax(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
+ *
+ * const axis = 1;
+ * x.argMax(axis).print(); // or tf.argMax(x, axis)
+ * ```
+ *
+ * @param x The input tensor.
+ * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function argMax_(x, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'argMax');
+ var inputs = { x: $x };
+ var attrs = { axis: axis };
+ return ENGINE.runKernel(ArgMax, inputs, attrs);
+ }
+ var argMax = op({ argMax_: argMax_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the indices of the minimum values along an `axis`.
+ *
+ * The result has the same shape as `input` with the dimension along `axis`
+ * removed.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.argMin().print(); // or tf.argMin(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
+ *
+ * const axis = 1;
+ * x.argMin(axis).print(); // or tf.argMin(x, axis)
+ * ```
+ *
+ * @param x The input tensor.
+ * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function argMin_(x, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'argMin');
+ var inputs = { x: $x };
+ var attrs = { axis: axis };
+ return ENGINE.runKernel(ArgMin, inputs, attrs);
+ }
+ var argMin = op({ argMin_: argMin_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.asin().print(); // or tf.asin(x)
+ * ```
+ * @param x The input tensor.
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function asin_(x) {
+ var $x = convertToTensor(x, 'x', 'asin');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Asin, inputs);
+ }
+ var asin = op({ asin_: asin_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
+ * `asinh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.asinh().print(); // or tf.asinh(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function asinh_(x) {
+ var $x = convertToTensor(x, 'x', 'asinh');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Asinh, inputs);
+ }
+ var asinh = op({ asinh_: asinh_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.atan().print(); // or tf.atan(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function atan_(x) {
+ var $x = convertToTensor(x, 'x', 'atan');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Atan, inputs);
+ }
+ var atan = op({ atan_: atan_ });
+
+ /**
+ * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
+ * Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
+ * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
+ *
+ * tf.atan2(a, b).print()
+ * ```
+ *
+ * @param a The first tensor.
+ * @param b The second tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function atan2_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'atan2');
+ var $b = convertToTensor(b, 'b', 'atan2');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Atan2, inputs);
+ }
+ var atan2 = op({ atan2_: atan2_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
+ * `atanh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, .1, -.1, .7]);
+ *
+ * x.atanh().print(); // or tf.atanh(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function atanh_(x) {
+ var $x = convertToTensor(x, 'x', 'atanh');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Atanh, inputs);
+ }
+ var atanh = op({ atanh_: atanh_ });
+
+ /**
+ *
+ * @param inputShape Input tensor shape is of the following dimensions:
+ * `[batch, height, width, inChannels]`.
+ * @param filterShape The filter shape is of the following dimensions:
+ * `[filterHeight, filterWidth, depth]`.
+ * @param strides The strides of the sliding window for each dimension of the
+ * input tensor: `[strideHeight, strideWidth]`.
+ * If `strides` is a single number,
+ * then `strideHeight == strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1*1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat The data format of the input and output data.
+ * Defaults to 'NHWC'.
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
+ * Defaults to `[1, 1]`. If `dilations` is a single number, then
+ * `dilationHeight == dilationWidth`.
+ */
+ function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat, dilations) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ // `computerConv2DInfo` require filterShape to be in the dimension of:
+ // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
+ // outDepth, it should have the same depth as the input.
+ // Input shape: [batch, height, width, inChannels]
+ var inputChannels = inputShape[3];
+ var $filterShape = __spread(filterShape, [inputChannels]);
+ var $dataFormat = convertConv2DDataFormat(dataFormat);
+ return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null /* roundingMode */, null /* depthWise */, $dataFormat);
+ }
+ function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
+ if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
+ var _a = __read(parseTupleParam(filterSize), 2), filterHeight = _a[0], filterWidth = _a[1];
+ var filterShape;
+ if (dataFormat === 'channelsLast') {
+ filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
+ }
+ else if (dataFormat === 'channelsFirst') {
+ filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
+ }
+ else {
+ throw new Error("Unknown dataFormat " + dataFormat);
+ }
+ return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
+ }
+ /**
+ * Computes the information for a forward pass of a pooling3D operation.
+ */
+ function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
+ if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
+ var _a = __read(parse3TupleParam(filterSize), 3), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2];
+ var filterShape;
+ var $dataFormat;
+ if (dataFormat === 'NDHWC') {
+ $dataFormat = 'channelsLast';
+ filterShape =
+ [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
+ }
+ else if (dataFormat === 'NCDHW') {
+ $dataFormat = 'channelsFirst';
+ filterShape =
+ [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
+ }
+ else {
+ throw new Error("Unknown dataFormat " + dataFormat);
+ }
+ return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
+ }
+ /**
+ * Computes the information for a forward pass of a convolution/pooling
+ * operation.
+ */
+ function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) {
+ var _a, _b;
+ if (depthwise === void 0) { depthwise = false; }
+ if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
+ var _c = __read([-1, -1, -1, -1], 4), batchSize = _c[0], inHeight = _c[1], inWidth = _c[2], inChannels = _c[3];
+ if (dataFormat === 'channelsLast') {
+ _a = __read(inShape, 4), batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3];
+ }
+ else if (dataFormat === 'channelsFirst') {
+ _b = __read(inShape, 4), batchSize = _b[0], inChannels = _b[1], inHeight = _b[2], inWidth = _b[3];
+ }
+ else {
+ throw new Error("Unknown dataFormat " + dataFormat);
+ }
+ var _d = __read(filterShape, 4), filterHeight = _d[0], filterWidth = _d[1], filterChannels = _d[3];
+ var _e = __read(parseTupleParam(strides), 2), strideHeight = _e[0], strideWidth = _e[1];
+ var _f = __read(parseTupleParam(dilations), 2), dilationHeight = _f[0], dilationWidth = _f[1];
+ var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
+ var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
+ var _g = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat), padInfo = _g.padInfo, outHeight = _g.outHeight, outWidth = _g.outWidth;
+ var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
+ var outShape;
+ if (dataFormat === 'channelsFirst') {
+ outShape = [batchSize, outChannels, outHeight, outWidth];
+ }
+ else if (dataFormat === 'channelsLast') {
+ outShape = [batchSize, outHeight, outWidth, outChannels];
+ }
+ return {
+ batchSize: batchSize,
+ dataFormat: dataFormat,
+ inHeight: inHeight,
+ inWidth: inWidth,
+ inChannels: inChannels,
+ outHeight: outHeight,
+ outWidth: outWidth,
+ outChannels: outChannels,
+ padInfo: padInfo,
+ strideHeight: strideHeight,
+ strideWidth: strideWidth,
+ filterHeight: filterHeight,
+ filterWidth: filterWidth,
+ effectiveFilterHeight: effectiveFilterHeight,
+ effectiveFilterWidth: effectiveFilterWidth,
+ dilationHeight: dilationHeight,
+ dilationWidth: dilationWidth,
+ inShape: inShape,
+ outShape: outShape,
+ filterShape: filterShape
+ };
+ }
+ /**
+ * Computes the information for a forward pass of a 3D convolution/pooling
+ * operation.
+ */
+ function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) {
+ var _a, _b;
+ if (depthwise === void 0) { depthwise = false; }
+ if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
+ var _c = __read([-1, -1, -1, -1, -1], 5), batchSize = _c[0], inDepth = _c[1], inHeight = _c[2], inWidth = _c[3], inChannels = _c[4];
+ if (dataFormat === 'channelsLast') {
+ _a = __read(inShape, 5), batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4];
+ }
+ else if (dataFormat === 'channelsFirst') {
+ _b = __read(inShape, 5), batchSize = _b[0], inChannels = _b[1], inDepth = _b[2], inHeight = _b[3], inWidth = _b[4];
+ }
+ else {
+ throw new Error("Unknown dataFormat " + dataFormat);
+ }
+ var _d = __read(filterShape, 5), filterDepth = _d[0], filterHeight = _d[1], filterWidth = _d[2], filterChannels = _d[4];
+ var _e = __read(parse3TupleParam(strides), 3), strideDepth = _e[0], strideHeight = _e[1], strideWidth = _e[2];
+ var _f = __read(parse3TupleParam(dilations), 3), dilationDepth = _f[0], dilationHeight = _f[1], dilationWidth = _f[2];
+ var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
+ var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
+ var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
+ var _g = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _g.padInfo, outDepth = _g.outDepth, outHeight = _g.outHeight, outWidth = _g.outWidth;
+ var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
+ var outShape;
+ if (dataFormat === 'channelsFirst') {
+ outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
+ }
+ else if (dataFormat === 'channelsLast') {
+ outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
+ }
+ return {
+ batchSize: batchSize,
+ dataFormat: dataFormat,
+ inDepth: inDepth,
+ inHeight: inHeight,
+ inWidth: inWidth,
+ inChannels: inChannels,
+ outDepth: outDepth,
+ outHeight: outHeight,
+ outWidth: outWidth,
+ outChannels: outChannels,
+ padInfo: padInfo,
+ strideDepth: strideDepth,
+ strideHeight: strideHeight,
+ strideWidth: strideWidth,
+ filterDepth: filterDepth,
+ filterHeight: filterHeight,
+ filterWidth: filterWidth,
+ effectiveFilterDepth: effectiveFilterDepth,
+ effectiveFilterHeight: effectiveFilterHeight,
+ effectiveFilterWidth: effectiveFilterWidth,
+ dilationDepth: dilationDepth,
+ dilationHeight: dilationHeight,
+ dilationWidth: dilationWidth,
+ inShape: inShape,
+ outShape: outShape,
+ filterShape: filterShape
+ };
+ }
+ function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
+ if (zeroPad == null) {
+ zeroPad = computeDefaultPad(inShape, fieldSize, stride);
+ }
+ var inputRows = inShape[0];
+ var inputCols = inShape[1];
+ var outputRows = round$1((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
+ var outputCols = round$1((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
+ return [outputRows, outputCols];
+ }
+ function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
+ if (zeroPad == null) {
+ zeroPad = computeDefaultPad(inShape, fieldSize, stride);
+ }
+ var inputDepth = inShape[0];
+ var inputRows = inShape[1];
+ var inputCols = inShape[2];
+ var outputDepths = round$1((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
+ var outputRows = round$1((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
+ var outputCols = round$1((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
+ return [outputDepths, outputRows, outputCols, outChannels];
+ }
+ function computeDefaultPad(inputShape, fieldSize, stride, dilation) {
+ if (dilation === void 0) { dilation = 1; }
+ var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
+ return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
+ }
+ function parseTupleParam(param) {
+ if (typeof param === 'number') {
+ return [param, param, param];
+ }
+ if (param.length === 2) {
+ return [param[0], param[1], 1];
+ }
+ return param;
+ }
+ function parse3TupleParam(param) {
+ return typeof param === 'number' ? [param, param, param] : param;
+ }
+ /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
+ * Atrous convolution is equivalent to standard convolution with upsampled
+ * filters with effective_filter_height =
+ * filter_height + (filter_height - 1) * (dilation - 1)
+ * and effective_filter_width =
+ * filter_width + (filter_width - 1) * (dilation - 1),
+ * produced by inserting dilation - 1 zeros along consecutive elements across
+ * the filters' spatial dimensions.
+ * When there is a dilation, this converts a filter dimension to the
+ * effective filter dimension, so it can be used in a standard convolution.
+ */
+ function getEffectiveFilterSize(filterSize, dilation) {
+ if (dilation <= 1) {
+ return filterSize;
+ }
+ return filterSize + (filterSize - 1) * (dilation - 1);
+ }
+ function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
+ var padInfo;
+ var outHeight;
+ var outWidth;
+ if (typeof pad === 'number') {
+ var padType = (pad === 0) ? 'VALID' : 'NUMBER';
+ padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
+ var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
+ outHeight = outShape[0];
+ outWidth = outShape[1];
+ }
+ else if (pad === 'same') {
+ outHeight = Math.ceil(inHeight / strideHeight);
+ outWidth = Math.ceil(inWidth / strideWidth);
+ var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
+ var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
+ var top = Math.floor(padAlongHeight / 2);
+ var bottom = padAlongHeight - top;
+ var left = Math.floor(padAlongWidth / 2);
+ var right = padAlongWidth - left;
+ padInfo = { top: top, bottom: bottom, left: left, right: right, type: 'SAME' };
+ }
+ else if (pad === 'valid') {
+ padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
+ outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
+ outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
+ }
+ else if (typeof pad === 'object') {
+ var top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
+ var bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
+ var left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
+ var right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
+ var padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
+ 'VALID' :
+ 'EXPLICIT';
+ padInfo = { top: top, bottom: bottom, left: left, right: right, type: padType };
+ outHeight = round$1((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
+ outWidth = round$1((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
+ }
+ else {
+ throw Error("Unknown padding parameter: " + pad);
+ }
+ return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth };
+ }
+ function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
+ var padInfo;
+ var outDepth;
+ var outHeight;
+ var outWidth;
+ if (typeof pad === 'number') {
+ var padType = (pad === 0) ? 'VALID' : 'NUMBER';
+ padInfo = {
+ top: pad,
+ bottom: pad,
+ left: pad,
+ right: pad,
+ front: pad,
+ back: pad,
+ type: padType
+ };
+ var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
+ outDepth = outShape[0];
+ outHeight = outShape[1];
+ outWidth = outShape[2];
+ }
+ else if (pad === 'same') {
+ outDepth = Math.ceil(inDepth / strideDepth);
+ outHeight = Math.ceil(inHeight / strideHeight);
+ outWidth = Math.ceil(inWidth / strideWidth);
+ var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
+ var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
+ var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
+ var front = Math.floor(padAlongDepth / 2);
+ var back = padAlongDepth - front;
+ var top = Math.floor(padAlongHeight / 2);
+ var bottom = padAlongHeight - top;
+ var left = Math.floor(padAlongWidth / 2);
+ var right = padAlongWidth - left;
+ padInfo = { top: top, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' };
+ }
+ else if (pad === 'valid') {
+ padInfo = {
+ top: 0,
+ bottom: 0,
+ left: 0,
+ right: 0,
+ front: 0,
+ back: 0,
+ type: 'VALID'
+ };
+ outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
+ outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
+ outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
+ }
+ else {
+ throw Error("Unknown padding parameter: " + pad);
+ }
+ return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth };
+ }
+ /**
+ * Rounds a value depending on the rounding mode
+ * @param value
+ * @param roundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ */
+ function round$1(value, roundingMode) {
+ if (!roundingMode) {
+ return Math.trunc(value);
+ }
+ switch (roundingMode) {
+ case 'round':
+ // used for Caffe Conv
+ return Math.round(value);
+ case 'ceil':
+ // used for Caffe Pool
+ return Math.ceil(value);
+ case 'floor':
+ return Math.floor(value);
+ default:
+ throw new Error("Unknown roundingMode " + roundingMode);
+ }
+ }
+ function tupleValuesAreOne(param) {
+ var _a = __read(parseTupleParam(param), 3), dimA = _a[0], dimB = _a[1], dimC = _a[2];
+ return dimA === 1 && dimB === 1 && dimC === 1;
+ }
+ function eitherStridesOrDilationsAreOne(strides, dilations) {
+ return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
+ }
+ /**
+ * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
+ * 'channelsLast'|'channelsFirst'
+ * @param dataFormat in 'NHWC'|'NCHW' mode
+ * @return dataFormat in 'channelsLast'|'channelsFirst' mode
+ * @throws unknown dataFormat
+ */
+ function convertConv2DDataFormat(dataFormat) {
+ if (dataFormat === 'NHWC') {
+ return 'channelsLast';
+ }
+ else if (dataFormat === 'NCHW') {
+ return 'channelsFirst';
+ }
+ else {
+ throw new Error("Unknown dataFormat " + dataFormat);
+ }
+ }
+ /**
+ * Check validity of pad when using dimRoundingMode.
+ * @param opDesc A string of op description
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid` output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ * @throws unknown padding parameter
+ */
+ function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
+ if (dimRoundingMode != null) {
+ if (typeof pad === 'string') {
+ throw Error("Error in " + opDesc + ": pad must be an integer when using " +
+ ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."));
+ }
+ else if (typeof pad === 'number') {
+ assert(isInt(pad), function () { return "Error in " + opDesc + ": pad must be an integer when using " +
+ ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); });
+ }
+ else if (typeof pad === 'object') {
+ pad.forEach(function (p) {
+ p.forEach(function (v) {
+ assert(isInt(v), function () { return "Error in " + opDesc + ": pad must be an integer when using " +
+ ("dimRoundingMode " + dimRoundingMode + " but got pad " + v + "."); });
+ });
+ });
+ }
+ else {
+ throw Error("Error in " + opDesc + ": Unknown padding parameter: " + pad);
+ }
+ }
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reshapes a `tf.Tensor` to a given shape.
+ *
+ * Given an input tensor, returns a new tensor with the same values as the
+ * input tensor with shape `shape`.
+ *
+ * If one component of shape is the special value -1, the size of that
+ * dimension is computed so that the total size remains constant. In
+ * particular, a shape of [-1] flattens into 1-D. At most one component of
+ * shape can be -1.
+ *
+ * If shape is 1-D or higher, then the operation returns a tensor with shape
+ * shape filled with the values of tensor. In this case, the number of
+ * elements implied by shape must be the same as the number of elements in
+ * tensor.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ * x.reshape([2, 2]).print();
+ * ```
+ *
+ * @param x The input tensor to be reshaped.
+ * @param shape An array of integers defining the output tensor shape.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function reshape_(x, shape) {
+ var $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
+ var inputs = { x: $x };
+ var attrs = { shape: shape };
+ return ENGINE.runKernel(Reshape, inputs, attrs);
+ }
+ var reshape = op({ reshape_: reshape_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the 2D average pooling of an image.
+ *
+ * @param x The input tensor, of rank 4 or rank 3 of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
+ * `filterSize` is a single number, then `filterHeight == filterWidth`.
+ * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
+ * `strides` is a single number, then `strideHeight == strideWidth`.
+ * @param pad The type of padding algorithm:
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ */
+ function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
+ var $x = convertToTensor(x, 'x', 'avgPool', 'float32');
+ var dilations = 1;
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in avgPool: x must be rank 4 but got rank " + x4D.rank + "."; });
+ checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
+ var inputs = { x: x4D };
+ var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(AvgPool, inputs, attrs);
+ res = cast(res, $x.dtype);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var avgPool = op({ avgPool_: avgPool_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the 3D average pooling.
+ *
+ * ```js
+ * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
+ * const result = tf.avgPool3d(x, 2, 1, 'valid');
+ * result.print();
+ * ```
+ *
+ * @param x The input tensor, of rank 5 or rank 4 of shape
+ * `[batch, depth, height, width, inChannels]`.
+ * @param filterSize The filter size:
+ * `[filterDepth, filterHeight, filterWidth]`.
+ * If `filterSize` is a single number,
+ * then `filterDepth == filterHeight == filterWidth`.
+ * @param strides The strides of the pooling:
+ * `[strideDepth, strideHeight, strideWidth]`.
+ * If `strides` is a single number,
+ * then `strideDepth == strideHeight == strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1*1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
+ * "NDHWC". Specify the data format of the input and output data. With the
+ * default format "NDHWC", the data is stored in the order of: [batch,
+ * depth, height, width, channels]. Only "NDHWC" is currently supported.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat) {
+ if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
+ var $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
+ var x5D = $x;
+ var reshapedTo5D = false;
+ if ($x.rank === 4) {
+ reshapedTo5D = true;
+ x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
+ }
+ assert(x5D.rank === 5, function () { return "Error in avgPool3d: x must be rank 5 but got rank " + x5D.rank + "."; });
+ assert(dataFormat === 'NDHWC', function () { return "Error in avgPool3d: Only NDHWC is currently supported, " +
+ ("but got dataFormat of " + dataFormat); });
+ checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
+ var inputs = { x: x5D };
+ var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dataFormat: dataFormat };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(AvgPool3D, inputs, attrs);
+ res = cast(res, x5D.dtype);
+ if (reshapedTo5D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
+ }
+ return res;
+ }
+ var avgPool3d = op({ avgPool3d_: avgPool3d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Concatenates a list of `tf.Tensor`s along a given axis.
+ *
+ * The tensors ranks and types must match, and their sizes must match in all
+ * dimensions except `axis`.
+ *
+ * Also available are stricter rank-specific methods that assert that
+ * `tensors` are of the given rank:
+ * - `tf.concat1d`
+ * - `tf.concat2d`
+ * - `tf.concat3d`
+ * - `tf.concat4d`
+ *
+ * Except `tf.concat1d` (which does not have axis param), all methods have
+ * same signature as this method.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ * const b = tf.tensor1d([3, 4]);
+ * a.concat(b).print(); // or a.concat(b)
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ * const b = tf.tensor1d([3, 4]);
+ * const c = tf.tensor1d([5, 6]);
+ * tf.concat([a, b, c]).print();
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor2d([[1, 2], [10, 20]]);
+ * const b = tf.tensor2d([[3, 4], [30, 40]]);
+ * const axis = 1;
+ * tf.concat([a, b], axis).print();
+ * ```
+ * @param tensors A list of tensors to concatenate.
+ * @param axis The axis to concate along. Defaults to 0 (the first dim).
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function concat_(tensors, axis) {
+ if (axis === void 0) { axis = 0; }
+ assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; });
+ var $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
+ if ($tensors[0].dtype === 'complex64') {
+ $tensors.forEach(function (tensor) {
+ if (tensor.dtype !== 'complex64') {
+ throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". ");
+ }
+ });
+ }
+ if ($tensors.length === 1) {
+ return clone($tensors[0]);
+ }
+ var inputs = $tensors;
+ var attr = { axis: axis };
+ return ENGINE.runKernel(Concat, inputs, attr);
+ }
+ var concat = op({ concat_: concat_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, -1, 2, -3]);
+ *
+ * x.sigmoid().print(); // or tf.sigmoid(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function sigmoid_(x) {
+ var $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Sigmoid, inputs);
+ }
+ var sigmoid = op({ sigmoid_: sigmoid_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
+ * and is of size `size`.
+ *
+ * Also available are stricter rank-specific methods with the same signature
+ * as this method that assert that `x` is of the given rank:
+ * - `tf.slice1d`
+ * - `tf.slice2d`
+ * - `tf.slice3d`
+ * - `tf.slice4d`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * x.slice([1], [2]).print();
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * x.slice([1, 0], [1, 2]).print();
+ * ```
+ * @param x The input `tf.Tensor` to slice from.
+ * @param begin The coordinates to start the slice from. The length can be
+ * less than the rank of x - the rest of the axes will have implicit 0 as
+ * start. Can also be a single number, in which case it specifies the
+ * first axis.
+ * @param size The size of the slice. The length can be less than the rank of
+ * x - the rest of the axes will have implicit -1. A value of -1 requests
+ * the rest of the dimensions in the axis. Can also be a single number,
+ * in which case it specifies the size of the first axis.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function slice_(x, begin, size) {
+ var $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
+ if ($x.rank === 0) {
+ throw new Error('Slicing scalar is not possible');
+ }
+ var inputs = { x: $x };
+ var attrs = { begin: begin, size: size };
+ return ENGINE.runKernel(Slice, inputs, attrs);
+ }
+ var slice = op({ slice_: slice_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, 70]);
+ *
+ * x.tanh().print(); // or tf.tanh(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function tanh_(x) {
+ var $x = convertToTensor(x, 'x', 'tanh', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Tanh, inputs);
+ }
+ var tanh = op({ tanh_: tanh_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the next state and output of a BasicLSTMCell.
+ *
+ * Returns `[newC, newH]`.
+ *
+ * Derived from tf.contrib.rnn.BasicLSTMCell.
+ *
+ * @param forgetBias Forget bias for the cell.
+ * @param lstmKernel The weights for the cell.
+ * @param lstmBias The bias for the cell.
+ * @param data The input to the cell.
+ * @param c Previous cell state.
+ * @param h Previous cell output.
+ *
+ * @doc {heading: 'Operations', subheading: 'RNN'}
+ */
+ function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
+ var $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
+ var $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
+ var $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
+ var $data = convertToTensor(data, 'data', 'basicLSTMCell');
+ var $c = convertToTensor(c, 'c', 'basicLSTMCell');
+ var $h = convertToTensor(h, 'h', 'basicLSTMCell');
+ var combined = concat([$data, $h], 1);
+ var weighted = matMul$1(combined, $lstmKernel);
+ var res = add(weighted, $lstmBias);
+ // i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ var batchSize = res.shape[0];
+ var sliceCols = res.shape[1] / 4;
+ var sliceSize = [batchSize, sliceCols];
+ var i = slice(res, [0, 0], sliceSize);
+ var j = slice(res, [0, sliceCols], sliceSize);
+ var f = slice(res, [0, sliceCols * 2], sliceSize);
+ var o = slice(res, [0, sliceCols * 3], sliceSize);
+ var newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f))));
+ var newH = mul(tanh(newC), sigmoid(o));
+ return [newC, newH];
+ }
+ var basicLSTMCell = op({ basicLSTMCell_: basicLSTMCell_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
+ * shape `blockShape + [batch]`, interleaves these blocks back into the grid
+ * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
+ * the same rank as the input. The spatial dimensions of this intermediate
+ * result are then optionally cropped according to `crops` to produce the
+ * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
+ * description.
+ *
+ * ```js
+ * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
+ * const blockShape = [2, 2];
+ * const crops = [[0, 0], [0, 0]];
+ *
+ * x.batchToSpaceND(blockShape, crops).print();
+ * ```
+ *
+ * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
+ * remainingShape`, where spatialShape has `M` dimensions.
+ * @param blockShape A 1-D array. Must have shape `[M]`, all values must
+ * be >= 1.
+ * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
+ * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
+ * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
+ * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
+ *
+ * This operation is equivalent to the following steps:
+ *
+ * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
+ * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
+ * x.shape[N-1]]`
+ *
+ * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
+ * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
+ * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
+ *
+ * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
+ * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
+ * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
+ *
+ * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
+ * according to `crops` to produce the output of shape: `[batch /
+ * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
+ * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
+ * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function batchToSpaceND_(x, blockShape, crops) {
+ var $x = convertToTensor(x, 'x', 'batchToSpaceND');
+ var prod = blockShape.reduce(function (a, b) { return a * b; });
+ assert($x.rank >= 1 + blockShape.length, function () { return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length; });
+ assert(crops.length === blockShape.length, function () { return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length; });
+ assert($x.shape[0] % prod === 0, function () { return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " +
+ ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod); });
+ var inputs = { x: $x };
+ var attrs = { blockShape: blockShape, crops: crops };
+ return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
+ }
+ var batchToSpaceND = op({ batchToSpaceND_: batchToSpaceND_ });
+
+ function xAs4D(x) {
+ var x4D;
+ if (x.rank === 0 || x.rank === 1) {
+ x4D = reshape(x, [1, 1, 1, x.size]);
+ }
+ else if (x.rank === 2) {
+ x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
+ }
+ else if (x.rank === 3) {
+ x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
+ }
+ else {
+ x4D = x;
+ }
+ return x4D;
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Batch normalization.
+ *
+ * As described in
+ * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
+ *
+ * Mean, variance, scale, and offset can be of two shapes:
+ * - The same shape as the input.
+ * - In the common case, the depth dimension is the last dimension of x, so
+ * the values would be an `tf.Tensor1D` of shape [depth].
+ *
+ * Also available are stricter rank-specific methods with the same signature
+ * as this method that assert that parameters passed are of given rank
+ * - `tf.batchNorm2d`
+ * - `tf.batchNorm3d`
+ * - `tf.batchNorm4d`
+ *
+ * @param x The input Tensor.
+ * @param mean A mean Tensor.
+ * @param variance A variance Tensor.
+ * @param offset An offset Tensor.
+ * @param scale A scale Tensor.
+ * @param varianceEpsilon A small float number to avoid dividing by 0.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
+ if (varianceEpsilon == null) {
+ varianceEpsilon = 0.001;
+ }
+ var $x = convertToTensor(x, 'x', 'batchNorm');
+ var $mean = convertToTensor(mean, 'mean', 'batchNorm');
+ var $variance = convertToTensor(variance, 'variance', 'batchNorm');
+ var $scale;
+ if (scale != null) {
+ $scale = convertToTensor(scale, 'scale', 'batchNorm');
+ }
+ var $offset;
+ if (offset != null) {
+ $offset = convertToTensor(offset, 'offset', 'batchNorm');
+ }
+ assert($mean.rank === $variance.rank, function () { return 'Batch normalization gradient requires mean and variance to have ' +
+ 'equal ranks.'; });
+ assert($offset == null || $mean.rank === $offset.rank, function () { return 'Batch normalization gradient requires mean and offset to have ' +
+ 'equal ranks.'; });
+ assert($scale == null || $mean.rank === $scale.rank, function () { return 'Batch normalization gradient requires mean and scale to have ' +
+ 'equal ranks.'; });
+ var x4D = xAs4D($x);
+ var inputs = {
+ x: x4D,
+ scale: $scale,
+ offset: $offset,
+ mean: $mean,
+ variance: $variance
+ };
+ var attrs = { varianceEpsilon: varianceEpsilon };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(FusedBatchNorm, inputs, attrs);
+ return reshape(res, $x.shape);
+ }
+ var batchNorm = op({ batchNorm_: batchNorm_ });
+
+ /**
+ * Batch normalization, strictly for 2D. For the more relaxed version, see
+ * `tf.batchNorm`.
+ *
+ * @param x The input Tensor.
+ * @param mean A mean Tensor.
+ * @param variance A variance Tensor.
+ * @param offset An offset Tensor.
+ * @param scale A scale Tensor.
+ * @param varianceEpsilon A small float number to avoid dividing by 0.
+ */
+ function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
+ var $x = convertToTensor(x, 'x', 'batchNorm');
+ var $mean = convertToTensor(mean, 'mean', 'batchNorm');
+ var $variance = convertToTensor(variance, 'variance', 'batchNorm');
+ var $scale;
+ if (scale != null) {
+ $scale = convertToTensor(scale, 'scale', 'batchNorm');
+ }
+ var $offset;
+ if (offset != null) {
+ $offset = convertToTensor(offset, 'offset', 'batchNorm');
+ }
+ assert($x.rank === 2, function () { return "Error in batchNorm2D: x must be rank 2 but got rank " +
+ ($x.rank + "."); });
+ assert($mean.rank === 2 || $mean.rank === 1, function () { return "Error in batchNorm2D: mean must be rank 2 or rank 1 but " +
+ ("got rank " + $mean.rank + "."); });
+ assert($variance.rank === 2 || $variance.rank === 1, function () { return "Error in batchNorm2D: variance must be rank 2 or rank 1 " +
+ ("but got rank " + $variance.rank + "."); });
+ if ($scale != null) {
+ assert($scale.rank === 2 || $scale.rank === 1, function () { return "Error in batchNorm2D: scale must be rank 2 or rank 1 " +
+ ("but got rank " + $scale.rank + "."); });
+ }
+ if ($offset != null) {
+ assert($offset.rank === 2 || $offset.rank === 1, function () { return "Error in batchNorm2D: offset must be rank 2 or rank 1 " +
+ ("but got rank " + $offset.rank + "."); });
+ }
+ return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
+ }
+ var batchNorm2d = op({ batchNorm2d_: batchNorm2d_ });
+
+ /**
+ * Batch normalization, strictly for 3D. For the more relaxed version, see
+ * `tf.batchNorm`.
+ *
+ * @param x The input Tensor.
+ * @param mean A mean Tensor.
+ * @param variance A variance Tensor.
+ * @param offset An offset Tensor.
+ * @param scale A scale Tensor.
+ * @param varianceEpsilon A small float number to avoid dividing by 0.
+ */
+ function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
+ var $x = convertToTensor(x, 'x', 'batchNorm');
+ var $mean = convertToTensor(mean, 'mean', 'batchNorm');
+ var $variance = convertToTensor(variance, 'variance', 'batchNorm');
+ var $scale;
+ if (scale != null) {
+ $scale = convertToTensor(scale, 'scale', 'batchNorm');
+ }
+ var $offset;
+ if (offset != null) {
+ $offset = convertToTensor(offset, 'offset', 'batchNorm');
+ }
+ assert($x.rank === 3, function () { return "Error in batchNorm3D: x must be rank 3 but got rank " +
+ ($x.rank + "."); });
+ assert($mean.rank === 3 || $mean.rank === 1, function () { return "Error in batchNorm3D: mean must be rank 3 or rank 1 but " +
+ ("got rank " + $mean.rank + "."); });
+ assert($variance.rank === 3 || $variance.rank === 1, function () { return "Error in batchNorm3D: variance must be rank 3 or rank 1 " +
+ ("but got rank " + $variance.rank + "."); });
+ if ($scale != null) {
+ assert($scale.rank === 3 || $scale.rank === 1, function () { return "Error in batchNorm3D: scale must be rank 3 or rank 1 " +
+ ("but got rank " + $scale.rank + "."); });
+ }
+ if ($offset != null) {
+ assert($offset.rank === 3 || $offset.rank === 1, function () { return "Error in batchNorm3D: offset must be rank 3 or rank 1 " +
+ ("but got rank " + $offset.rank + "."); });
+ }
+ return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
+ }
+ var batchNorm3d = op({ batchNorm3d_: batchNorm3d_ });
+
+ /**
+ * Batch normalization, strictly for 4D. For the more relaxed version, see
+ * `tf.batchNorm`.
+ *
+ * @param x The input Tensor.
+ * @param mean A mean Tensor.
+ * @param variance A variance Tensor.
+ * @param offset An offset Tensor.
+ * @param scale A scale Tensor.
+ * @param varianceEpsilon A small float number to avoid dividing by 0.
+ */
+ function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
+ var $x = convertToTensor(x, 'x', 'batchNorm');
+ var $mean = convertToTensor(mean, 'mean', 'batchNorm');
+ var $variance = convertToTensor(variance, 'variance', 'batchNorm');
+ var $scale;
+ if (scale != null) {
+ $scale = convertToTensor(scale, 'scale', 'batchNorm');
+ }
+ var $offset;
+ if (offset != null) {
+ $offset = convertToTensor(offset, 'offset', 'batchNorm');
+ }
+ assert($x.rank === 4, function () { return "Error in batchNorm4D: x must be rank 4 but got rank " +
+ ($x.rank + "."); });
+ assert($mean.rank === 4 || $mean.rank === 1, function () { return "Error in batchNorm4D: mean must be rank 4 or rank 1 but " +
+ ("got rank " + $mean.rank + "."); });
+ assert($variance.rank === 4 || $variance.rank === 1, function () { return "Error in batchNorm4D: variance must be rank 4 or rank 1 " +
+ ("but got rank " + $variance.rank + "."); });
+ if ($scale != null) {
+ assert($scale.rank === 4 || $scale.rank === 1, function () { return "Error in batchNorm4D: scale must be rank 4 or rank 1 " +
+ ("but got rank " + $scale.rank + "."); });
+ }
+ if ($offset != null) {
+ assert($offset.rank === 4 || $offset.rank === 1, function () { return "Error in batchNorm4D: offset must be rank 4 or rank 1 " +
+ ("but got rank " + $offset.rank + "."); });
+ }
+ return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
+ }
+ var batchNorm4d = op({ batchNorm4d_: batchNorm4d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Outputs a vector with length `size` and the same dtype as `weights`.
+ *
+ * If `weights` are empty, then index `i` stores the number of times the value
+ * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
+ * sum of the value in `weights` at each index where the corresponding value in
+ * `x` is `i`.
+ *
+ * Values in `x` outside of the range [0, size) are ignored.
+ *
+ * @param x The input int tensor, rank 1.
+ * @param weights The weights tensor, must have the same shape as x, or a
+ * length-0 Tensor, in which case it acts as all weights equal to 1.
+ * @param size Non-negative integer.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function bincount_(x, weights, size) {
+ var $x = convertToTensor(x, 'x', 'bincount');
+ var $weights = convertToTensor(weights, 'weights', 'bincount');
+ assert($x.dtype === 'int32', function () { return "Error in bincount: input " +
+ ("dtype must be int32, but got " + $x.dtype); });
+ assert(size >= 0, function () { return "size must be non-negative, but got " + size + "."; });
+ assert($weights.size === $x.size || $weights.size === 0, function () { return "Error in bincount: weights must have the same size as input or" +
+ ("0-length, but got input shape: " + $x.shape + ", weights shape: ") +
+ ($weights.shape + "."); });
+ var inputs = { x: $x, weights: $weights };
+ var attrs = { size: size };
+ return ENGINE.runKernel(Bincount, inputs, attrs);
+ }
+ var bincount = op({ bincount_: bincount_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Return the shape of s0 op s1 with broadcast.
+ *
+ * compute r0, the broadcasted shape as a tensor.
+ * s0, s1 and r0 are all integer vectors.
+ *
+ * This function returns the shape of the result of an operation between
+ * two tensors of size s0 and s1 performed with broadcast.
+ *
+ * @param s0 A tensor representing a shape
+ * @param s1 A tensor representing a shape
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function broadcastArgs_(s0, s1) {
+ var shape1Input = convertToTensor(s0, 's0', 'broadcastArgs', 'int32');
+ var shape2Input = convertToTensor(s1, 's1', 'broadcastArgs', 'int32');
+ if (shape1Input.rank !== 1) {
+ throw new Error('broadcastArgs(): first input must be a vector (rank=1). ' +
+ ("Has rank " + shape1Input.rank));
+ }
+ if (shape2Input.rank !== 1) {
+ throw new Error('broadcastArgs(): second input must be a vector (rank=1). ' +
+ ("Has rank " + shape2Input.rank));
+ }
+ var inputs = { s0: shape1Input, s1: shape2Input };
+ return ENGINE.runKernel(BroadcastArgs, inputs);
+ }
+ var broadcastArgs = op({ broadcastArgs_: broadcastArgs_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Broadcast an array to a compatible shape NumPy-style.
+ *
+ * The tensor's shape is compared to the broadcast shape from end to beginning.
+ * Ones are prepended to the tensor's shape until is has the same length as
+ * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
+ * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
+ * the input tensor is tiled N times along that axis (using tf.tile).
+ *
+ * @param input The tensor that is to be broadcasted.
+ * @param shape The input is to be broadcast to this shape.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function broadcastTo_(x, shape) {
+ var input = convertToTensor(x, 'broadcastTo', 'x');
+ var xShape = input.shape;
+ if (shape.some(function (d) { return !(d > 0) || d % 1 !== 0; })) {
+ throw new Error("broadcastTo(): Invalid broadcast shape [" + shape + "].");
+ }
+ if (shape.length < input.rank) {
+ throw new Error("broadcastTo(): shape.length=" + shape.length + " < input.rank=" + input.rank + ".");
+ }
+ if (shape.length > input.rank) {
+ var newShape = input.shape.slice();
+ while (newShape.length < shape.length) {
+ newShape.unshift(1);
+ }
+ input = reshape(input, newShape);
+ }
+ var inputShape = input.shape;
+ var reps = Array.from(shape);
+ for (var i = shape.length - 1; i >= 0; i--) {
+ if (inputShape[i] === shape[i]) {
+ reps[i] = 1;
+ }
+ else if (input.shape[i] !== 1) {
+ throw new Error("broadcastTo(): [" + xShape + "] cannot be broadcast to [" + shape + "].");
+ }
+ }
+ var axes = reps.map(function (n, i) { return n > 1 ? i : -1; }).filter(function (i) { return i >= 0; });
+ if (axes.length === 0) {
+ return clone(input);
+ }
+ // TODO call broadcastTo kernel directly once backends implement broadcstTo
+ var inputs = { x: input };
+ var attrs = { reps: reps };
+ return ENGINE.runKernel(Tile, inputs, attrs);
+ }
+ var broadcastTo = op({ broadcastTo_: broadcastTo_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([.6, 1.1, -3.3]);
+ *
+ * x.ceil().print(); // or tf.ceil(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function ceil_(x) {
+ var $x = convertToTensor(x, 'x', 'ceil', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Ceil, inputs);
+ }
+ var ceil = op({ ceil_: ceil_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ *
+ * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
+ * ```
+ * @param x The input tensor.
+ * @param clipValueMin Lower-bound of range to be clipped to.
+ * @param clipValueMax Upper-bound of range to be clipped to.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function clipByValue_(x, clipValueMin, clipValueMax) {
+ var $x = convertToTensor(x, 'x', 'clipByValue');
+ assert((clipValueMin <= clipValueMax), function () { return "Error in clip: min (" + clipValueMin + ") must be " +
+ ("less than or equal to max (" + clipValueMax + ")."); });
+ var inputs = { x: $x };
+ var attrs = { clipValueMin: clipValueMin, clipValueMax: clipValueMax };
+ return ENGINE.runKernel(ClipByValue, inputs, attrs);
+ }
+ var clipByValue = op({ clipByValue_: clipByValue_ });
+
+ /**
+ * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
+ *
+ * For example, if:
+ * A: shape(3) = |r1, g1, b1|
+ * B: shape(2) = |r2, g2|
+ * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
+ *
+ * @param tensors A list of`tf.Tensor`s to concatenate.
+ * @return The concatenated array.
+ */
+ function concat1d_(tensors) {
+ return concat(tensors, 0 /* axis */);
+ }
+ var concat1d = op({ concat1d_: concat1d_ });
+
+ /**
+ * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
+ *
+ * For example, if:
+ * A: shape(2, 3) = | r1, g1, b1 |
+ * | r2, g2, b2 |
+ *
+ * B: shape(2, 3) = | r3, g3, b3 |
+ * | r4, g4, b4 |
+ *
+ * C = tf.concat2d([A, B], axis)
+ *
+ * if axis = 0:
+ * C: shape(4, 3) = | r1, g1, b1 |
+ * | r2, g2, b2 |
+ * | r3, g3, b3 |
+ * | r4, g4, b4 |
+ *
+ * if axis = 1:
+ * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
+ * | r2, g2, b2, r4, g4, b4 |
+ *
+ *
+ * @param tensors A list of `tf.Tensor`s to concatenate.
+ * @param axis The axis to concatenate along.
+ * @return The concatenated array.
+ */
+ function concat2d_(tensors, axis) {
+ return concat(tensors, axis);
+ }
+ var concat2d = op({ concat2d_: concat2d_ });
+
+ /**
+ * Concatenates a list of `tf.Tensor3D`s along an axis.
+ * See `concat` for details.
+ *
+ * For example, if:
+ * A: shape(2, 1, 3) = | r1, g1, b1 |
+ * | r2, g2, b2 |
+ *
+ * B: shape(2, 1, 3) = | r3, g3, b3 |
+ * | r4, g4, b4 |
+ *
+ * C = tf.concat3d([A, B], axis)
+ *
+ * if axis = 0:
+ * C: shape(4, 1, 3) = | r1, g1, b1 |
+ * | r2, g2, b2 |
+ * | r3, g3, b3 |
+ * | r4, g4, b4 |
+ *
+ * if axis = 1:
+ * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
+ * | r2, g2, b2, r4, g4, b4 |
+ *
+ * if axis = 2:
+ * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
+ * | r2, g2, b2, r4, g4, b4 |
+ *
+ * @param tensors A list of`tf.Tensor`s to concatenate.
+ * @param axis The axis to concate along.
+ * @return The concatenated array.
+ */
+ function concat3d_(tensors, axis) {
+ return concat(tensors, axis);
+ }
+ var concat3d = op({ concat3d_: concat3d_ });
+
+ /**
+ * Concatenates a list of `tf.Tensor4D`s along an axis.
+ * See `concat` for details.
+ *
+ * @param tensors A list of `tf.Tensor`s to concatenate.
+ * @param axis The axis to concate along.
+ * @return The concatenated array.
+ */
+ function concat4d_(tensors, axis) {
+ return concat(tensors, axis);
+ }
+ var concat4d = op({ concat4d_: concat4d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes a 2D convolution over the input x.
+ *
+ * @param x The input tensor, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param filter The filter, rank 4, of shape
+ * `[filterHeight, filterWidth, inDepth, outDepth]`.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels].
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function conv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ if (dilations === void 0) { dilations = [1, 1]; }
+ var $x = convertToTensor(x, 'x', 'conv2d', 'float32');
+ var $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + "."; });
+ assert($filter.rank === 4, function () { return "Error in conv2d: filter must be rank 4, but got rank " +
+ ($filter.rank + "."); });
+ checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
+ var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
+ assert(inDepth === $filter.shape[2], function () { return "Error in conv2d: depth of input (" + inDepth + ") must match " +
+ ("input depth for filter " + $filter.shape[2] + "."); });
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ var inputs = { x: x4D, filter: $filter };
+ var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dilations: dilations, dimRoundingMode: dimRoundingMode };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Conv2D, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var conv2d$1 = op({ conv2d_: conv2d_ });
+
+ /**
+ * Computes a 1D convolution over the input x.
+ *
+ * @param x The input tensor, of rank 3 or rank 2, of shape
+ * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
+ * @param filter The filter, rank 3, of shape
+ * `[filterWidth, inDepth, outDepth]`.
+ * @param stride The number of entries by which the filter is moved right at
+ * each step.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
+ * the data is stored in the order of [batch, in_width, in_channels]. Only
+ * "NWC" is currently supported.
+ * @param dilation The dilation rate in which we sample input values in
+ * atrous convolution. Defaults to `1`. If it is greater than 1, then
+ * stride must be `1`.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function conv1d_(x, filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
+ if (dataFormat === void 0) { dataFormat = 'NWC'; }
+ if (dilation === void 0) { dilation = 1; }
+ var $x = convertToTensor(x, 'x', 'conv1d');
+ var $filter = convertToTensor(filter, 'filter', 'conv1d');
+ var x3D = $x;
+ var reshapedTo3D = false;
+ if ($x.rank === 2) {
+ reshapedTo3D = true;
+ x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);
+ }
+ assert(x3D.rank === 3, function () { return "Error in conv1d: input must be rank 3, but got rank " + x3D.rank + "."; });
+ assert($filter.rank === 3, function () { return "Error in conv1d: filter must be rank 3, but got rank " +
+ ($filter.rank + "."); });
+ checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
+ assert(x3D.shape[2] === $filter.shape[1], function () { return "Error in conv1d: depth of input (" + x3D.shape[2] + ") must match " +
+ ("input depth for filter " + $filter.shape[1] + "."); });
+ assert(eitherStridesOrDilationsAreOne(stride, dilation), function () { return 'Error in conv1D: Either stride or dilation must be 1. ' +
+ ("Got stride " + stride + " and dilation '" + dilation + "'"); });
+ assert(dataFormat === 'NWC', function () { return "Error in conv1d: got dataFormat of " + dataFormat + " but only NWC is currently supported."; });
+ var filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
+ var input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
+ var strides = [1, stride];
+ var dilations = [1, dilation];
+ var conv2dDataFormat = 'NHWC';
+ var res = conv2d$1(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
+ if (reshapedTo3D) {
+ return reshape(res, [res.shape[2], res.shape[3]]);
+ }
+ return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]);
+ }
+ var conv1d = op({ conv1d_: conv1d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the derivative of the input of a 2D convolution.
+ *
+ * @param xShape The shape of the input: [batch, height, width, inDepth].
+ * If length of 3, batch of 1 is assumed.
+ * @param dy The derivative of the output, of rank 4 or rank 3 of shape
+ * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param filter The filter, rank 4, of shape
+ * `[filterHeight, filterWidth, inDepth, outDepth]`.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`.
+ * @param pad The type of padding algorithm used:
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels].
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ */
+ function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat, dimRoundingMode) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ assert(xShape.length === dy.rank, function () { return "Length of inShape " +
+ ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); });
+ var xShape4D = xShape;
+ var dy4D = dy;
+ var reshapedTo4D = false;
+ if (dy.rank === 3) {
+ reshapedTo4D = true;
+ dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
+ xShape4D = [1, xShape[0], xShape[1], xShape[2]];
+ }
+ assert(xShape4D.length === 4, function () { return "Error in conv2dDerInput: inShape must be length 4, but got length " +
+ (xShape4D.length + "."); });
+ assert(dy4D.rank === 4, function () { return "Error in conv2dDerInput: dy must be rank 4, but got " +
+ ("rank " + dy4D.rank); });
+ assert(filter.rank === 4, function () { return "Error in conv2dDerInput: filter must be rank 4, but got " +
+ ("rank " + filter.rank); });
+ var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
+ var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
+ assert(inDepth === filter.shape[2], function () { return "Error in conv2dDerInput: depth of input (" + inDepth + ") must " +
+ ("match input depth for filter " + filter.shape[2] + "."); });
+ assert(outDepth === filter.shape[3], function () { return "Error in conv2dDerInput: depth of output (" + outDepth + ") must " +
+ ("match output depth for filter " + filter.shape[3] + "."); });
+ checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
+ var inputs = { dy: dy4D, filter: filter };
+ var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode, inputShape: xShape4D };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var conv2DBackpropInput = op({ conv2DBackpropInput_: conv2DBackpropInput_ });
+
+ /**
+ * Computes the transposed 2D convolution of an image, also known as a
+ * deconvolution.
+ *
+ * @param x The input image, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
+ * @param filter The filter, rank 4, of shape
+ * `[filterHeight, filterWidth, outDepth, inDepth]`.
+ * `inDepth` must match `inDepth` in `x`.
+ * @param outputShape Output shape, of rank 4 or rank 3:
+ * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
+ * @param strides The strides of the original convolution:
+ * `[strideHeight, strideWidth]`.
+ * @param pad The type of padding algorithm used in the non-transpose version
+ * of the op.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
+ var $x = convertToTensor(x, 'x', 'conv2dTranspose');
+ var $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
+ return conv2DBackpropInput(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
+ }
+ var conv2dTranspose = op({ conv2dTranspose_: conv2dTranspose_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes a 3D convolution over the input x.
+ *
+ * @param x The input tensor, of rank 5 or rank 4, of shape
+ * `[batch, depth, height, width, channels]`. If rank 4,
+ * batch of 1 is assumed.
+ * @param filter The filter, rank 5, of shape
+ * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
+ * inChannels must match between input and filter.
+ * @param strides The strides of the convolution: `[strideDepth, strideHeight,
+ * strideWidth]`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
+ * "NDHWC". Specify the data format of the input and output data. With the
+ * default format "NDHWC", the data is stored in the order of: [batch,
+ * depth, height, width, channels]. Only "NDHWC" is currently supported.
+ * @param dilations The dilation rates: `[dilationDepth, dilationHeight,
+ * dilationWidth]` in which we sample input values across the height
+ * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
+ * If `dilations` is a single number, then
+ * `dilationDepth == dilationHeight == dilationWidth`. If it is greater
+ * than 1, then all values of `strides` must be 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function conv3d_(x, filter, strides, pad, dataFormat, dilations) {
+ if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
+ if (dilations === void 0) { dilations = [1, 1, 1]; }
+ var $x = convertToTensor(x, 'x', 'conv3d');
+ var $filter = convertToTensor(filter, 'filter', 'conv3d');
+ var x5D = $x;
+ var reshapedTo5D = false;
+ if ($x.rank === 4) {
+ reshapedTo5D = true;
+ x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
+ }
+ assert(x5D.rank === 5, function () { return "Error in conv3d: input must be rank 5, but got rank " + x5D.rank + "."; });
+ assert($filter.rank === 5, function () { return "Error in conv3d: filter must be rank 5, but got rank " +
+ ($filter.rank + "."); });
+ assert(x5D.shape[4] === $filter.shape[3], function () { return "Error in conv3d: depth of input (" + x5D.shape[4] + ") must match " +
+ ("input depth for filter " + $filter.shape[3] + "."); });
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv3D: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ assert(dataFormat === 'NDHWC', function () { return "Error in conv3d: got dataFormat of " + dataFormat + " but only NDHWC is currently supported."; });
+ var inputs = { x: x5D, filter: $filter };
+ var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dilations: dilations };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Conv3D, inputs, attrs);
+ if (reshapedTo5D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
+ }
+ return res;
+ }
+ var conv3d = op({ conv3d_: conv3d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the derivative of the input of a 3D convolution.
+ *
+ * @param xShape The shape of the input: [batch, depth, height, width,
+ * in_channels]. If length of 4, batch of 1 is assumed.
+ * @param dy The derivative of the output, of rank 5 or rank 4 of shape
+ * `[batch, outDepth, outHeight, outWidth, in_channels]`.
+ * If rank 4, batch of 1 is assumed.
+ * @param filter The filter, rank 5, of shape
+ * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
+ * @param strides The strides of the convolution: `[strideDepth, strideHeight,
+ * strideWidth]`.
+ * @param pad The type of padding algorithm used:
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ */
+ function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
+ assert(xShape.length === dy.rank, function () { return "Length of inShape " +
+ ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); });
+ var xShape5D = xShape;
+ var dy5D = dy;
+ var reshapedTo5D = false;
+ if (dy.rank === 4) {
+ reshapedTo5D = true;
+ dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
+ xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
+ }
+ var inDepth = xShape5D[4];
+ var outDepth = dy5D.shape[4];
+ assert(xShape5D.length === 5, function () { return "Error in conv3dDerInput: inShape must be length 5, but got length " +
+ (xShape5D.length + "."); });
+ assert(dy5D.rank === 5, function () { return "Error in conv3dDerInput: dy must be rank 5, but got " +
+ ("rank " + dy5D.rank); });
+ assert(filter.rank === 5, function () { return "Error in conv3dDerInput: filter must be rank 5, but got " +
+ ("rank " + filter.rank); });
+ assert(inDepth === filter.shape[3], function () { return "Error in conv3dDerInput: depth of input (" + inDepth + ") must " +
+ ("match input depth for filter " + filter.shape[3] + "."); });
+ assert(outDepth === filter.shape[4], function () { return "Error in conv3dDerInput: depth of output (" + outDepth + ") must " +
+ ("match output depth for filter " + filter.shape[4] + "."); });
+ var inputs = { dy: dy5D, filter: filter };
+ var attrs = { pad: pad, strides: strides, inputShape: xShape5D };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
+ if (reshapedTo5D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
+ }
+ return res;
+ }
+ var conv3DBackpropInput = op({ conv3DBackpropInput_: conv3DBackpropInput_ });
+
+ /**
+ * Computes the transposed 3D convolution of a volume, also known as a
+ * deconvolution.
+ *
+ * @param x The input image, of rank 5 or rank 4, of shape
+ * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
+ * @param filter The filter, rank 4, of shape
+ * `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
+ * `inDepth` must match `inDepth` in `x`.
+ * @param outputShape Output shape, of rank 5 or rank 4:
+ * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param strides The strides of the original convolution:
+ * `[strideDepth, strideHeight, strideWidth]`.
+ * @param pad The type of padding algorithm used in the non-transpose version
+ * of the op.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function conv3dTranspose_(x, filter, outputShape, strides, pad) {
+ var $x = convertToTensor(x, 'x', 'conv3dTranspose');
+ var $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
+ return conv3DBackpropInput(outputShape, $x, $filter, strides, pad);
+ }
+ var conv3dTranspose = op({ conv3dTranspose_: conv3dTranspose_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
+ *
+ * x.cos().print(); // or tf.cos(x)
+ * ```
+ * @param x The input tensor. Must be float32 type.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function cos_(x) {
+ var $x = convertToTensor(x, 'x', 'cos', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Cos, inputs);
+ }
+ var cos = op({ cos_: cos_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.cosh().print(); // or tf.cosh(x)
+ * ```
+ * @param x The input tensor. Must be float32 type.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function cosh_(x) {
+ var $x = convertToTensor(x, 'x', 'cosh', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Cosh, inputs);
+ }
+ var cosh = op({ cosh_: cosh_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the cumulative sum of a `tf.Tensor` along `axis`.
+ *
+ * ```js
+ * const x = tf.tensor([1, 2, 3, 4]);
+ * x.cumsum().print();
+ * ```
+ * ```js
+ * const x = tf.tensor([[1, 2], [3, 4]]);
+ * x.cumsum().print();
+ * ```
+ *
+ * @param x The input tensor to be summed.
+ * @param axis The axis along which to sum. Optional. Defaults to 0.
+ * @param exclusive Whether to perform exclusive cumulative sum. Optional.
+ * Defaults to false. If set to true then the sum of each tensor entry
+ * does not include its own value, but only the values previous to it
+ * along the specified axis.
+ * @param reverse Whether to sum in the opposite direction. Optional.
+ * Defaults to false.
+ *
+ * @doc {heading: 'Operations', subheading: 'Scan'}
+ */
+ function cumsum_(x, axis, exclusive, reverse) {
+ if (axis === void 0) { axis = 0; }
+ if (exclusive === void 0) { exclusive = false; }
+ if (reverse === void 0) { reverse = false; }
+ var $x = convertToTensor(x, 'x', 'cumsum');
+ var inputs = { x: $x };
+ var attrs = { axis: axis, exclusive: exclusive, reverse: reverse };
+ return ENGINE.runKernel(Cumsum, inputs, attrs);
+ }
+ var cumsum = op({ cumsum_: cumsum_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Outputs a vector with length `size` and the same dtype as `weights`.
+ *
+ * If `weights` are empty, then index `i` stores the number of times the value
+ * `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
+ * sum of the value in `weights` at each index where the corresponding value in
+ * `x` is `i`.
+ *
+ * Values in `x` outside of the range [0, size) are ignored.
+ *
+ * @param x The input int tensor, rank 1 or rank 2.
+ * @param weights The weights tensor, must have the same shape as x, or a
+ * length-0 Tensor, in which case it acts as all weights equal to 1.
+ * @param size Non-negative integer.
+ * @param binaryOutput Optional. Whether the kernel should count the appearance
+ * or number of occurrences. Defaults to False.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function denseBincount_(x, weights, size, binaryOutput) {
+ if (binaryOutput === void 0) { binaryOutput = false; }
+ var $x = convertToTensor(x, 'x', 'denseBincount');
+ var $weights = convertToTensor(weights, 'weights', 'denseBincount');
+ assert($x.dtype === 'int32', function () { return "Error in denseBincount: input " +
+ ("dtype must be int32, but got " + $x.dtype); });
+ assert($x.rank <= 2, function () { return "Error in denseBincount: input must be at most rank 2, but got " +
+ ("rank " + $x.rank + "."); });
+ assert(size >= 0, function () { return "size must be non-negative, but got " + size + "."; });
+ assert($weights.size === $x.size || $weights.size === 0, function () { return "Error in denseBincount: weights must have the same shape as x or " +
+ ("0-length, but got x shape: " + $x.shape + ", weights shape: ") +
+ ($weights.shape + "."); });
+ var inputs = { x: $x, weights: $weights };
+ var attrs = { size: size, binaryOutput: binaryOutput };
+ return ENGINE.runKernel(DenseBincount, inputs, attrs);
+ }
+ var denseBincount = op({ denseBincount_: denseBincount_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Rearranges data from depth into blocks of spatial data. More specifically,
+ * this op outputs a copy of the input tensor where values from the `depth`
+ * dimension are moved in spatial blocks to the `height` and `width` dimensions.
+ * The attr `blockSize` indicates the input block size and how the data is
+ * moved.
+ *
+ * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
+ * into non-overlapping blocks of size `blockSize x blockSize`
+ *
+ * - The width the output tensor is `inputWidth * blockSize`, whereas the
+ * height is `inputHeight * blockSize`
+ *
+ * - The Y, X coordinates within each block of the output image are determined
+ * by the high order component of the input channel index
+ *
+ * - The depth of the input tensor must be divisible by `blockSize *
+ * blockSize`
+ *
+ * The `dataFormat` attr specifies the layout of the input and output tensors
+ * with the following options: "NHWC": [ `batch, height, width, channels` ]
+ * "NCHW": [ `batch, channels, height, width` ]
+ *
+ * ```js
+ * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
+ * const blockSize = 2;
+ * const dataFormat = "NHWC";
+ *
+ * tf.depthToSpace(x, blockSize, dataFormat).print();
+ * ```
+ *
+ * @param x The input tensor of rank 4
+ * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
+ * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function depthToSpace_(x, blockSize, dataFormat) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ var $x = convertToTensor(x, 'x', 'depthToSpace', 'float32');
+ var inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
+ var inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
+ var inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
+ assert(blockSize > 1, function () { return "blockSize should be > 1 for depthToSpace, but was: " + blockSize; });
+ assert(inputHeight * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; });
+ assert(inputWidth * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; });
+ assert((inputDepth % (blockSize * blockSize) === 0), function () { return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape; });
+ var inputs = { x: $x };
+ var attrs = { blockSize: blockSize, dataFormat: dataFormat };
+ return ENGINE.runKernel(DepthToSpace, inputs, attrs);
+ }
+ var depthToSpace = op({ depthToSpace_: depthToSpace_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Depthwise 2D convolution.
+ *
+ * Given a 4D `input` array and a `filter` array of shape
+ * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
+ * `inChannels` convolutional filters of depth 1, this op applies a
+ * different filter to each input channel (expanding from 1 channel to
+ * `channelMultiplier` channels for each), then concatenates the results
+ * together. The output has `inChannels * channelMultiplier` channels.
+ *
+ * See
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
+ * for more details.
+ *
+ * @param x The input tensor, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param filter The filter tensor, rank 4, of shape
+ * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`. If strides is a single number, then `strideHeight ==
+ * strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels]. Only "NHWC" is currently supported.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function depthwiseConv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ if (dilations === void 0) { dilations = [1, 1]; }
+ var $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
+ var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in depthwiseConv2d: input must be rank 4, but got " +
+ ("rank " + x4D.rank + "."); });
+ assert($filter.rank === 4, function () { return "Error in depthwiseConv2d: filter must be rank 4, but got rank " +
+ ($filter.rank + "."); });
+ assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in depthwiseConv2d: number of input channels " +
+ ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") +
+ ("filter " + $filter.shape[2] + "."); });
+ checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
+ var inputs = { x: x4D, filter: $filter };
+ var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dilations: dilations, dimRoundingMode: dimRoundingMode };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var depthwiseConv2d$1 = op({ depthwiseConv2d_: depthwiseConv2d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns a diagonal tensor with a given diagonal values.
+ *
+ * Given a diagonal, this operation returns a tensor with the diagonal and
+ * everything else padded with zeros.
+ *
+ * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
+ * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * tf.diag(x).print()
+ * ```
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
+ *
+ * tf.diag(x).print()
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function diag_(x) {
+ var $x = convertToTensor(x, 'x', 'diag');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Diag, inputs);
+ }
+ var diag = op({ diag_: diag_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the grayscale dilation over the input `x`.
+ *
+ * @param x The input tensor, rank 3 or rank 4 of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param filter The filter tensor, rank 3, of shape
+ * `[filterHeight, filterWidth, depth]`.
+ * @param strides The strides of the sliding window for each dimension of the
+ * input tensor: `[strideHeight, strideWidth]`.
+ * If `strides` is a single number,
+ * then `strideHeight == strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1*1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat Specify the data format of the input and output data.
+ * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels].
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
+ * is a single number, then `dilationHeight == dilationWidth`. If it is
+ * greater than 1, then all values of `strides` must be 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function dilation2d_(x, filter, strides, pad, dilations, dataFormat) {
+ if (dilations === void 0) { dilations = [1, 1]; }
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ var $x = convertToTensor(x, 'x', 'dilation2d');
+ var $filter = convertToTensor(filter, 'filter', 'dilation2d');
+ assert($x.rank === 3 || $x.rank === 4, function () { return "Error in dilation2d: input must be rank 3 or 4, but got rank " +
+ ($x.rank + "."); });
+ assert($filter.rank === 3, function () { return "Error in dilation2d: filter must be rank 3, but got rank " +
+ ($filter.rank + "."); });
+ assert(dataFormat === 'NHWC', function () { return "Error in dilation2d: Only NHWC is currently supported, " +
+ ("but got dataFormat of " + dataFormat); });
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ reshapedTo4D = true;
+ }
+ var inputs = { x: x4D, filter: $filter };
+ var attrs = { strides: strides, pad: pad, dilations: dilations };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Dilation2D, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var dilation2d = op({ dilation2d_: dilation2d_ });
+
+ /**
+ * Returns the truth value of (a == b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([2, 2, 2]);
+ *
+ * a.equal(b).print();
+ * ```
+ *
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function equal_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Equal, inputs);
+ }
+ var equal = op({ equal_: equal_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the elements, either `a` or `b` depending on the `condition`.
+ *
+ * If the condition is true, select from `a`, otherwise select from `b`.
+ *
+ * ```js
+ * const cond = tf.tensor1d([false, false, true], 'bool');
+ * const a = tf.tensor1d([1 , 2, 3]);
+ * const b = tf.tensor1d([-1, -2, -3]);
+ *
+ * a.where(cond, b).print();
+ * ```
+ *
+ * @param condition The input condition. Must be of dtype bool.
+ * @param a If `condition` is rank 1, `a` may have a higher rank but
+ * its first dimension must match the size of `condition`.
+ * @param b A tensor with the same dtype as `a` and with shape that is
+ * compatible with `a`.
+ * @return A tensor with same dtype as `a` and `b`, and shape that is
+ * broadcastable from `a` and `b`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function where_(condition, a, b) {
+ var $a = convertToTensor(a, 'a', 'where');
+ var $b = convertToTensor(b, 'b', 'where');
+ var $condition = convertToTensor(condition, 'condition', 'where', 'bool');
+ // TODO: move this logic to forward function when the broadcastTo op is
+ // implemented in WASM.
+ // Find the broadcastable shape for $condition, $a, and $b.
+ var broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
+ var $broadcastedCondition = broadcastTo($condition, broadcastShape);
+ var $broadcastedA = broadcastTo($a, broadcastShape);
+ var $broadcastedB = broadcastTo($b, broadcastShape);
+ var inputs = {
+ condition: $broadcastedCondition,
+ t: $broadcastedA,
+ e: $broadcastedB
+ };
+ return ENGINE.runKernel(Select, inputs);
+ }
+ var where = op({ where_: where_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
+ * given tensor.
+ *
+ * ```js
+ * const x = tf.tensor([1, 2]);
+ * tf.zerosLike(x).print();
+ * ```
+ *
+ * @param x The tensor of required shape.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function zerosLike_(x) {
+ var $x = convertToTensor(x, 'x', 'zerosLike');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(ZerosLike, inputs);
+ }
+ var zerosLike = op({ zerosLike_: zerosLike_ });
+
+ /**
+ * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
+ * if denominator is 0.
+ *
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 9, 16]);
+ * const b = tf.tensor1d([1, 2, 3, 4]);
+ * const c = tf.tensor1d([0, 0, 0, 0]);
+ *
+ * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
+ * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
+ * ```
+ *
+ * ```js
+ * // Broadcast div a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(2);
+ * const c = tf.scalar(0);
+ *
+ * a.divNoNan(b).print(); // or tf.divNoNan(a, b)
+ * a.divNoNan(c).print(); // or tf.divNoNan(a, c)
+ * ```
+ *
+ * @param a The first tensor as the numerator.
+ * @param b The second tensor as the denominator. Must have the same dtype as
+ * `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function divNoNan_(a, b) {
+ var _a;
+ // TODO: Make this into its own kernel.
+ var $a = convertToTensor(a, 'a', 'div');
+ var $b = convertToTensor(b, 'b', 'div');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var divResult = div($a, $b);
+ var zeros = zerosLike(divResult);
+ var bEqualsZero = equal($b, zeros);
+ return where(bEqualsZero, zeros, divResult);
+ }
+ var divNoNan = op({ divNoNan_: divNoNan_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ * const b = tf.tensor2d([[1, 2], [3, 4]]);
+ * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
+ *
+ * a.dot(b).print(); // or tf.dot(a, b)
+ * b.dot(a).print();
+ * b.dot(c).print();
+ * ```
+ * @param t1 The first tensor in the dot operation.
+ * @param t2 The second tensor in the dot operation.
+ *
+ * @doc {heading: 'Operations', subheading: 'Matrices'}
+ */
+ function dot_(t1, t2) {
+ var $t1 = convertToTensor(t1, 't1', 'dot');
+ var $t2 = convertToTensor(t2, 't2', 'dot');
+ assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), function () { return "Error in dot: inputs must all be rank 1 or 2, but got ranks " +
+ ($t1.rank + " and " + $t2.rank + "."); });
+ var t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]);
+ var t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]);
+ assert(t1Inner === t2Inner, function () { return "Error in dot: inner dimensions of inputs must match, but got " +
+ (t1Inner + " and " + t2Inner + "."); });
+ if ($t1.rank === 1 && $t2.rank === 1) {
+ var t12D = reshape($t1, [1, -1]);
+ var t22D = reshape($t2, [-1, 1]);
+ var t1t2 = matMul$1(t12D, t22D);
+ return reshape(t1t2, []);
+ }
+ else if ($t1.rank === 1 && $t2.rank === 2) {
+ var t12D = reshape($t1, [1, -1]);
+ var t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
+ var t1t2 = matMul$1(t12D, t22D);
+ return reshape(t1t2, [t1t2.size]);
+ }
+ else if ($t1.rank === 2 && $t2.rank === 1) {
+ var t22D = reshape($t2, [-1, 1]);
+ var t1t2 = matMul$1($t1, t22D);
+ return reshape(t1t2, [t1t2.size]);
+ }
+ else {
+ var t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
+ var t1t2 = matMul$1($t1, t22D);
+ return t1t2;
+ }
+ }
+ var dot = op({ dot_: dot_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Tensor contraction over specified indices and outer product.
+ *
+ * `einsum` allows defining Tensors by defining their element-wise computation.
+ * This computation is based on
+ * [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
+ *
+ * Some special cases include:
+ *
+ * Matrix multiplication:
+ * ```js
+ * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
+ * const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
+ * x.print();
+ * y.print();
+ * tf.einsum('ij,jk->ik', x, y).print();
+ * ```
+ *
+ * Dot product:
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ * const y = tf.tensor1d([0, 1, 2]);
+ * x.print();
+ * y.print();
+ * tf.einsum('i,i->', x, y).print();
+ * ```
+ *
+ * Batch dot product:
+ * ```js
+ * const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
+ * const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
+ * x.print();
+ * y.print();
+ * tf.einsum('bi,bi->b', x, y).print();
+ * ```
+ *
+ * Outer prouduct:
+ * ```js
+ * const x = tf.tensor1d([1, 3, 5]);
+ * const y = tf.tensor1d([2, 4, 6]);
+ * x.print();
+ * y.print();
+ * tf.einsum('i,j->ij', x, y).print();
+ * ```
+ *
+ * Matrix transpose:
+ * ```js
+ * const x = tf.tensor2d([[1, 2], [3, 4]]);
+ * x.print();
+ * tf.einsum('ij->ji', x).print();
+ * ```
+ *
+ * Batch matrix transpose:
+ * ```js
+ * const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
+ * x.print();
+ * tf.einsum('bij->bji', x).print();
+ * ```
+ *
+ * Limitations:
+ *
+ * This implementation of einsum has the following limitations:
+ *
+ * - Does not support >2 input tensors.
+ * - Does not support duplicate axes for any given input tensor. E.g., equation
+ * 'ii->' is not suppoted.
+ * - The `...` notation is not supported.
+ *
+ * @param equation a string describing the contraction, in the same format as
+ * [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
+ * @param tensors the input(s) to contract (each one a Tensor), whose shapes
+ * should be consistent with equation.
+ * @returns The output tensor.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Matrices'}
+ */
+ function einsum_(equation) {
+ var tensors = [];
+ for (var _i = 1; _i < arguments.length; _i++) {
+ tensors[_i - 1] = arguments[_i];
+ }
+ var $tensors = tensors.map(function (t, i) { return convertToTensor(t, "tensors" + i, 'einsum'); });
+ var attrs = { equation: equation };
+ return ENGINE.runKernel(Einsum, $tensors, attrs);
+ }
+ var einsum = op({ einsum_: einsum_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes exponential linear element-wise: `x > 0 ? x : (e ^ x) - 1`.
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 1, -3, 2]);
+ *
+ * x.elu().print(); // or tf.elu(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function elu_(x) {
+ var $x = convertToTensor(x, 'x', 'elu', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Elu, inputs);
+ }
+ var elu = op({ elu_: elu_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes gause error function of the input `tf.Tensor` element-wise:
+ * `erf(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, .1, -.1, .7]);
+ *
+ * x.erf().print(); // or tf.erf(x);
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function erf_(x) {
+ var $x = convertToTensor(x, 'x', 'erf');
+ assert($x.dtype === 'int32' || $x.dtype === 'float32', function () { return 'Input dtype must be `int32` or `float32`.'; });
+ if ($x.dtype === 'int32') {
+ $x = cast($x, 'float32');
+ }
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Erf, inputs);
+ }
+ var erf = op({ erf_: erf_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, -3]);
+ *
+ * x.exp().print(); // or tf.exp(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function exp_(x) {
+ var $x = convertToTensor(x, 'x', 'exp');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Exp, inputs);
+ }
+ var exp = op({ exp_: exp_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
+ * into the tensor's shape.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ * const axis = 1;
+ * x.expandDims(axis).print();
+ * ```
+ *
+ * @param x The input tensor whose dimensions to be expanded.
+ * @param axis The dimension index at which to insert shape of `1`. Defaults
+ * to 0 (the first dimension).
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function expandDims_(x, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
+ assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; });
+ var inputs = { input: $x };
+ var attrs = { dim: axis };
+ return ENGINE.runKernel(ExpandDims, inputs, attrs);
+ }
+ var expandDims = op({ expandDims_: expandDims_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes exponential of the input `tf.Tensor` minus one element-wise.
+ * `e ^ x - 1`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, -3]);
+ *
+ * x.expm1().print(); // or tf.expm1(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function expm1_(x) {
+ var $x = convertToTensor(x, 'x', 'expm1');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Expm1, inputs);
+ }
+ var expm1 = op({ expm1_: expm1_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Construct a tensor by repeating it the number of times given by reps.
+ *
+ * This operation creates a new tensor by replicating `input` `reps`
+ * times. The output tensor's i'th dimension has `input.shape[i] *
+ * reps[i]` elements, and the values of `input` are replicated
+ * `reps[i]` times along the i'th dimension. For example, tiling
+ * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ *
+ * a.tile([2]).print(); // or a.tile([2])
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * a.tile([1, 2]).print(); // or a.tile([1, 2])
+ * ```
+ * @param x The tensor to tile.
+ * @param reps Determines the number of replications per dimension.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function tile_(x, reps) {
+ var $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
+ assert($x.rank === reps.length, function () { return "Error in transpose: rank of input " + $x.rank + " " +
+ ("must match length of reps " + reps + "."); });
+ var inputs = { x: $x };
+ var attrs = { reps: reps };
+ return ENGINE.runKernel(Tile, inputs, attrs);
+ }
+ var tile = op({ tile_: tile_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Create an identity matrix.
+ *
+ * @param numRows Number of rows.
+ * @param numColumns Number of columns. Defaults to `numRows`.
+ * @param batchShape If provided, will add the batch shape to the beginning
+ * of the shape of the returned `tf.Tensor` by repeating the identity
+ * matrix.
+ * @param dtype Data type.
+ * @returns Identity matrix of the specified size and data type, possibly
+ * with batch repetition if `batchShape` is specified.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function eye_(numRows, numColumns, batchShape, dtype) {
+ if (dtype === void 0) { dtype = 'float32'; }
+ if (numColumns == null) {
+ numColumns = numRows;
+ }
+ var buff = buffer([numRows, numColumns], dtype);
+ var n = numRows <= numColumns ? numRows : numColumns;
+ for (var i = 0; i < n; ++i) {
+ buff.set(1, i, i);
+ }
+ var out = reshape(buff.toTensor(), [numRows, numColumns]);
+ if (batchShape == null) {
+ return out;
+ }
+ else {
+ if (batchShape.length === 1) {
+ return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
+ }
+ else if (batchShape.length === 2) {
+ // tslint:disable-next-line:no-unnecessary-type-assertion
+ return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
+ }
+ else if (batchShape.length === 3) {
+ // tslint:disable-next-line:no-unnecessary-type-assertion
+ return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [
+ batchShape[0], batchShape[1], batchShape[2], 1, 1
+ ]);
+ }
+ else {
+ throw new Error("eye() currently supports only 1D and 2D " +
+ (
+ // tslint:disable-next-line:no-any
+ "batchShapes, but received " + batchShape.length + "D."));
+ }
+ }
+ }
+ var eye = op({ eye_: eye_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` filled with a scalar value.
+ *
+ * ```js
+ * tf.fill([2, 2], 4).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param value The scalar value to fill the tensor with.
+ * @param dtype The type of an element in the resulting tensor. Defaults to
+ * 'float'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function fill(shape, value, dtype) {
+ var attrs = { shape: shape, value: value, dtype: dtype };
+ return ENGINE.runKernel(Fill, {}, attrs);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
+ *
+ * ```js
+ * const x = tf.tensor1d([.6, 1.1, -3.3]);
+ *
+ * x.floor().print(); // or tf.floor(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function floor_(x) {
+ var $x = convertToTensor(x, 'x', 'floor', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Floor, inputs);
+ }
+ var floor = op({ floor_: floor_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Gather slices from tensor `x`'s axis `axis` according to `indices`.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ * const indices = tf.tensor1d([1, 3, 3], 'int32');
+ *
+ * x.gather(indices).print();
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ * const indices = tf.tensor1d([1, 1, 0], 'int32');
+ *
+ * x.gather(indices).print();
+ * ```
+ * @param x The input tensor whose slices to be gathered.
+ * @param indices The indices of the values to extract.
+ * @param axis The axis over which to select values. Defaults to 0.
+ * @param batchDims Optional. The number of batch dimensions. It must be less
+ * than or equal to rank(indices). Defaults to 0.
+ * The output tensor will have shape of
+ * `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function gather_(x, indices, axis, batchDims) {
+ if (axis === void 0) { axis = 0; }
+ if (batchDims === void 0) { batchDims = 0; }
+ var $x = convertToTensor(x, 'x', 'gather');
+ var $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
+ var inputs = { x: $x, indices: $indices };
+ var attrs = { axis: axis, batchDims: batchDims };
+ return ENGINE.runKernel(GatherV2, inputs, attrs);
+ }
+ var gather = op({ gather_: gather_ });
+
+ /**
+ * Returns the truth value of (a > b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([2, 2, 2]);
+ *
+ * a.greater(b).print();
+ * ```
+ *
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function greater_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Greater, inputs);
+ }
+ var greater = op({ greater_: greater_ });
+
+ /**
+ * Returns the truth value of (a >= b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([2, 2, 2]);
+ *
+ * a.greaterEqual(b).print();
+ * ```
+ *
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function greaterEqual_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(GreaterEqual, inputs);
+ }
+ var greaterEqual = op({ greaterEqual_: greaterEqual_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the imaginary part of a complex (or real) tensor.
+ *
+ * Given a tensor input, this operation returns a tensor of type float that is
+ * the imaginary part of each element in input considered as a complex number.
+ * If input is real, a tensor of all zeros is returned.
+ *
+ * ```js
+ * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
+ * tf.imag(x).print();
+ * ```
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function imag_(input) {
+ var $input = convertToTensor(input, 'input', 'imag');
+ var inputs = { input: $input };
+ return ENGINE.runKernel(Imag, inputs);
+ }
+ var imag = op({ imag_: imag_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns which elements of x are finite.
+ *
+ * ```js
+ * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
+ *
+ * x.isFinite().print(); // or tf.isNaN(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function isFinite_(x) {
+ var $x = convertToTensor(x, 'x', 'isFinite');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(IsFinite, inputs);
+ }
+ var isFinite$1 = op({ isFinite_: isFinite_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns which elements of x are Infinity or -Infinity.
+ *
+ * ```js
+ * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
+ *
+ * x.isInf().print(); // or tf.isNaN(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function isInf_(x) {
+ var $x = convertToTensor(x, 'x', 'isInf');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(IsInf, inputs);
+ }
+ var isInf = op({ isInf_: isInf_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * RReturns which elements of x are NaN.
+ *
+ * ```js
+ * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
+ *
+ * x.isNaN().print(); // or tf.isNaN(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function isNaN_(x) {
+ var $x = convertToTensor(x, 'x', 'isNaN');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(IsNan, inputs);
+ }
+ var isNaN$1 = op({ isNaN_: isNaN_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes leaky rectified linear element-wise.
+ *
+ * See
+ * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
+ * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ *
+ * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
+ * ```
+ * @param x The input tensor.
+ * @param alpha The scaling factor for negative values, defaults to 0.2.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function leakyRelu_(x, alpha) {
+ if (alpha === void 0) { alpha = 0.2; }
+ var $x = convertToTensor(x, 'x', 'leakyRelu');
+ var inputs = { x: $x };
+ var attrs = { alpha: alpha };
+ return ENGINE.runKernel(LeakyRelu, inputs, attrs);
+ }
+ var leakyRelu = op({ leakyRelu_: leakyRelu_ });
+
+ /**
+ * Returns the truth value of (a < b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([2, 2, 2]);
+ *
+ * a.less(b).print();
+ * ```
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function less_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Less, inputs);
+ }
+ var less = op({ less_: less_ });
+
+ /**
+ * Returns the truth value of (a <= b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([2, 2, 2]);
+ *
+ * a.lessEqual(b).print();
+ * ```
+ *
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function lessEqual_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(LessEqual, inputs);
+ }
+ var lessEqual = op({ lessEqual_: lessEqual_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Return an evenly spaced sequence of numbers over the given interval.
+ *
+ * ```js
+ * tf.linspace(0, 9, 10).print();
+ * ```
+ * @param start The start value of the sequence.
+ * @param stop The end value of the sequence.
+ * @param num The number of values to generate.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function linspace(start, stop, num) {
+ if (num <= 0) {
+ throw new Error('The number of values should be positive.');
+ }
+ var attrs = { start: start, stop: stop, num: num };
+ return ENGINE.runKernel(LinSpace, {}, attrs);
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Normalizes the activation of a local neighborhood across or within
+ * channels.
+ *
+ * @param x The input tensor. The 4-D input tensor is treated as a 3-D array
+ * of 1D vectors (along the last dimension), and each vector is
+ * normalized independently.
+ * @param depthRadius The number of adjacent channels in the 1D normalization
+ * window.
+ * @param bias A constant bias term for the basis.
+ * @param alpha A scale factor, usually positive.
+ * @param beta An exponent.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function localResponseNormalization_(x, depthRadius, bias, alpha, beta) {
+ if (depthRadius === void 0) { depthRadius = 5; }
+ if (bias === void 0) { bias = 1; }
+ if (alpha === void 0) { alpha = 1; }
+ if (beta === void 0) { beta = 0.5; }
+ var $x = convertToTensor(x, 'x', 'localResponseNormalization');
+ assert($x.rank === 4 || $x.rank === 3, function () { return "Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank " + $x.rank + "."; });
+ assert(isInt(depthRadius), function () { return "Error in localResponseNormalization: depthRadius must be an " +
+ ("integer but got depthRadius " + depthRadius + "."); });
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ var inputs = { x: x4D };
+ var attrs = { depthRadius: depthRadius, bias: bias, alpha: alpha, beta: beta };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(LRN, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ else {
+ return res;
+ }
+ }
+ var localResponseNormalization = op({ localResponseNormalization_: localResponseNormalization_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, Math.E]);
+ *
+ * x.log().print(); // or tf.log(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function log_(x) {
+ var $x = convertToTensor(x, 'x', 'log', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Log, inputs);
+ }
+ var log = op({ log_: log_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes natural logarithm of the input `tf.Tensor` plus one
+ * element-wise: `ln(1 + x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, Math.E - 1]);
+ *
+ * x.log1p().print(); // or tf.log1p(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function log1p_(x) {
+ var $x = convertToTensor(x, 'x', 'log1p');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Log1p, inputs);
+ }
+ var log1p = op({ log1p_: log1p_ });
+
+ /**
+ * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
+ * gradient of `f(x)` with respect to `x`.
+ *
+ * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
+ * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
+ * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
+ *
+ * ```js
+ * // f(x) = x ^ 2
+ * const f = x => x.square();
+ * // f'(x) = 2x
+ * const g = tf.grad(f);
+ *
+ * const x = tf.tensor1d([2, 3]);
+ * g(x).print();
+ * ```
+ *
+ * ```js
+ * // f(x) = x ^ 3
+ * const f = x => x.pow(tf.scalar(3, 'int32'));
+ * // f'(x) = 3x ^ 2
+ * const g = tf.grad(f);
+ * // f''(x) = 6x
+ * const gg = tf.grad(g);
+ *
+ * const x = tf.tensor1d([2, 3]);
+ * gg(x).print();
+ * ```
+ *
+ * @param f The function f(x), to compute gradient for.
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function grad(f) {
+ assert(isFunction(f), function () { return 'The f passed in grad(f) must be a function'; });
+ return function (x, dy) {
+ // x can be of any dtype, thus null as the last argument.
+ var $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
+ var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
+ return ENGINE.tidy(function () {
+ var _a = ENGINE.gradients(function () { return f($x); }, [$x], $dy), value = _a.value, grads = _a.grads;
+ if ($dy != null) {
+ assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
+ 'returned by f(x)');
+ }
+ checkGrads(grads);
+ return grads[0];
+ });
+ };
+ }
+ /**
+ * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
+ * which gives an array of gradients of `f()` with respect to each input
+ * [`x1`,`x2`,...].
+ *
+ * If `dy` is passed when calling `g()`, the gradient of
+ * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
+ * The provided `f` must take one or more tensors and return a single tensor
+ * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
+ *
+ * ```js
+ * // f(a, b) = a * b
+ * const f = (a, b) => a.mul(b);
+ * // df / da = b, df / db = a
+ * const g = tf.grads(f);
+ *
+ * const a = tf.tensor1d([2, 3]);
+ * const b = tf.tensor1d([-2, -3]);
+ * const [da, db] = g([a, b]);
+ * console.log('da');
+ * da.print();
+ * console.log('db');
+ * db.print();
+ * ```
+ *
+ * @param f The function `f(x1, x2,...)` to compute gradients for.
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function grads(f) {
+ assert(isFunction(f), function () { return 'The f passed in grads(f) must be a function'; });
+ return function (args, dy) {
+ assert(Array.isArray(args), function () { return 'The args passed in grads(f)(args) must be an array ' +
+ 'of `Tensor`s or `TensorLike`s'; });
+ // args can be of any dtype, thus null as the last argument.
+ var $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
+ var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
+ return ENGINE.tidy(function () {
+ var _a = ENGINE.gradients(function () { return f.apply(void 0, __spread($args)); }, $args, $dy), value = _a.value, grads = _a.grads;
+ if ($dy != null) {
+ assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
+ 'match the shape returned by f([x1,...])');
+ }
+ checkGrads(grads);
+ return grads;
+ });
+ };
+ }
+ /**
+ * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
+ * returns a metric you want to show.
+ *
+ * The result is a rich object with the following properties:
+ * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).
+ * - value: The value returned by `f(x)`.
+ *
+ * ```js
+ * // f(x) = x ^ 2
+ * const f = x => x.square();
+ * // f'(x) = 2x
+ * const g = tf.valueAndGrad(f);
+ *
+ * const x = tf.tensor1d([2, 3]);
+ * const {value, grad} = g(x);
+ *
+ * console.log('value');
+ * value.print();
+ * console.log('grad');
+ * grad.print();
+ * ```
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function valueAndGrad(f) {
+ assert(isFunction(f), function () { return 'The f passed in valueAndGrad(f) must be a function'; });
+ return function (x, dy) {
+ assert(x instanceof Tensor, function () { return 'The x passed in valueAndGrad(f)(x) must be a tensor'; });
+ assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'; });
+ var _a = ENGINE.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value;
+ checkGrads(grads);
+ return { grad: grads[0], value: value };
+ };
+ }
+ /**
+ * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
+ * returns a metric you want to show.
+ *
+ * The result is a rich object with the following properties:
+ * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).
+ * - value: The value returned by `f(x)`.
+ *
+ * ```js
+ * // f(a, b) = a * b
+ * const f = (a, b) => a.mul(b);
+ * // df/da = b, df/db = a
+ * const g = tf.valueAndGrads(f);
+ *
+ * const a = tf.tensor1d([2, 3]);
+ * const b = tf.tensor1d([-2, -3]);
+ * const {value, grads} = g([a, b]);
+ *
+ * const [da, db] = grads;
+ *
+ * console.log('value');
+ * value.print();
+ *
+ * console.log('da');
+ * da.print();
+ * console.log('db');
+ * db.print();
+ * ```
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function valueAndGrads(f) {
+ assert(isFunction(f), function () { return 'The f passed in valueAndGrads(f) must be a function'; });
+ return function (args, dy) {
+ assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof Tensor; }), function () { return 'The args passed in valueAndGrads(f)(args) must be array of ' +
+ 'tensors'; });
+ assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'; });
+ var res = ENGINE.gradients(function () { return f.apply(void 0, __spread(args)); }, args, dy);
+ if (dy != null) {
+ assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
+ 'match the shape returned by f([x1,...])');
+ }
+ checkGrads(res.grads);
+ return res;
+ };
+ }
+ /**
+ * Computes and returns the gradient of f(x) with respect to the list of
+ * trainable variables provided by `varList`. If no list is provided, it
+ * defaults to all trainable variables.
+ *
+ * ```js
+ * const a = tf.variable(tf.tensor1d([3, 4]));
+ * const b = tf.variable(tf.tensor1d([5, 6]));
+ * const x = tf.tensor1d([1, 2]);
+ *
+ * // f(a, b) = a * x ^ 2 + b * x
+ * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
+ * // df/da = x ^ 2, df/db = x
+ * const {value, grads} = tf.variableGrads(f);
+ *
+ * Object.keys(grads).forEach(varName => grads[varName].print());
+ * ```
+ *
+ * @param f The function to execute. f() should return a scalar.
+ * @param varList The list of variables to compute the gradients with respect
+ * to. Defaults to all trainable variables.
+ * @returns An object with the following keys and values:
+ * - `value`: The value of the function `f`.
+ * - `grads`: A map from the names of the variables to the gradients.
+ * If the `varList` argument is provided explicitly and contains a subset of
+ * non-trainable variables, this map in the return value will contain keys
+ * that map the names of the non-trainable variables to `null`.
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function variableGrads(f, varList) {
+ assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; });
+ assert(varList == null ||
+ Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () { return 'The varList passed in variableGrads(f, varList) must be an array ' +
+ 'of variables'; });
+ var specifiedVarList = varList != null;
+ if (!specifiedVarList) {
+ // Get all of the trainable variables.
+ varList = [];
+ for (var varName in ENGINE.registeredVariables) {
+ varList.push(ENGINE.registeredVariables[varName]);
+ }
+ }
+ var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null;
+ // Prune non-trainable variables.
+ var originalVarCount = varList.length;
+ varList = varList.filter(function (variable) { return variable.trainable; });
+ assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " +
+ ("be trainable, but none of the " + originalVarCount + " variables is ") +
+ "trainable."; });
+ var allowNoGradients = true;
+ var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads;
+ assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' +
+ 'the loss function y=f(x). Please make sure the operations that ' +
+ 'use variables are inside the function f passed to minimize().'; });
+ assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " +
+ ("returned a rank-" + value.rank + " tensor"); });
+ var namedGrads = {};
+ varList.forEach(function (v, i) {
+ if (grads[i] != null) {
+ namedGrads[v.name] = grads[i];
+ }
+ });
+ if (specifiedNonTrainable != null) {
+ // If varList is explicitly provided and contains non-trainable values,
+ // add them to the returned gradients with `null` values.
+ specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; });
+ }
+ return { value: value, grads: namedGrads };
+ }
+ /**
+ * Overrides the gradient computation of a function `f`.
+ *
+ * Takes a function
+ * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
+ * and returns another function `g(...inputs)` which takes the same inputs as
+ * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
+ * with respect to each input of `f` are computed using `f().gradFunc`.
+ *
+ * The `save` function passsed to `f` should be used for saving tensors needed
+ * in the gradient. And the `saved` passed to the `gradFunc` is a
+ * `NamedTensorMap`, which contains those saved tensor.
+ *
+ * ```js
+ * const customOp = tf.customGrad((x, save) => {
+ * // Save x to make sure it's available later for the gradient.
+ * save([x]);
+ * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
+ * return {
+ * value: x.square(),
+ * // Note `saved.x` which points to the `x` we saved earlier.
+ * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
+ * };
+ * });
+ *
+ * const x = tf.tensor1d([-1, -2, 3]);
+ * const dx = tf.grad(x => customOp(x));
+ *
+ * console.log(`f(x):`);
+ * customOp(x).print();
+ * console.log(`f'(x):`);
+ * dx(x).print();
+ * ```
+ *
+ * @param f The function to evaluate in forward mode, which should return
+ * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
+ * returns the custom gradients of `f` with respect to its inputs.
+ *
+ * @doc {heading: 'Training', subheading: 'Gradients'}
+ */
+ function customGrad(f) {
+ return ENGINE.customGrad(f);
+ }
+ function checkGrads(grads) {
+ var numNullGradients = grads.filter(function (g) { return g == null; }).length;
+ if (numNullGradients > 0) {
+ throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
+ }
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes `-1 * x` element-wise.
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
+ *
+ * x.neg().print(); // or tf.neg(x)
+ * ```
+ *
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function neg_(x) {
+ var $x = convertToTensor(x, 'x', 'neg');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Neg, inputs);
+ }
+ var neg = op({ neg_: neg_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.softplus().print(); // or tf.softplus(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function softplus_(x) {
+ var $x = convertToTensor(x, 'x', 'softplus');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Softplus, inputs);
+ }
+ var softplus = op({ softplus_: softplus_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes log sigmoid of the input `tf.Tensor` element-wise:
+ * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.logSigmoid().print(); // or tf.logSigmoid(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function logSigmoid_(x) {
+ var $x = convertToTensor(x, 'x', 'logSigmoid');
+ // Use a custom gradient to maintain previous implementation.
+ // There is no LogSigmoid kernel in TF so we can't use engine.runKernel
+ // directly
+ var customOp = customGrad(function (x) {
+ // TODO(yassogba) we can remove the chained softplus call here only
+ // after backends have modualrized softplus at which point we can call
+ // engine runKernel(..., Sotfplus, ...) directly.
+ var value = neg(softplus(neg(x)));
+ var gradFunc = function (dy) {
+ var derX = mul(dy, sigmoid(neg(x)));
+ return derX;
+ };
+ return { value: value, gradFunc: gradFunc };
+ });
+ return customOp($x);
+ }
+ var logSigmoid = op({ logSigmoid_: logSigmoid_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the maximum of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
+ * `axes`. If `keepDims` is true, the reduced dimensions are retained with
+ * length 1. If `axes` has no entries, all dimensions are reduced, and an
+ * `tf.Tensor` with a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.max().print(); // or tf.max(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.max(axis).print(); // or tf.max(x, axis)
+ * ```
+ *
+ * @param x The input tensor.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function max_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'max');
+ var inputs = { x: $x };
+ var attrs = { reductionIndices: axis, keepDims: keepDims };
+ return ENGINE.runKernel(Max, inputs, attrs);
+ }
+ var max = op({ max_: max_ });
+
+ /**
+ * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([10, 20, 30, 40]);
+ * const b = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * a.sub(b).print(); // or tf.sub(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast subtract a with b.
+ * const a = tf.tensor1d([10, 20, 30, 40]);
+ * const b = tf.scalar(5);
+ *
+ * a.sub(b).print(); // or tf.sub(a, b)
+ * ```
+ * @param a The first `tf.Tensor` to subtract from.
+ * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
+ * `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function sub_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'sub');
+ var $b = convertToTensor(b, 'b', 'sub');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Sub, inputs);
+ }
+ var sub = op({ sub_: sub_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the sum of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
+ * `axes`. If `keepDims` is true, the reduced dimensions are retained with
+ * length 1. If axes has no entries, all dimensions are reduced, and a
+ * `tf.Tensor` with a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.sum().print(); // or tf.sum(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.sum(axis).print(); // or tf.sum(x, axis)
+ * ```
+ *
+ * @param x The input tensor to compute the sum over. If the dtype is `bool`
+ * it will be converted to `int32` and the output dtype will be `int32`.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function sum_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'sum');
+ if ($x.dtype === 'bool') {
+ $x = cast($x, 'int32');
+ }
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ return ENGINE.runKernel(Sum, inputs, attrs);
+ }
+ var sum = op({ sum_: sum_ });
+
+ /**
+ * Computes the log softmax.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ *
+ * a.logSoftmax().print(); // or tf.logSoftmax(a)
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
+ *
+ * a.logSoftmax().print(); // or tf.logSoftmax(a)
+ * ```
+ *
+ * @param logits The logits array.
+ * @param axis The dimension softmax would be performed on. Defaults to `-1`
+ * which indicates the last dimension.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function logSoftmax_(logits, axis) {
+ if (axis === void 0) { axis = -1; }
+ var $logits = convertToTensor(logits, 'logits', 'logSoftmax');
+ if (axis === -1) {
+ axis = $logits.rank - 1;
+ }
+ if (axis !== $logits.rank - 1) {
+ throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
+ ("Logits was rank " + $logits.rank + " and axis was " + axis));
+ }
+ // const forward: ForwardFunc<Tensor> = (backend, save) => {
+ // const keepDims = true;
+ // const xMax = max(logits, axis, true);
+ // const shifted = sub(logits, xMax);
+ // const value =
+ // sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
+ // keepDims)));
+ // save([value]);
+ // return value;
+ // };
+ // Use a custom gradient for numerical stability.
+ var customOp = customGrad(function (logits, save) {
+ var keepDims = true;
+ var xMax = max(logits, axis, true);
+ var shifted = sub(logits, xMax);
+ var value = sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis, keepDims)));
+ save([value]);
+ var gradFunc = function (dy, saved) {
+ var _a = __read(saved, 1), value = _a[0];
+ var keepDims = true;
+ var softmax = exp(value);
+ return sub(dy, mul(sum(dy, axis, keepDims), softmax));
+ };
+ return { value: value, gradFunc: gradFunc };
+ });
+ return customOp($logits);
+ // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
+ // const inputs: LogSoftmaxInputs = {logits: $logits};
+ // const attrs: LogSoftmaxAttrs = {axis};
+ // return ENGINE.runKernel(
+ // LogSoftmax, inputs as {} as NamedTensorMap,
+ // attrs as {} as NamedAttrMap);
+ }
+ var logSoftmax = op({ logSoftmax_: logSoftmax_ });
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns true if the axis specifies the inner most dimensions of the
+ * array.
+ */
+ function axesAreInnerMostDims(axes, rank) {
+ for (var i = 0; i < axes.length; ++i) {
+ if (axes[axes.length - i - 1] !== rank - 1 - i) {
+ return false;
+ }
+ }
+ return true;
+ }
+ function combineLocations(outputLoc, reduceLoc, axes) {
+ var rank = outputLoc.length + reduceLoc.length;
+ var loc = [];
+ var outIdx = 0;
+ var reduceIdx = 0;
+ for (var dim = 0; dim < rank; dim++) {
+ if (axes.indexOf(dim) === -1) {
+ loc.push(outputLoc[outIdx++]);
+ }
+ else {
+ loc.push(reduceLoc[reduceIdx++]);
+ }
+ }
+ return loc;
+ }
+ function computeOutAndReduceShapes(aShape, axes) {
+ var outShape = [];
+ var rank = aShape.length;
+ for (var dim = 0; dim < rank; dim++) {
+ if (axes.indexOf(dim) === -1) {
+ outShape.push(aShape[dim]);
+ }
+ }
+ var reduceShape = axes.map(function (dim) { return aShape[dim]; });
+ return [outShape, reduceShape];
+ }
+ function expandShapeToKeepDim(shape, axes) {
+ var reduceSubShape = axes.map(function (x) { return 1; });
+ return combineLocations(shape, reduceSubShape, axes);
+ }
+ function assertAxesAreInnerMostDims(msg, axes, rank) {
+ assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " +
+ ("Got axes " + axes + " and rank-" + rank + " input."); });
+ }
+ /**
+ * Returns the axes permutation to be used with `tf.transpose`, if such
+ * permutation is necessary. Otherwise it returns null. This method is used by
+ * operations that operate only on inner-most axes.
+ */
+ function getAxesPermutation(axes, rank) {
+ if (axesAreInnerMostDims(axes, rank)) {
+ return null;
+ }
+ var result = [];
+ for (var i = 0; i < rank; ++i) {
+ if (axes.indexOf(i) === -1) {
+ result.push(i);
+ }
+ }
+ axes.forEach(function (axis) { return result.push(axis); });
+ return result;
+ }
+ /** Returns the axes permutation that undoes the original permutation. */
+ function getUndoAxesPermutation(axes) {
+ return axes.map(function (axis, i) { return [i, axis]; })
+ .sort(function (a, b) { return a[1] - b[1]; })
+ .map(function (x) { return x[0]; });
+ }
+ function getInnerMostAxes(numAxes, rank) {
+ var res = [];
+ for (var i = rank - numAxes; i < rank; ++i) {
+ res.push(i);
+ }
+ return res;
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the log(sum(exp(elements across the reduction dimensions)).
+ *
+ * Reduces the input along the dimensions given in `axis`. Unless `keepDims`
+ * is true, the rank of the array is reduced by 1 for each entry in `axis`.
+ * If `keepDims` is true, the reduced dimensions are retained with length 1.
+ * If `axis` has no entries, all dimensions are reduced, and an array with a
+ * single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.logSumExp().print(); // or tf.logSumExp(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
+ * ```
+ * @param x The input tensor.
+ * @param axis The dimension(s) to reduce. If null (the default),
+ * reduces all dimensions.
+ * @param keepDims If true, retains reduced dimensions with length
+ * of 1. Defaults to false.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function logSumExp_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'logSumExp');
+ var axes = parseAxisParam(axis, $x.shape);
+ var xMax = max($x, axes, true /* keepDims */);
+ var a = sub($x, xMax);
+ var b = exp(a);
+ var c = sum(b, axes);
+ var d = log(c);
+ var res = add(reshape(xMax, d.shape), d);
+ if (keepDims) {
+ var newShape = expandShapeToKeepDim(res.shape, axes);
+ return reshape(res, newShape);
+ }
+ return res;
+ }
+ var logSumExp = op({ logSumExp_: logSumExp_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([false, false, true, true], 'bool');
+ * const b = tf.tensor1d([false, true, false, true], 'bool');
+ *
+ * a.logicalAnd(b).print();
+ * ```
+ *
+ * @param a The first input tensor. Must be of dtype bool.
+ * @param b The second input tensor. Must be of dtype bool.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function logicalAnd_(a, b) {
+ var $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
+ var $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(LogicalAnd, inputs);
+ }
+ var logicalAnd = op({ logicalAnd_: logicalAnd_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the truth value of `NOT x` element-wise.
+ *
+ * ```js
+ * const a = tf.tensor1d([false, true], 'bool');
+ *
+ * a.logicalNot().print();
+ * ```
+ *
+ * @param x The input tensor. Must be of dtype 'bool'.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function logicalNot_(x) {
+ var $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(LogicalNot, inputs);
+ }
+ var logicalNot = op({ logicalNot_: logicalNot_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([false, false, true, true], 'bool');
+ * const b = tf.tensor1d([false, true, false, true], 'bool');
+ *
+ * a.logicalOr(b).print();
+ * ```
+ * @param a The first input tensor. Must be of dtype bool.
+ * @param b The second input tensor. Must be of dtype bool.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function logicalOr_(a, b) {
+ var $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
+ var $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(LogicalOr, inputs);
+ }
+ var logicalOr = op({ logicalOr_: logicalOr_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([false, false, true, true], 'bool');
+ * const b = tf.tensor1d([false, true, false, true], 'bool');
+ *
+ * a.logicalXor(b).print();
+ * ```
+ *
+ * @param a The first input tensor. Must be of dtype bool.
+ * @param b The second input tensor. Must be of dtype bool.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function logicalXor_(a, b) {
+ var $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
+ var $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ // x ^ y = (x | y) & ~(x & y)
+ return logicalAnd(logicalOr(a, b), logicalNot(logicalAnd(a, b)));
+ }
+ var logicalXor = op({ logicalXor_: logicalXor_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the 2D max pooling of an image.
+ *
+ * @param x The input tensor, of rank 4 or rank 3 of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
+ * `filterSize` is a single number, then `filterHeight == filterWidth`.
+ * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
+ * `strides` is a single number, then `strideHeight == strideWidth`.
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ */
+ function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
+ var $x = convertToTensor(x, 'x', 'maxPool');
+ var dilations = 1;
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in maxPool: input must be rank 4 but got rank " + x4D.rank + "."; });
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ checkPadOnDimRoundingMode('maxPool', pad, dimRoundingMode);
+ var inputs = { x: x4D };
+ var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(MaxPool, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var maxPool = op({ maxPool_: maxPool_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the 3D max pooling.
+ *
+ * ```js
+ * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
+ * const result = tf.maxPool3d(x, 2, 1, 'valid');
+ * result.print();
+ * ```
+ *
+ * @param x The input tensor, of rank 5 or rank 4 of shape
+ * `[batch, depth, height, width, inChannels]`.
+ * @param filterSize The filter size:
+ * `[filterDepth, filterHeight, filterWidth]`.
+ * If `filterSize` is a single number,
+ * then `filterDepth == filterHeight == filterWidth`.
+ * @param strides The strides of the pooling:
+ * `[strideDepth, strideHeight, strideWidth]`.
+ * If `strides` is a single number,
+ * then `strideDepth == strideHeight == strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1*1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
+ * "NDHWC". Specify the data format of the input and output data. With the
+ * default format "NDHWC", the data is stored in the order of: [batch,
+ * depth, height, width, channels]. Only "NDHWC" is currently supported.
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function maxPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat) {
+ if (filterSize === void 0) { filterSize = [1, 1, 1]; }
+ if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
+ var $x = convertToTensor(x, 'x', 'maxPool3d');
+ var x5D = $x;
+ var reshapedTo5D = false;
+ if ($x.rank === 4) {
+ reshapedTo5D = true;
+ x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
+ }
+ assert(x5D.rank === 5, function () { return "Error in maxPool3d: x must be rank 5 but got rank " + x5D.rank + "."; });
+ assert(dataFormat === 'NDHWC', function () { return "Error in maxPool3d: Only NDHWC is currently supported, " +
+ ("but got dataFormat of " + dataFormat); });
+ checkPadOnDimRoundingMode('maxPool3d', pad, dimRoundingMode);
+ var inputs = { x: x5D };
+ var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dataFormat: dataFormat };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(MaxPool3D, inputs, attrs);
+ if (reshapedTo5D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
+ }
+ return res;
+ }
+ var maxPool3d = op({ maxPool3d_: maxPool3d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the 2D max pooling of an image with Argmax index.
+ * The indices in argmax are flattened, so that a maximum value at position `[b,
+ * y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
+ * include_batch_in_index is False; `((b * height + y) * width + x) * channels
+ * +c` if include_batch_in_index is True.
+ *
+ * The indices returned are always in `[0, height) x [0, width)` before
+ * flattening.
+ *
+ * @param x The input tensor, of rank 4 or rank 3 of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param filterSize The filter size: `[filterHeight, filterWidth]`. If
+ * `filterSize` is a single number, then `filterHeight == filterWidth`.
+ * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
+ * `strides` is a single number, then `strideHeight == strideWidth`.
+ * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
+ * "NDHWC". Specify the data format of the input and output data. With the
+ * default format "NDHWC", the data is stored in the order of: [batch,
+ * depth, height, width, channels]. Only "NDHWC" is currently supported.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param includeBatchIndex Defaults to False. Whether to include batch
+ * dimension in flattened index of argmax.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function maxPoolWithArgmax_(x, filterSize, strides, pad, includeBatchInIndex) {
+ if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
+ var $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
+ var inputs = { x: $x };
+ var attrs = { filterSize: filterSize, strides: strides, pad: pad, includeBatchInIndex: includeBatchInIndex };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
+ return { result: result[0], indexes: result[1] };
+ }
+ var maxPoolWithArgmax = op({ maxPoolWithArgmax_: maxPoolWithArgmax_ });
+
+ /**
+ * Returns the max of a and b (`a > b ? a : b`) element-wise.
+ * Supports broadcasting.
+ *
+ * We also expose `tf.maximumStrict` which has the same signature as this op and
+ * asserts that `a` and `b` are the same shape (does not broadcast).
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 3, 16]);
+ * const b = tf.tensor1d([1, 2, 9, 4]);
+ *
+ * a.maximum(b).print(); // or tf.maximum(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast maximum a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(5);
+ *
+ * a.maximum(b).print(); // or tf.maximum(a, b)
+ * ```
+ *
+ * @param a The first tensor.
+ * @param b The second tensor. Must have the same type as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function maximum_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'maximum');
+ var $b = convertToTensor(b, 'b', 'maximum');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ if ($a.dtype === 'bool') {
+ $a = cast($a, 'int32');
+ $b = cast($b, 'int32');
+ }
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Maximum, inputs);
+ }
+ var maximum = op({ maximum_: maximum_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the mean of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
+ * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
+ * If `keepDims` is true, the reduced dimensions are retained with length 1.
+ * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
+ * a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.mean().print(); // or tf.mean(a)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.mean(axis).print(); // or tf.mean(x, axis)
+ * ```
+ *
+ * @param x The input tensor.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function mean_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'mean');
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ return ENGINE.runKernel(Mean, inputs, attrs);
+ }
+ var mean = op({ mean_: mean_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with all elements set to 0.
+ *
+ * ```js
+ * tf.zeros([2, 2]).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param dtype The type of an element in the resulting tensor. Can
+ * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function zeros(shape, dtype) {
+ if (dtype === void 0) { dtype = 'float32'; }
+ if (dtype === 'complex64') {
+ var real = zeros(shape, 'float32');
+ var imag = zeros(shape, 'float32');
+ return complex(real, imag);
+ }
+ var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
+ return ENGINE.makeTensor(values, shape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with all elements set to 1.
+ *
+ * ```js
+ * tf.ones([2, 2]).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param dtype The type of an element in the resulting tensor. Defaults to
+ * 'float'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function ones(shape, dtype) {
+ if (dtype === void 0) { dtype = 'float32'; }
+ if (dtype === 'complex64') {
+ var real = ones(shape, 'float32');
+ var imag = zeros(shape, 'float32');
+ return complex(real, imag);
+ }
+ var values = makeOnesTypedArray(sizeFromShape(shape), dtype);
+ return ENGINE.makeTensor(values, shape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Broadcasts parameters for evaluation on an N-D grid.
+ *
+ * Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
+ * of N-D coordinate arrays for evaluating expressions on an N-D grid.
+ *
+ * Notes:
+ * `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
+ * When the `indexing` argument is set to 'xy' (the default), the broadcasting
+ * instructions for the first two dimensions are swapped.
+ * Examples:
+ * Calling `const [X, Y] = meshgrid(x, y)` with the tensors
+ *
+ * ```javascript
+ * const x = [1, 2, 3];
+ * const y = [4, 5, 6];
+ * const [X, Y] = tf.meshgrid(x, y);
+ * // X = [[1, 2, 3],
+ * // [1, 2, 3],
+ * // [1, 2, 3]]
+ * // Y = [[4, 4, 4],
+ * // [5, 5, 5],
+ * // [6, 6, 6]]
+ * ```
+ *
+ * @param x Tensor with rank geq 1.
+ * @param y Tensor with rank geq 1.
+ * @param indexing
+ *
+ * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
+ */
+ function meshgrid(x, y, _a) {
+ var _b = (_a === void 0 ? {} : _a).indexing, indexing = _b === void 0 ? 'xy' : _b;
+ if (indexing !== 'xy' && indexing !== 'ij') {
+ throw new TypeError(indexing + " is not a valid third argument to meshgrid");
+ }
+ if (x === undefined) {
+ return [];
+ }
+ var $x = convertToTensor(x, 'x', 'meshgrid', x instanceof Tensor ? x.dtype : 'float32');
+ if (y === undefined) {
+ return [$x];
+ }
+ var $y = convertToTensor(y, 'y', 'meshgrid', y instanceof Tensor ? y.dtype : 'float32');
+ var w = sizeFromShape($x.shape);
+ var h = sizeFromShape($y.shape);
+ if (indexing === 'xy') {
+ $x = reshape($x, [1, -1]);
+ $y = reshape($y, [-1, 1]);
+ return [
+ matMul$1(ones([h, 1], $x.dtype), $x),
+ matMul$1($y, ones([1, w], $y.dtype)),
+ ];
+ }
+ $x = reshape($x, [-1, 1]);
+ $y = reshape($y, [1, -1]);
+ return [
+ matMul$1($x, ones([1, h], $x.dtype)),
+ matMul$1(ones([w, 1], $y.dtype), $y),
+ ];
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the minimum value from the input.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the array is reduced by 1 for each entry in `axes`.
+ * If `keepDims` is true, the reduced dimensions are retained with length 1.
+ * If `axes` has no entries, all dimensions are reduced, and an array with a
+ * single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.min().print(); // or tf.min(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.min(axis).print(); // or tf.min(x, axis)
+ * ```
+ *
+ * @param x The input Tensor.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function min_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'min');
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ return ENGINE.runKernel(Min, inputs, attrs);
+ }
+ var min = op({ min_: min_ });
+
+ /**
+ * Returns the min of a and b (`a < b ? a : b`) element-wise.
+ * Supports broadcasting.
+ *
+ * We also expose `minimumStrict` which has the same signature as this op and
+ * asserts that `a` and `b` are the same shape (does not broadcast).
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 3, 16]);
+ * const b = tf.tensor1d([1, 2, 9, 4]);
+ *
+ * a.minimum(b).print(); // or tf.minimum(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast minimum a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(5);
+ *
+ * a.minimum(b).print(); // or tf.minimum(a, b)
+ * ```
+ *
+ * @param a The first tensor.
+ * @param b The second tensor. Must have the same type as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function minimum_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'minimum');
+ var $b = convertToTensor(b, 'b', 'minimum');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ if ($a.dtype === 'bool') {
+ $a = cast($a, 'int32');
+ $b = cast($b, 'int32');
+ }
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Minimum, inputs);
+ }
+ var minimum = op({ minimum_: minimum_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Pads a `tf.Tensor` using mirror padding.
+ *
+ * This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
+ *
+ * ```js
+ * const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
+ * x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
+ * ```
+ * @param x The tensor to pad.
+ * @param paddings An array of length `R` (the rank of the tensor), where
+ * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
+ * specifying how much to pad along each dimension of the tensor.
+ * In "reflect" mode, the padded regions do not include the borders,
+ * while in "symmetric" mode the padded regions do include the borders.
+ * For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
+ * then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
+ * `[1, 2, 3, 3, 2]` in "symmetric" mode.
+ * If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
+ * must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
+ * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
+ * `x.shape[D]`
+ * @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
+ */
+ /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
+ function mirrorPad_(x, paddings, mode) {
+ assert(mode === 'reflect' || mode === 'symmetric', function () { return "Invalid mode. Mode must be either reflect or symmetric. " +
+ ("Got " + mode + "."); });
+ var $x = convertToTensor(x, 'x', 'mirrorPad');
+ if ($x.rank === 0) {
+ throw new Error('mirrorPad(scalar) is not defined. ' +
+ 'Pass non-scalar to mirrorPad');
+ }
+ assert(paddings.length === $x.rank, function () { return "Padding doesn't match input. Must be " + $x.rank + ". " +
+ ("Got " + paddings.length + "."); });
+ var shapeOffset = mode === 'reflect' ? 1 : 0;
+ var _loop_1 = function (i) {
+ assert(paddings[i].length === 2, function () { return "Invalid number of paddings. Must be length of 2 each."; });
+ assert(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset &&
+ paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, function () { return "Padding in dimension " + i + " cannot be greater than or equal " +
+ ("to " + ($x.shape[i] - shapeOffset) + " or less than 0 for input of ") +
+ ("shape " + $x.shape); });
+ };
+ for (var i = 0; i < $x.rank; i++) {
+ _loop_1(i);
+ }
+ var attrs = { paddings: paddings, mode: mode };
+ var inputs = { x: $x };
+ return ENGINE.runKernel(MirrorPad, inputs, attrs);
+ }
+ var mirrorPad = op({ mirrorPad_: mirrorPad_ });
+
+ /**
+ * Returns the mod of a and b element-wise.
+ * `floor(x / y) * y + mod(x, y) = x`
+ * Supports broadcasting.
+ *
+ * We also expose `tf.modStrict` which has the same signature as this op and
+ * asserts that `a` and `b` are the same shape (does not broadcast).
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 3, 16]);
+ * const b = tf.tensor1d([1, 2, 9, 4]);
+ *
+ * a.mod(b).print(); // or tf.mod(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast a mod b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(5);
+ *
+ * a.mod(b).print(); // or tf.mod(a, b)
+ * ```
+ *
+ * @param a The first tensor.
+ * @param b The second tensor. Must have the same type as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function mod_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'mod');
+ var $b = convertToTensor(b, 'b', 'mod');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(Mod, inputs);
+ }
+ var mod = op({ mod_: mod_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes square of `x` element-wise: `x ^ 2`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
+ *
+ * x.square().print(); // or tf.square(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function square_(x) {
+ var $x = convertToTensor(x, 'x', 'square');
+ var attrs = {};
+ return ENGINE.runKernel('Square', { x: $x }, attrs);
+ }
+ var square = op({ square_: square_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Calculates the mean and variance of `x`. The mean and variance are
+ * calculated by aggregating the contents of `x` across `axes`. If `x` is
+ * 1-D and `axes = [0]` this is just the mean and variance of a vector.
+ *
+ * @param x The input tensor.
+ * @param axis The dimension(s) along with to compute mean and
+ * variance. By default it reduces all dimensions.
+ * @param keepDims If true, the moments have the same dimensionality as the
+ * input.
+ * @return An object with two keys: `mean` and `variance`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function moments_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ x = convertToTensor(x, 'x', 'moments');
+ var axes = parseAxisParam(axis, x.shape);
+ var xMean = mean(x, axes, keepDims);
+ var keepDimsShape = xMean.shape;
+ if (!keepDims) {
+ keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
+ }
+ var devSquared = square(sub(cast(x, 'float32'), reshape(xMean, keepDimsShape)));
+ var variance = mean(devSquared, axes, keepDims);
+ return { mean: xMean, variance: variance };
+ }
+ var moments = op({ moments_: moments_ });
+
+ /**
+ * Computes the next states and outputs of a stack of LSTMCells.
+ *
+ * Each cell output is used as input to the next cell.
+ *
+ * Returns `[cellState, cellOutput]`.
+ *
+ * Derived from tf.contrib.rn.MultiRNNCell.
+ *
+ * @param lstmCells Array of LSTMCell functions.
+ * @param data The input to the cell.
+ * @param c Array of previous cell states.
+ * @param h Array of previous cell outputs.
+ *
+ * @doc {heading: 'Operations', subheading: 'RNN'}
+ */
+ function multiRNNCell_(lstmCells, data, c, h) {
+ var $data = convertToTensor(data, 'data', 'multiRNNCell');
+ var $c = convertToTensorArray(c, 'c', 'multiRNNCell');
+ var $h = convertToTensorArray(h, 'h', 'multiRNNCell');
+ var input = $data;
+ var newStates = [];
+ for (var i = 0; i < lstmCells.length; i++) {
+ var output = lstmCells[i](input, $c[i], $h[i]);
+ newStates.push(output[0]);
+ newStates.push(output[1]);
+ input = output[1];
+ }
+ var newC = [];
+ var newH = [];
+ for (var i = 0; i < newStates.length; i += 2) {
+ newC.push(newStates[i]);
+ newH.push(newStates[i + 1]);
+ }
+ return [newC, newH];
+ }
+ var multiRNNCell = op({ multiRNNCell_: multiRNNCell_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
+ *
+ * ```js
+ * const probs = tf.tensor([.75, .25]);
+ * tf.multinomial(probs, 3).print();
+ * ```
+ *
+ * @param logits 1D array with unnormalized log-probabilities, or
+ * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
+ * parameter.
+ * @param numSamples Number of samples to draw for each row slice.
+ * @param seed The seed number.
+ * @param normalized Whether the provided `logits` are normalized true
+ * probabilities (sum to 1). Defaults to false.
+ * @return 1D array of shape `[numSamples]`, or 2D array of shape
+ * `[batchSize, numSamples]`, depending on the rank of the input.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Random'}
+ */
+ function multinomial_(logits, numSamples, seed, normalized) {
+ if (normalized === void 0) { normalized = false; }
+ var $logits = convertToTensor(logits, 'logits', 'multinomial');
+ var numOutcomes = $logits.size;
+ var origRank = $logits.rank;
+ if (numOutcomes < 2) {
+ throw new Error("Error in multinomial: you need at least 2 outcomes, but got " +
+ (numOutcomes + "."));
+ }
+ if (origRank > 2) {
+ throw new Error("Rank of probabilities must be 1 or 2, but is " + origRank);
+ }
+ // TODO(lina128): Investigate correct seed behavior. The code seems not allow
+ // setting see to 0.
+ seed = seed || Math.random();
+ // The kernel only accepts (and returns) rank 2 tensors.
+ var logits2D = origRank === 1 ? reshape($logits, [1, -1]) : $logits;
+ var inputs = { logits: logits2D };
+ var attrs = { numSamples: numSamples, seed: seed, normalized: normalized };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(Multinomial, inputs, attrs);
+ // tslint:disable-next-line:no-unnecessary-type-assertion
+ return origRank === 1 ? reshape(res, [res.size]) : res;
+ }
+ var multinomial = op({ multinomial_: multinomial_ });
+
+ /**
+ * Returns the truth value of (a != b) element-wise. Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([0, 2, 3]);
+ *
+ * a.notEqual(b).print();
+ * ```
+ * @param a The first input tensor.
+ * @param b The second input tensor. Must have the same dtype as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function notEqual_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
+ var $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ return ENGINE.runKernel(NotEqual, inputs);
+ }
+ var notEqual = op({ notEqual_: notEqual_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
+ * given tensor.
+ *
+ * ```js
+ * const x = tf.tensor([1, 2]);
+ * tf.onesLike(x).print();
+ * ```
+ * @param x A tensor.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function onesLike_(x) {
+ var $x = convertToTensor(x, 'x', 'onesLike');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(OnesLike, inputs);
+ }
+ var onesLike = op({ onesLike_: onesLike_ });
+
+ /**
+ * Computes the outer product of two vectors, `v1` and `v2`.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ * const b = tf.tensor1d([3, 4, 5]);
+ *
+ * tf.outerProduct(a, b).print();
+ * ```
+ * @param v1 The first vector in the outer product operation.
+ * @param v2 The second vector in the outer product operation.
+ *
+ * @doc {heading: 'Operations', subheading: 'Matrices'}
+ */
+ function outerProduct_(v1, v2) {
+ var $v1 = convertToTensor(v1, 'v1', 'outerProduct');
+ var $v2 = convertToTensor(v2, 'v2', 'outerProduct');
+ assert($v1.rank === 1 && $v2.rank === 1, function () { return "Error in outerProduct: inputs must be rank 1, but got ranks " +
+ ($v1.rank + " and " + $v2.rank + "."); });
+ var v12D = reshape($v1, [-1, 1]);
+ var v22D = reshape($v2, [1, -1]);
+ return matMul$1(v12D, v22D);
+ }
+ var outerProduct = op({ outerProduct_: outerProduct_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Pads a `tf.Tensor` with a given value and paddings.
+ *
+ * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
+ * refer to `tf.mirrorPad`
+ *
+ * Also available are stricter rank-specific methods with the same signature
+ * as this method that assert that `paddings` is of given length.
+ * - `tf.pad1d`
+ * - `tf.pad2d`
+ * - `tf.pad3d`
+ * - `tf.pad4d`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ * x.pad([[1, 2]]).print();
+ * ```
+ * @param x The tensor to pad.
+ * @param paddings An array of length `R` (the rank of the tensor), where
+ * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
+ * specifying how much to pad along each dimension of the tensor.
+ * @param constantValue The pad value to use. Defaults to 0.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function pad_(x, paddings, constantValue) {
+ if (constantValue === void 0) { constantValue = 0; }
+ var $x = convertToTensor(x, 'x', 'pad');
+ if ($x.rank === 0) {
+ throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
+ }
+ var attrs = { paddings: paddings, constantValue: constantValue };
+ var inputs = { x: $x };
+ return ENGINE.runKernel(PadV2, inputs, attrs);
+ }
+ var pad = op({ pad_: pad_ });
+
+ /**
+ * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
+ */
+ function pad1d_(x, paddings, constantValue) {
+ if (constantValue === void 0) { constantValue = 0; }
+ assert(paddings.length === 2, function () { return 'Invalid number of paddings. Must be length of 2.'; });
+ return pad(x, [paddings], constantValue);
+ }
+ var pad1d = op({ pad1d_: pad1d_ });
+
+ /**
+ * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
+ */
+ function pad2d_(x, paddings, constantValue) {
+ if (constantValue === void 0) { constantValue = 0; }
+ assert(paddings.length === 2 && paddings[0].length === 2 &&
+ paddings[1].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; });
+ return pad(x, paddings, constantValue);
+ }
+ var pad2d = op({ pad2d_: pad2d_ });
+
+ /**
+ * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
+ */
+ function pad3d_(x, paddings, constantValue) {
+ if (constantValue === void 0) { constantValue = 0; }
+ assert(paddings.length === 3 && paddings[0].length === 2 &&
+ paddings[1].length === 2 && paddings[2].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; });
+ return pad(x, paddings, constantValue);
+ }
+ var pad3d = op({ pad3d_: pad3d_ });
+
+ /**
+ * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
+ */
+ function pad4d_(x, paddings, constantValue) {
+ if (constantValue === void 0) { constantValue = 0; }
+ assert(paddings.length === 4 && paddings[0].length === 2 &&
+ paddings[1].length === 2 && paddings[2].length === 2 &&
+ paddings[3].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; });
+ return pad(x, paddings, constantValue);
+ }
+ var pad4d = op({ pad4d_: pad4d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
+ * a grid of blocks of shape `blockShape`, and interleaves these blocks with
+ * the "batch" dimension (0) such that in the output, the spatial
+ * dimensions `[1, ..., M]` correspond to the position within the grid,
+ * and the batch dimension combines both the position within a spatial block
+ * and the original batch position. Prior to division into blocks,
+ * the spatial dimensions of the input are optionally zero padded
+ * according to `paddings`. See below for a precise description.
+ *
+ * ```js
+ * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
+ * const blockShape = [2, 2];
+ * const paddings = [[0, 0], [0, 0]];
+ *
+ * x.spaceToBatchND(blockShape, paddings).print();
+ * ```
+ *
+ * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
+ * remainingShape`, where spatialShape has `M` dimensions.
+ * @param blockShape A 1-D array. Must have shape `[M]`, all values must
+ * be >= 1.
+ * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
+ * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
+ * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
+ * is required that
+ * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
+ *
+ * This operation is equivalent to the following steps:
+ *
+ * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
+ * according to `paddings` to produce `padded` of shape paddedShape.
+ *
+ * 2. Reshape `padded` to `reshapedPadded` of shape:
+ * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
+ * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
+ *
+ * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
+ * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
+ * paddedShape[M] / blockShape[M-1]] + remainingShape`
+ *
+ * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
+ * batch dimension, producing an output tensor of shape:
+ * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
+ * paddedShape[M] / blockShape[M-1]] + remainingShape`
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function spaceToBatchND_(x, blockShape, paddings) {
+ var $x = convertToTensor(x, 'x', 'spaceToBatchND');
+ assert($x.rank >= 1 + blockShape.length, function () { return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length; });
+ assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length; });
+ assert($x.shape.reduce(function (a, b, i) {
+ if (i > 0 && i <= blockShape.length) {
+ return a &&
+ ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
+ blockShape[i - 1] ===
+ 0);
+ }
+ return a;
+ }, true), function () { return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString(); });
+ var inputs = { x: $x };
+ var attrs = { blockShape: blockShape, paddings: paddings };
+ return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
+ }
+ var spaceToBatchND = op({ spaceToBatchND_: spaceToBatchND_ });
+
+ /**
+ * Performs an N-D pooling operation
+ *
+ * @param input The input tensor, of rank 4 or rank 3 of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param windowShape The filter size: `[filterHeight, filterWidth]`. If
+ * `filterSize` is a single number, then `filterHeight == filterWidth`.
+ * @param poolingType The type of pooling, either 'max' or 'avg'.
+ * @param pad The type of padding algorithm:
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_guides/python/nn#Convolution](
+ * https://www.tensorflow.org/api_guides/python/nn#Convolution)
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
+ * `strides` is a single number, then `strideHeight == strideWidth`.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function pool_(input, windowShape, poolingType, pad, dilations, strides, dimRoundingMode) {
+ if (dilations == null) {
+ dilations = [1, 1];
+ }
+ if (strides == null) {
+ strides = 1;
+ }
+ if (pad === 0) {
+ pad = 'valid';
+ }
+ var $x = convertToTensor(input, 'x', 'maxPool');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in pool: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ var convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
+ var dilation = [convInfo.dilationHeight, convInfo.dilationWidth];
+ // The following implementation does batchToSpace(pool(spaceToBatch(x)))
+ // whenever dilation > 1 since the TF kernels do not support dilation > 1.
+ // tslint:disable-next-line:max-line-length
+ // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
+ var basePadding;
+ if (pad === 'same') {
+ basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
+ }
+ else {
+ basePadding = [[0, 0], [0, 0]];
+ }
+ var isDilationOne = dilation[0] === 1 && dilation[1] === 1;
+ var _a = __read(requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding), 2), adjustedPadding = _a[0], adjustedCrops = _a[1];
+ var convertedPad = isDilationOne ? pad : 'valid';
+ var convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);
+ var forwardOp = poolingType === 'avg' ?
+ function () { return avgPool(convertedX, windowShape, strides, convertedPad, dimRoundingMode); } :
+ function () { return maxPool(convertedX, windowShape, strides, convertedPad, dimRoundingMode); };
+ var y = forwardOp();
+ var res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ // Helper function to compute crops and paddings for pool with dilation > 1.
+ // tslint:disable-next-line:max-line-length
+ // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
+ function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
+ var padStart = basePadding.map(function (b) { return b[0]; });
+ var origPadEnd = basePadding.map(function (b) { return b[1]; });
+ var fullInputShape = inputShape.concat(padStart, origPadEnd);
+ var padEndExtra = blockShape.map(function (b, i) { return (b - fullInputShape[i] % b) % b; });
+ var padEnd = origPadEnd.map(function (s, i) { return s + padEndExtra[i]; });
+ var paddings = blockShape.map(function (_, i) { return [padStart[i], padEnd[i]]; });
+ var crops = blockShape.map(function (_, i) { return [0, padEndExtra[i]]; });
+ return [paddings, crops];
+ }
+ // Helper function to compute base paddings for pool with dilation > 1.
+ // tslint:disable-next-line:max-line-length
+ // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
+ function withSpaceToBatchBasePaddings(filterShape, dilation) {
+ // Spatial dimensions of the filters and the upsampled filters in which we
+ // introduce (rate - 1) zeros between consecutive filter values.
+ var dilatedFilterShape = filterShape.map(function (s, i) {
+ return s + (s - 1) * (dilation[i] - 1);
+ });
+ var padExtraShape = dilatedFilterShape.map(function (s) { return s - 1; });
+ // When padding is odd, we pad more at end, following the same
+ // convention as conv2d.
+ var padExtraStart = padExtraShape.map(function (s) { return Math.floor(s / 2); });
+ var padExtraEnd = padExtraShape.map(function (s, i) { return s - padExtraStart[i]; });
+ return padExtraShape.map(function (_, i) {
+ return [padExtraStart[i], padExtraEnd[i]];
+ });
+ }
+ var pool = op({ pool_: pool_ });
+
+ /**
+ * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
+ *
+ * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
+ * corresponding elements in x and y. The result's dtype will be the upcasted
+ * type of the `base` and `exp` dtypes.
+ *
+ * ```js
+ * const a = tf.tensor([[2, 3], [4, 5]])
+ * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
+ *
+ * a.pow(b).print(); // or tf.pow(a, b)
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor([[1, 2], [3, 4]])
+ * const b = tf.tensor(2).toInt();
+ *
+ * a.pow(b).print(); // or tf.pow(a, b)
+ * ```
+ * We also expose `powStrict` which has the same signature as this op and
+ * asserts that `base` and `exp` are the same shape (does not broadcast).
+ *
+ * @param base The base `tf.Tensor` to pow element-wise.
+ * @param exp The exponent `tf.Tensor` to pow element-wise.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function pow_(base, exp) {
+ var _a;
+ var $base = convertToTensor(base, 'base', 'pow');
+ var $exp = convertToTensor(exp, 'exp', 'pow');
+ _a = __read(makeTypesMatch($base, $exp), 2), $base = _a[0], $exp = _a[1];
+ var inputs = { a: $base, b: $exp };
+ return ENGINE.runKernel(Pow, inputs);
+ }
+ var pow = op({ pow_: pow_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes leaky rectified linear element-wise with parametric alphas.
+ *
+ * `x < 0 ? alpha * x : f(x) = x`
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ * const alpha = tf.scalar(0.1);
+ *
+ * x.prelu(alpha).print(); // or tf.prelu(x, alpha)
+ * ```
+ * @param x The input tensor.
+ * @param alpha Scaling factor for negative values.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function prelu_(x, alpha) {
+ var $x = convertToTensor(x, 'x', 'prelu');
+ var $alpha = convertToTensor(alpha, 'alpha', 'prelu');
+ var inputs = { x: $x, alpha: $alpha };
+ return ENGINE.runKernel(Prelu, inputs);
+ }
+ var prelu = op({ prelu_: prelu_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the product of elements across dimensions of a `tf.Tensor`.
+ *
+ * Reduces the input along the dimensions given in `axes`. Unless `keepDims`
+ * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
+ * `axes`. If `keepDims` is true, the reduced dimensions are retained with
+ * length 1. If `axes` has no entries, all dimensions are reduced, and a
+ * `tf.Tensor` with a single element is returned.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3]);
+ *
+ * x.prod().print(); // or tf.prod(x)
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.prod(axis).print(); // or tf.prod(x, axis)
+ * ```
+ *
+ * @param x The input tensor to compute the product over. If the dtype is `bool`
+ * it will be converted to `int32` and the output dtype will be `int32`.
+ * @param axis The dimension(s) to reduce. By default it reduces
+ * all dimensions.
+ * @param keepDims If true, retains reduced dimensions with size 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Reduction'}
+ */
+ function prod_(x, axis, keepDims) {
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ var $x = convertToTensor(x, 'x', 'prod');
+ if ($x.dtype === 'bool') {
+ // bool is not an allowed type for the underlying kernel.
+ $x = cast($x, 'int32');
+ }
+ var inputs = { x: $x };
+ var attrs = { axis: axis, keepDims: keepDims };
+ return ENGINE.runKernel(Prod, inputs, attrs);
+ }
+ var prod = op({ prod_: prod_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values sampled from a random number generator
+ * function defined by the user.
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param randFunction A random number generator function which is called
+ * for each element in the output tensor.
+ * @param dtype The data type of the output tensor. Defaults to 'float32'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Random'}
+ */
+ function rand_(shape, randFunction, dtype) {
+ var size = sizeFromShape(shape);
+ var values = null;
+ if (dtype == null || dtype === 'float32') {
+ values = new Float32Array(size);
+ }
+ else if (dtype === 'int32') {
+ values = new Int32Array(size);
+ }
+ else if (dtype === 'bool') {
+ values = new Uint8Array(size);
+ }
+ else {
+ throw new Error("Unknown data type " + dtype);
+ }
+ for (var i = 0; i < size; i++) {
+ values[i] = randFunction();
+ }
+ return ENGINE.makeTensor(values, shape, dtype);
+ }
+ var rand = op({ rand_: rand_ });
+
+ var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
+ function createCommonjsModule(fn) {
+ var module = { exports: {} };
+ return fn(module, module.exports), module.exports;
+ }
+
+ var alea = createCommonjsModule(function (module) {
+ // A port of an algorithm by Johannes Baagøe <[email protected]>, 2010
+ // http://baagoe.com/en/RandomMusings/javascript/
+ // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
+ // Original work is under MIT license -
+ // Copyright (C) 2010 by Johannes Baagøe <[email protected]>
+ //
+ // Permission is hereby granted, free of charge, to any person obtaining a copy
+ // of this software and associated documentation files (the "Software"), to deal
+ // in the Software without restriction, including without limitation the rights
+ // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ // copies of the Software, and to permit persons to whom the Software is
+ // furnished to do so, subject to the following conditions:
+ //
+ // The above copyright notice and this permission notice shall be included in
+ // all copies or substantial portions of the Software.
+ //
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ // THE SOFTWARE.
+ (function (global, module, define) {
+ function Alea(seed) {
+ var me = this, mash = Mash();
+ me.next = function () {
+ var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
+ me.s0 = me.s1;
+ me.s1 = me.s2;
+ return me.s2 = t - (me.c = t | 0);
+ };
+ // Apply the seeding algorithm from Baagoe.
+ me.c = 1;
+ me.s0 = mash(' ');
+ me.s1 = mash(' ');
+ me.s2 = mash(' ');
+ me.s0 -= mash(seed);
+ if (me.s0 < 0) {
+ me.s0 += 1;
+ }
+ me.s1 -= mash(seed);
+ if (me.s1 < 0) {
+ me.s1 += 1;
+ }
+ me.s2 -= mash(seed);
+ if (me.s2 < 0) {
+ me.s2 += 1;
+ }
+ mash = null;
+ }
+ function copy(f, t) {
+ t.c = f.c;
+ t.s0 = f.s0;
+ t.s1 = f.s1;
+ t.s2 = f.s2;
+ return t;
+ }
+ function impl(seed, opts) {
+ var xg = new Alea(seed), state = opts && opts.state, prng = xg.next;
+ prng.int32 = function () { return (xg.next() * 0x100000000) | 0; };
+ prng.double = function () {
+ return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
+ };
+ prng.quick = prng;
+ if (state) {
+ if (typeof (state) == 'object')
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ function Mash() {
+ var n = 0xefc8249d;
+ var mash = function (data) {
+ data = String(data);
+ for (var i = 0; i < data.length; i++) {
+ n += data.charCodeAt(i);
+ var h = 0.02519603282416938 * n;
+ n = h >>> 0;
+ h -= n;
+ h *= n;
+ n = h >>> 0;
+ h -= n;
+ n += h * 0x100000000; // 2^32
+ }
+ return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
+ };
+ return mash;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.alea = impl;
+ }
+ })(commonjsGlobal, module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ var xor128 = createCommonjsModule(function (module) {
+ // A Javascript implementaion of the "xor128" prng algorithm by
+ // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
+ (function (global, module, define) {
+ function XorGen(seed) {
+ var me = this, strseed = '';
+ me.x = 0;
+ me.y = 0;
+ me.z = 0;
+ me.w = 0;
+ // Set up generator function.
+ me.next = function () {
+ var t = me.x ^ (me.x << 11);
+ me.x = me.y;
+ me.y = me.z;
+ me.z = me.w;
+ return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
+ };
+ if (seed === (seed | 0)) {
+ // Integer seed.
+ me.x = seed;
+ }
+ else {
+ // String seed.
+ strseed += seed;
+ }
+ // Mix in string seed, then discard an initial batch of 64 values.
+ for (var k = 0; k < strseed.length + 64; k++) {
+ me.x ^= strseed.charCodeAt(k) | 0;
+ me.next();
+ }
+ }
+ function copy(f, t) {
+ t.x = f.x;
+ t.y = f.y;
+ t.z = f.z;
+ t.w = f.w;
+ return t;
+ }
+ function impl(seed, opts) {
+ var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; };
+ prng.double = function () {
+ do {
+ var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21);
+ } while (result === 0);
+ return result;
+ };
+ prng.int32 = xg.next;
+ prng.quick = prng;
+ if (state) {
+ if (typeof (state) == 'object')
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.xor128 = impl;
+ }
+ })(commonjsGlobal, module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ var xorwow = createCommonjsModule(function (module) {
+ // A Javascript implementaion of the "xorwow" prng algorithm by
+ // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
+ (function (global, module, define) {
+ function XorGen(seed) {
+ var me = this, strseed = '';
+ // Set up generator function.
+ me.next = function () {
+ var t = (me.x ^ (me.x >>> 2));
+ me.x = me.y;
+ me.y = me.z;
+ me.z = me.w;
+ me.w = me.v;
+ return (me.d = (me.d + 362437 | 0)) +
+ (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
+ };
+ me.x = 0;
+ me.y = 0;
+ me.z = 0;
+ me.w = 0;
+ me.v = 0;
+ if (seed === (seed | 0)) {
+ // Integer seed.
+ me.x = seed;
+ }
+ else {
+ // String seed.
+ strseed += seed;
+ }
+ // Mix in string seed, then discard an initial batch of 64 values.
+ for (var k = 0; k < strseed.length + 64; k++) {
+ me.x ^= strseed.charCodeAt(k) | 0;
+ if (k == strseed.length) {
+ me.d = me.x << 10 ^ me.x >>> 4;
+ }
+ me.next();
+ }
+ }
+ function copy(f, t) {
+ t.x = f.x;
+ t.y = f.y;
+ t.z = f.z;
+ t.w = f.w;
+ t.v = f.v;
+ t.d = f.d;
+ return t;
+ }
+ function impl(seed, opts) {
+ var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; };
+ prng.double = function () {
+ do {
+ var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21);
+ } while (result === 0);
+ return result;
+ };
+ prng.int32 = xg.next;
+ prng.quick = prng;
+ if (state) {
+ if (typeof (state) == 'object')
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.xorwow = impl;
+ }
+ })(commonjsGlobal, module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ var xorshift7 = createCommonjsModule(function (module) {
+ // A Javascript implementaion of the "xorshift7" algorithm by
+ // François Panneton and Pierre L'ecuyer:
+ // "On the Xorgshift Random Number Generators"
+ // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
+ (function (global, module, define) {
+ function XorGen(seed) {
+ var me = this;
+ // Set up generator function.
+ me.next = function () {
+ // Update xor generator.
+ var X = me.x, i = me.i, t, v;
+ t = X[i];
+ t ^= (t >>> 7);
+ v = t ^ (t << 24);
+ t = X[(i + 1) & 7];
+ v ^= t ^ (t >>> 10);
+ t = X[(i + 3) & 7];
+ v ^= t ^ (t >>> 3);
+ t = X[(i + 4) & 7];
+ v ^= t ^ (t << 7);
+ t = X[(i + 7) & 7];
+ t = t ^ (t << 13);
+ v ^= t ^ (t << 9);
+ X[i] = v;
+ me.i = (i + 1) & 7;
+ return v;
+ };
+ function init(me, seed) {
+ var j, X = [];
+ if (seed === (seed | 0)) {
+ // Seed state array using a 32-bit integer.
+ X[0] = seed;
+ }
+ else {
+ // Seed state using a string.
+ seed = '' + seed;
+ for (j = 0; j < seed.length; ++j) {
+ X[j & 7] = (X[j & 7] << 15) ^
+ (seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
+ }
+ }
+ // Enforce an array length of 8, not all zeroes.
+ while (X.length < 8)
+ X.push(0);
+ for (j = 0; j < 8 && X[j] === 0; ++j)
+ ;
+ if (j == 8)
+ X[7] = -1;
+ me.x = X;
+ me.i = 0;
+ // Discard an initial 256 values.
+ for (j = 256; j > 0; --j) {
+ me.next();
+ }
+ }
+ init(me, seed);
+ }
+ function copy(f, t) {
+ t.x = f.x.slice();
+ t.i = f.i;
+ return t;
+ }
+ function impl(seed, opts) {
+ if (seed == null)
+ seed = +(new Date);
+ var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; };
+ prng.double = function () {
+ do {
+ var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21);
+ } while (result === 0);
+ return result;
+ };
+ prng.int32 = xg.next;
+ prng.quick = prng;
+ if (state) {
+ if (state.x)
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.xorshift7 = impl;
+ }
+ })(commonjsGlobal, module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ var xor4096 = createCommonjsModule(function (module) {
+ // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
+ //
+ // This fast non-cryptographic random number generator is designed for
+ // use in Monte-Carlo algorithms. It combines a long-period xorshift
+ // generator with a Weyl generator, and it passes all common batteries
+ // of stasticial tests for randomness while consuming only a few nanoseconds
+ // for each prng generated. For background on the generator, see Brent's
+ // paper: "Some long-period random number generators using shifts and xors."
+ // http://arxiv.org/pdf/1004.3115v1.pdf
+ //
+ // Usage:
+ //
+ // var xor4096 = require('xor4096');
+ // random = xor4096(1); // Seed with int32 or string.
+ // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
+ // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
+ //
+ // For nonzero numeric keys, this impelementation provides a sequence
+ // identical to that by Brent's xorgens 3 implementaion in C. This
+ // implementation also provides for initalizing the generator with
+ // string seeds, or for saving and restoring the state of the generator.
+ //
+ // On Chrome, this prng benchmarks about 2.1 times slower than
+ // Javascript's built-in Math.random().
+ (function (global, module, define) {
+ function XorGen(seed) {
+ var me = this;
+ // Set up generator function.
+ me.next = function () {
+ var w = me.w, X = me.X, i = me.i, t, v;
+ // Update Weyl generator.
+ me.w = w = (w + 0x61c88647) | 0;
+ // Update xor generator.
+ v = X[(i + 34) & 127];
+ t = X[i = ((i + 1) & 127)];
+ v ^= v << 13;
+ t ^= t << 17;
+ v ^= v >>> 15;
+ t ^= t >>> 12;
+ // Update Xor generator array state.
+ v = X[i] = v ^ t;
+ me.i = i;
+ // Result is the combination.
+ return (v + (w ^ (w >>> 16))) | 0;
+ };
+ function init(me, seed) {
+ var t, v, i, j, w, X = [], limit = 128;
+ if (seed === (seed | 0)) {
+ // Numeric seeds initialize v, which is used to generates X.
+ v = seed;
+ seed = null;
+ }
+ else {
+ // String seeds are mixed into v and X one character at a time.
+ seed = seed + '\0';
+ v = 0;
+ limit = Math.max(limit, seed.length);
+ }
+ // Initialize circular array and weyl value.
+ for (i = 0, j = -32; j < limit; ++j) {
+ // Put the unicode characters into the array, and shuffle them.
+ if (seed)
+ v ^= seed.charCodeAt((j + 32) % seed.length);
+ // After 32 shuffles, take v as the starting w value.
+ if (j === 0)
+ w = v;
+ v ^= v << 10;
+ v ^= v >>> 15;
+ v ^= v << 4;
+ v ^= v >>> 13;
+ if (j >= 0) {
+ w = (w + 0x61c88647) | 0; // Weyl.
+ t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array.
+ i = (0 == t) ? i + 1 : 0; // Count zeroes.
+ }
+ }
+ // We have detected all zeroes; make the key nonzero.
+ if (i >= 128) {
+ X[(seed && seed.length || 0) & 127] = -1;
+ }
+ // Run the generator 512 times to further mix the state before using it.
+ // Factoring this as a function slows the main generator, so it is just
+ // unrolled here. The weyl generator is not advanced while warming up.
+ i = 127;
+ for (j = 4 * 128; j > 0; --j) {
+ v = X[(i + 34) & 127];
+ t = X[i = ((i + 1) & 127)];
+ v ^= v << 13;
+ t ^= t << 17;
+ v ^= v >>> 15;
+ t ^= t >>> 12;
+ X[i] = v ^ t;
+ }
+ // Storing state as object members is faster than using closure variables.
+ me.w = w;
+ me.X = X;
+ me.i = i;
+ }
+ init(me, seed);
+ }
+ function copy(f, t) {
+ t.i = f.i;
+ t.w = f.w;
+ t.X = f.X.slice();
+ return t;
+ }
+ function impl(seed, opts) {
+ if (seed == null)
+ seed = +(new Date);
+ var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; };
+ prng.double = function () {
+ do {
+ var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21);
+ } while (result === 0);
+ return result;
+ };
+ prng.int32 = xg.next;
+ prng.quick = prng;
+ if (state) {
+ if (state.X)
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.xor4096 = impl;
+ }
+ })(commonjsGlobal, // window object or global
+ module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ var tychei = createCommonjsModule(function (module) {
+ // A Javascript implementaion of the "Tyche-i" prng algorithm by
+ // Samuel Neves and Filipe Araujo.
+ // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
+ (function (global, module, define) {
+ function XorGen(seed) {
+ var me = this, strseed = '';
+ // Set up generator function.
+ me.next = function () {
+ var b = me.b, c = me.c, d = me.d, a = me.a;
+ b = (b << 25) ^ (b >>> 7) ^ c;
+ c = (c - d) | 0;
+ d = (d << 24) ^ (d >>> 8) ^ a;
+ a = (a - b) | 0;
+ me.b = b = (b << 20) ^ (b >>> 12) ^ c;
+ me.c = c = (c - d) | 0;
+ me.d = (d << 16) ^ (c >>> 16) ^ a;
+ return me.a = (a - b) | 0;
+ };
+ /* The following is non-inverted tyche, which has better internal
+ * bit diffusion, but which is about 25% slower than tyche-i in JS.
+ me.next = function() {
+ var a = me.a, b = me.b, c = me.c, d = me.d;
+ a = (me.a + me.b | 0) >>> 0;
+ d = me.d ^ a; d = d << 16 ^ d >>> 16;
+ c = me.c + d | 0;
+ b = me.b ^ c; b = b << 12 ^ d >>> 20;
+ me.a = a = a + b | 0;
+ d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
+ me.c = c = c + d | 0;
+ b = b ^ c;
+ return me.b = (b << 7 ^ b >>> 25);
+ }
+ */
+ me.a = 0;
+ me.b = 0;
+ me.c = 2654435769 | 0;
+ me.d = 1367130551;
+ if (seed === Math.floor(seed)) {
+ // Integer seed.
+ me.a = (seed / 0x100000000) | 0;
+ me.b = seed | 0;
+ }
+ else {
+ // String seed.
+ strseed += seed;
+ }
+ // Mix in string seed, then discard an initial batch of 64 values.
+ for (var k = 0; k < strseed.length + 20; k++) {
+ me.b ^= strseed.charCodeAt(k) | 0;
+ me.next();
+ }
+ }
+ function copy(f, t) {
+ t.a = f.a;
+ t.b = f.b;
+ t.c = f.c;
+ t.d = f.d;
+ return t;
+ }
+ function impl(seed, opts) {
+ var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; };
+ prng.double = function () {
+ do {
+ var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21);
+ } while (result === 0);
+ return result;
+ };
+ prng.int32 = xg.next;
+ prng.quick = prng;
+ if (state) {
+ if (typeof (state) == 'object')
+ copy(state, xg);
+ prng.state = function () { return copy(xg, {}); };
+ }
+ return prng;
+ }
+ if (module && module.exports) {
+ module.exports = impl;
+ }
+ else if (define && define.amd) {
+ define(function () { return impl; });
+ }
+ else {
+ this.tychei = impl;
+ }
+ })(commonjsGlobal, module, // present in node.js
+ (typeof undefined) == 'function' // present with an AMD loader
+ );
+ });
+
+ /*
+ Copyright 2019 David Bau.
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+ */
+ var seedrandom$1 = createCommonjsModule(function (module) {
+ (function (global, pool, math) {
+ //
+ // The following constants are related to IEEE 754 limits.
+ //
+ var width = 256, // each RC4 output is 0 <= x < 256
+ chunks = 6, // at least six RC4 outputs for each double
+ digits = 52, // there are 52 significant digits in a double
+ rngname = 'random', // rngname: name for Math.random and Math.seedrandom
+ startdenom = math.pow(width, chunks), significance = math.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto; // node.js crypto module, initialized at the bottom.
+ //
+ // seedrandom()
+ // This is the seedrandom function described above.
+ //
+ function seedrandom(seed, options, callback) {
+ var key = [];
+ options = (options == true) ? { entropy: true } : (options || {});
+ // Flatten the seed string or build one from local entropy if needed.
+ var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] :
+ (seed == null) ? autoseed() : seed, 3), key);
+ // Use the seed to initialize an ARC4 generator.
+ var arc4 = new ARC4(key);
+ // This function returns a random double in [0, 1) that contains
+ // randomness in every bit of the mantissa of the IEEE 754 value.
+ var prng = function () {
+ var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48
+ d = startdenom, // and denominator d = 2 ^ 48.
+ x = 0; // and no 'extra last byte'.
+ while (n < significance) { // Fill up all significant digits by
+ n = (n + x) * width; // shifting numerator and
+ d *= width; // denominator and generating a
+ x = arc4.g(1); // new least-significant-byte.
+ }
+ while (n >= overflow) { // To avoid rounding up, before adding
+ n /= 2; // last byte, shift everything
+ d /= 2; // right using integer math until
+ x >>>= 1; // we have exactly the desired bits.
+ }
+ return (n + x) / d; // Form the number within [0, 1).
+ };
+ prng.int32 = function () { return arc4.g(4) | 0; };
+ prng.quick = function () { return arc4.g(4) / 0x100000000; };
+ prng.double = prng;
+ // Mix the randomness into accumulated entropy.
+ mixkey(tostring(arc4.S), pool);
+ // Calling convention: what to return as a function of prng, seed, is_math.
+ return (options.pass || callback ||
+ function (prng, seed, is_math_call, state) {
+ if (state) {
+ // Load the arc4 state from the given state if it has an S array.
+ if (state.S) {
+ copy(state, arc4);
+ }
+ // Only provide the .state method if requested via options.state.
+ prng.state = function () { return copy(arc4, {}); };
+ }
+ // If called as a method of Math (Math.seedrandom()), mutate
+ // Math.random because that is how seedrandom.js has worked since v1.0.
+ if (is_math_call) {
+ math[rngname] = prng;
+ return seed;
+ }
+ // Otherwise, it is a newer calling convention, so return the
+ // prng directly.
+ else
+ return prng;
+ })(prng, shortseed, 'global' in options ? options.global : (this == math), options.state);
+ }
+ //
+ // ARC4
+ //
+ // An ARC4 implementation. The constructor takes a key in the form of
+ // an array of at most (width) integers that should be 0 <= x < (width).
+ //
+ // The g(count) method returns a pseudorandom integer that concatenates
+ // the next (count) outputs from ARC4. Its return value is a number x
+ // that is in the range 0 <= x < (width ^ count).
+ //
+ function ARC4(key) {
+ var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
+ // The empty key [] is treated as [0].
+ if (!keylen) {
+ key = [keylen++];
+ }
+ // Set up S using the standard key scheduling algorithm.
+ while (i < width) {
+ s[i] = i++;
+ }
+ for (i = 0; i < width; i++) {
+ s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
+ s[j] = t;
+ }
+ // The "g" method returns the next (count) outputs as one number.
+ (me.g = function (count) {
+ // Using instance members instead of closure state nearly doubles speed.
+ var t, r = 0, i = me.i, j = me.j, s = me.S;
+ while (count--) {
+ t = s[i = mask & (i + 1)];
+ r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
+ }
+ me.i = i;
+ me.j = j;
+ return r;
+ // For robust unpredictability, the function call below automatically
+ // discards an initial batch of values. This is called RC4-drop[256].
+ // See http://google.com/search?q=rsa+fluhrer+response&btnI
+ })(width);
+ }
+ //
+ // copy()
+ // Copies internal state of ARC4 to or from a plain object.
+ //
+ function copy(f, t) {
+ t.i = f.i;
+ t.j = f.j;
+ t.S = f.S.slice();
+ return t;
+ }
+ //
+ // flatten()
+ // Converts an object tree to nested arrays of strings.
+ //
+ function flatten(obj, depth) {
+ var result = [], typ = (typeof obj), prop;
+ if (depth && typ == 'object') {
+ for (prop in obj) {
+ try {
+ result.push(flatten(obj[prop], depth - 1));
+ }
+ catch (e) { }
+ }
+ }
+ return (result.length ? result : typ == 'string' ? obj : obj + '\0');
+ }
+ //
+ // mixkey()
+ // Mixes a string seed into a key that is an array of integers, and
+ // returns a shortened string seed that is equivalent to the result key.
+ //
+ function mixkey(seed, key) {
+ var stringseed = seed + '', smear, j = 0;
+ while (j < stringseed.length) {
+ key[mask & j] =
+ mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
+ }
+ return tostring(key);
+ }
+ //
+ // autoseed()
+ // Returns an object for autoseeding, using window.crypto and Node crypto
+ // module if available.
+ //
+ function autoseed() {
+ try {
+ var out;
+ if (nodecrypto && (out = nodecrypto.randomBytes)) {
+ // The use of 'out' to remember randomBytes makes tight minified code.
+ out = out(width);
+ }
+ else {
+ out = new Uint8Array(width);
+ (global.crypto || global.msCrypto).getRandomValues(out);
+ }
+ return tostring(out);
+ }
+ catch (e) {
+ var browser = global.navigator, plugins = browser && browser.plugins;
+ return [+new Date, global, plugins, global.screen, tostring(pool)];
+ }
+ }
+ //
+ // tostring()
+ // Converts an array of charcodes to a string
+ //
+ function tostring(a) {
+ return String.fromCharCode.apply(0, a);
+ }
+ //
+ // When seedrandom.js is loaded, we immediately mix a few bits
+ // from the built-in RNG into the entropy pool. Because we do
+ // not want to interfere with deterministic PRNG state later,
+ // seedrandom will not call math.random on its own again after
+ // initialization.
+ //
+ mixkey(math.random(), pool);
+ //
+ // Nodejs and AMD support: export the implementation as a module using
+ // either convention.
+ //
+ if (module.exports) {
+ module.exports = seedrandom;
+ // When in node.js, try using crypto package for autoseeding.
+ try {
+ nodecrypto = require$$0__default['default'];
+ }
+ catch (ex) { }
+ }
+ else {
+ // When included as a plain script, set up Math.seedrandom global.
+ math['seed' + rngname] = seedrandom;
+ }
+ // End anonymous scope, and pass initial values.
+ })(
+ // global: `self` in browsers (including strict mode and web workers),
+ // otherwise `this` in Node and other environments
+ (typeof self !== 'undefined') ? self : commonjsGlobal, [], // pool: entropy pool starts empty
+ Math // math: package containing random, pow, and seedrandom
+ );
+ });
+
+ // A library of seedable RNGs implemented in Javascript.
+ //
+ // Usage:
+ //
+ // var seedrandom = require('seedrandom');
+ // var random = seedrandom(1); // or any seed.
+ // var x = random(); // 0 <= x < 1. Every bit is random.
+ // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
+ // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
+ // Period: ~2^116
+ // Reported to pass all BigCrush tests.
+ // xor128, a pure xor-shift generator by George Marsaglia.
+ // Period: 2^128-1.
+ // Reported to fail: MatrixRank and LinearComp.
+ // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
+ // Period: 2^192-2^32
+ // Reported to fail: CollisionOver, SimpPoker, and LinearComp.
+ // xorshift7, by François Panneton and Pierre L'ecuyer, takes
+ // a different approach: it adds robustness by allowing more shifts
+ // than Marsaglia's original three. It is a 7-shift generator
+ // with 256 bits, that passes BigCrush with no systmatic failures.
+ // Period 2^256-1.
+ // No systematic BigCrush failures reported.
+ // xor4096, by Richard Brent, is a 4096-bit xor-shift with a
+ // very long period that also adds a Weyl generator. It also passes
+ // BigCrush with no systematic failures. Its long period may
+ // be useful if you have many generators and need to avoid
+ // collisions.
+ // Period: 2^4128-2^32.
+ // No systematic BigCrush failures reported.
+ // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
+ // number generator derived from ChaCha, a modern stream cipher.
+ // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
+ // Period: ~2^127
+ // No systematic BigCrush failures reported.
+ // The original ARC4-based prng included in this library.
+ // Period: ~2^1600
+ seedrandom$1.alea = alea;
+ seedrandom$1.xor128 = xor128;
+ seedrandom$1.xorwow = xorwow;
+ seedrandom$1.xorshift7 = xorshift7;
+ seedrandom$1.xor4096 = xor4096;
+ seedrandom$1.tychei = tychei;
+ var seedrandom = seedrandom$1;
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // https://en.wikipedia.org/wiki/Marsaglia_polar_method
+ var MPRandGauss = /** @class */ (function () {
+ function MPRandGauss(mean, stdDeviation, dtype, truncated, seed) {
+ this.mean = mean;
+ this.stdDev = stdDeviation;
+ this.dtype = dtype;
+ this.nextVal = NaN;
+ this.truncated = truncated;
+ if (this.truncated) {
+ this.upper = this.mean + this.stdDev * 2;
+ this.lower = this.mean - this.stdDev * 2;
+ }
+ var seedValue = seed ? seed : Math.random();
+ this.random = seedrandom.alea(seedValue.toString());
+ }
+ /** Returns next sample from a Gaussian distribution. */
+ MPRandGauss.prototype.nextValue = function () {
+ if (!isNaN(this.nextVal)) {
+ var value = this.nextVal;
+ this.nextVal = NaN;
+ return value;
+ }
+ var resultX, resultY;
+ var isValid = false;
+ while (!isValid) {
+ var v1 = void 0, v2 = void 0, s = void 0;
+ do {
+ v1 = 2 * this.random() - 1;
+ v2 = 2 * this.random() - 1;
+ s = v1 * v1 + v2 * v2;
+ } while (s >= 1 || s === 0);
+ var mul = Math.sqrt(-2.0 * Math.log(s) / s);
+ resultX = this.mean + this.stdDev * v1 * mul;
+ resultY = this.mean + this.stdDev * v2 * mul;
+ if (!this.truncated || this.isValidTruncated(resultX)) {
+ isValid = true;
+ }
+ }
+ if (!this.truncated || this.isValidTruncated(resultY)) {
+ this.nextVal = this.convertValue(resultY);
+ }
+ return this.convertValue(resultX);
+ };
+ /** Handles proper rounding for non-floating-point numbers. */
+ MPRandGauss.prototype.convertValue = function (value) {
+ if (this.dtype == null || this.dtype === 'float32') {
+ return value;
+ }
+ return Math.round(value);
+ };
+ /** Returns true if less than 2-standard-deviations from the mean. */
+ MPRandGauss.prototype.isValidTruncated = function (value) {
+ return value <= this.upper && value >= this.lower;
+ };
+ return MPRandGauss;
+ }());
+ // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
+ // Gamma Variables."
+ var RandGamma = /** @class */ (function () {
+ function RandGamma(alpha, beta, dtype, seed) {
+ this.alpha = alpha;
+ this.beta = 1 / beta; // convert rate to scale parameter
+ this.dtype = dtype;
+ var seedValue = seed ? seed : Math.random();
+ this.randu = seedrandom.alea(seedValue.toString());
+ this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
+ if (alpha < 1) {
+ this.d = alpha + (2 / 3);
+ }
+ else {
+ this.d = alpha - (1 / 3);
+ }
+ this.c = 1 / Math.sqrt(9 * this.d);
+ }
+ /** Returns next sample from a gamma distribution. */
+ RandGamma.prototype.nextValue = function () {
+ var x2, v0, v1, x, u, v;
+ while (true) {
+ do {
+ x = this.randn.nextValue();
+ v = 1 + (this.c * x);
+ } while (v <= 0);
+ v *= v * v;
+ x2 = x * x;
+ v0 = 1 - (0.331 * x2 * x2);
+ v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v)));
+ u = this.randu();
+ if (u < v0 || Math.log(u) < v1) {
+ break;
+ }
+ }
+ v = (1 / this.beta) * this.d * v;
+ if (this.alpha < 1) {
+ v *= Math.pow(this.randu(), 1 / this.alpha);
+ }
+ return this.convertValue(v);
+ };
+ /** Handles proper rounding for non-floating-point numbers. */
+ RandGamma.prototype.convertValue = function (value) {
+ if (this.dtype === 'float32') {
+ return value;
+ }
+ return Math.round(value);
+ };
+ return RandGamma;
+ }());
+ var UniformRandom = /** @class */ (function () {
+ function UniformRandom(min, max, dtype, seed) {
+ var _this = this;
+ if (min === void 0) { min = 0; }
+ if (max === void 0) { max = 1; }
+ /** Handles proper rounding for non floating point numbers. */
+ this.canReturnFloat = function () { return (_this.dtype == null || _this.dtype === 'float32'); };
+ this.min = min;
+ this.range = max - min;
+ this.dtype = dtype;
+ if (seed == null) {
+ seed = Math.random();
+ }
+ if (typeof seed === 'number') {
+ seed = seed.toString();
+ }
+ if (!this.canReturnFloat() && this.range <= 1) {
+ throw new Error("The difference between " + min + " - " + max + " <= 1 and dtype is not float");
+ }
+ this.random = seedrandom.alea(seed);
+ }
+ UniformRandom.prototype.convertValue = function (value) {
+ if (this.canReturnFloat()) {
+ return value;
+ }
+ return Math.round(value);
+ };
+ UniformRandom.prototype.nextValue = function () {
+ return this.convertValue(this.min + this.range * this.random());
+ };
+ return UniformRandom;
+ }());
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values sampled from a gamma distribution.
+ *
+ * ```js
+ * tf.randomGamma([2, 2], 1).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param alpha The shape parameter of the gamma distribution.
+ * @param beta The inverse scale parameter of the gamma distribution. Defaults
+ * to 1.
+ * @param dtype The data type of the output. Defaults to float32.
+ * @param seed The seed for the random number generator.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Random'}
+ */
+ function randomGamma_(shape, alpha, beta, dtype, seed) {
+ if (beta === void 0) { beta = 1; }
+ if (dtype === void 0) { dtype = 'float32'; }
+ if (beta == null) {
+ beta = 1;
+ }
+ if (dtype == null) {
+ dtype = 'float32';
+ }
+ if (dtype !== 'float32' && dtype !== 'int32') {
+ throw new Error("Unsupported data type " + dtype);
+ }
+ var rgamma = new RandGamma(alpha, beta, dtype, seed);
+ var res = buffer(shape, dtype);
+ for (var i = 0; i < res.values.length; i++) {
+ res.values[i] = rgamma.nextValue();
+ }
+ return res.toTensor();
+ }
+ var randomGamma = op({ randomGamma_: randomGamma_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values sampled from a normal distribution.
+ *
+ * ```js
+ * tf.randomNormal([2, 2]).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param mean The mean of the normal distribution.
+ * @param stdDev The standard deviation of the normal distribution.
+ * @param dtype The data type of the output.
+ * @param seed The seed for the random number generator.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Random'}
+ */
+ function randomNormal_(shape, mean, stdDev, dtype, seed) {
+ if (mean === void 0) { mean = 0; }
+ if (stdDev === void 0) { stdDev = 1; }
+ if (dtype != null && dtype === 'bool') {
+ throw new Error("Unsupported data type " + dtype);
+ }
+ var randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
+ var res = buffer(shape, dtype);
+ for (var i = 0; i < res.values.length; i++) {
+ res.values[i] = randGauss.nextValue();
+ }
+ return res.toTensor();
+ }
+ var randomNormal = op({ randomNormal_: randomNormal_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values sampled from a uniform distribution.
+ *
+ * The generated values follow a uniform distribution in the range [minval,
+ * maxval). The lower bound minval is included in the range, while the upper
+ * bound maxval is excluded.
+ *
+ * ```js
+ * tf.randomUniform([2, 2]).print();
+ * ```
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param minval The lower bound on the range of random values to generate.
+ * Defaults to 0.
+ * @param maxval The upper bound on the range of random values to generate.
+ * Defaults to 1.
+ * @param dtype The data type of the output tensor. Defaults to 'float32'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Random'}
+ */
+ function randomUniform_(shape, minval, maxval, dtype, seed) {
+ if (minval === void 0) { minval = 0; }
+ if (maxval === void 0) { maxval = 1; }
+ if (dtype === void 0) { dtype = 'float32'; }
+ var res = buffer(shape, dtype);
+ var random = new UniformRandom(minval, maxval, null, seed);
+ for (var i = 0; i < res.values.length; i++) {
+ res.values[i] = random.nextValue();
+ }
+ return res.toTensor();
+ }
+ var randomUniform = op({ randomUniform_: randomUniform_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
+ *
+ * The tensor is a is half-open interval meaning it includes start, but
+ * excludes stop. Decrementing ranges and negative step values are also
+ * supported.sv
+ *
+ *
+ * ```js
+ * tf.range(0, 9, 2).print();
+ * ```
+ *
+ * @param start An integer start value
+ * @param stop An integer stop value
+ * @param step An integer increment (will default to 1 or -1)
+ * @param dtype The data type of the output tensor. Defaults to 'float32'.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function range(start, stop, step, dtype) {
+ if (step === void 0) { step = 1; }
+ if (dtype === void 0) { dtype = 'float32'; }
+ if (step === 0) {
+ throw new Error('Cannot have a step of zero');
+ }
+ var attrs = { start: start, stop: stop, step: step, dtype: dtype };
+ return ENGINE.runKernel(Range, {} /* inputs */, attrs);
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns the real part of a complex (or real) tensor.
+ *
+ * Given a tensor input, this operation returns a tensor of type float that is
+ * the real part of each element in input considered as a complex number.
+ *
+ * If the input is real, it simply makes a clone.
+ *
+ * ```js
+ * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
+ * tf.real(x).print();
+ * ```
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function real_(input) {
+ var $input = convertToTensor(input, 'input', 'real');
+ var inputs = { input: $input };
+ return ENGINE.runKernel(Real, inputs);
+ }
+ var real = op({ real_: real_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes reciprocal of x element-wise: `1 / x`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, 2]);
+ *
+ * x.reciprocal().print(); // or tf.reciprocal(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function reciprocal_(x) {
+ var $x = convertToTensor(x, 'x', 'reciprocal');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Reciprocal, inputs);
+ }
+ var reciprocal = op({ reciprocal_: reciprocal_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes rectified linear element-wise: `max(x, 0)`.
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ *
+ * x.relu().print(); // or tf.relu(x)
+ * ```
+ * @param x The input tensor. If the dtype is `bool`, the output dtype will be
+ * `int32'.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function relu_(x) {
+ var $x = convertToTensor(x, 'x', 'relu');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Relu, inputs);
+ }
+ var relu = op({ relu_: relu_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 8]);
+ *
+ * x.relu6().print(); // or tf.relu6(x)
+ * ```
+ * @param x The input tensor. If the dtype is `bool`, the output dtype will be
+ * `int32'.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function relu6_(x) {
+ var $x = convertToTensor(x, 'x', 'relu6');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Relu6, inputs);
+ }
+ var relu6 = op({ relu6_: relu6_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reverses a `tf.Tensor` along a specified axis.
+ *
+ * Also available are stricter rank-specific methods that assert that `x` is
+ * of the given rank:
+ * - `tf.reverse1d`
+ * - `tf.reverse2d`
+ * - `tf.reverse3d`
+ * - `tf.reverse4d`
+ *
+ * Except `tf.reverse1d` (which does not have axis param), all methods have
+ * same signature as this method.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * x.reverse().print();
+ * ```
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * const axis = 1;
+ * x.reverse(axis).print();
+ * ```
+ * @param x The input tensor to be reversed.
+ * @param axis The set of dimensions to reverse. Must be in the
+ * range [-rank(x), rank(x)). Defaults to all axes.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function reverse_(x, axis) {
+ var $x = convertToTensor(x, 'x', 'reverse');
+ var inputs = { x: $x };
+ var attrs = { dims: axis };
+ return ENGINE.runKernel(Reverse, inputs, attrs);
+ }
+ var reverse = op({ reverse_: reverse_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reverses a `tf.Tensor1D`.
+ *
+ * @param x The input tensor.
+ */
+ function reverse1d_(x) {
+ var $x = convertToTensor(x, 'x', 'reverse');
+ assert($x.rank === 1, function () { return "Error in reverse1D: x must be rank 1 but got rank " + $x.rank + "."; });
+ return reverse($x, 0);
+ }
+ var reverse1d = op({ reverse1d_: reverse1d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reverses a `tf.Tensor2D` along a specified axis.
+ *
+ * @param x The input tensor.
+ * @param axis The set of dimensions to reverse. Must be in the
+ * range [-rank(x), rank(x)). Defaults to all axes.
+ */
+ function reverse2d_(x, axis) {
+ var $x = convertToTensor(x, 'x', 'reverse');
+ assert($x.rank === 2, function () { return "Error in reverse2D: x must be rank 2 but got rank " + $x.rank + "."; });
+ return reverse($x, axis);
+ }
+ var reverse2d = op({ reverse2d_: reverse2d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reverses a `tf.Tensor3D` along a specified axis.
+ *
+ * @param x The input tensor.
+ * @param axis The set of dimensions to reverse. Must be in the
+ * range [-rank(x), rank(x)). Defaults to all axes.
+ */
+ function reverse3d_(x, axis) {
+ var $x = convertToTensor(x, 'x', 'reverse');
+ assert($x.rank === 3, function () { return "Error in reverse3D: x must be rank 3 but got rank " + $x.rank + "."; });
+ return reverse($x, axis);
+ }
+ var reverse3d = op({ reverse3d_: reverse3d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Reverses a `tf.Tensor4D` along a specified axis.
+ *
+ * @param x The input tensor.
+ * @param axis The set of dimensions to reverse. Must be in the
+ * range [-rank(x), rank(x)). Defaults to all axes.
+ */
+ function reverse4d_(x, axis) {
+ var $x = convertToTensor(x, 'x', 'reverse');
+ assert($x.rank === 4, function () { return "Error in reverse4D: x must be rank 4 but got rank " + $x.rank + "."; });
+ return reverse($x, axis);
+ }
+ var reverse4d = op({ reverse4d_: reverse4d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes round of input `tf.Tensor` element-wise: `round(x)`.
+ * It implements banker's rounding.
+ *
+ * ```js
+ * const x = tf.tensor1d([.6, 1.1, -3.3]);
+ *
+ * x.round().print(); // or tf.round(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function round_(x) {
+ var $x = convertToTensor(x, 'x', 'round');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Round, inputs);
+ }
+ var round = op({ round_: round_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
+ * `y = 1 / sqrt(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 4, -1]);
+ *
+ * x.rsqrt().print(); // or tf.rsqrt(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function rsqrt_(x) {
+ var $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Rsqrt, inputs);
+ }
+ var rsqrt = op({ rsqrt_: rsqrt_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.scalar` as it makes the code more readable.
+ *
+ * ```js
+ * tf.scalar(3.14).print();
+ * ```
+ *
+ * @param value The value of the scalar.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function scalar(value, dtype) {
+ if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
+ dtype !== 'complex64') {
+ throw new Error('Error creating a new Scalar: value must be a primitive ' +
+ '(number|boolean|string)');
+ }
+ if (dtype === 'string' && isTypedArray(value) &&
+ !(value instanceof Uint8Array)) {
+ throw new Error('When making a scalar from encoded string, ' +
+ 'the value must be `Uint8Array`.');
+ }
+ var shape = [];
+ var inferredShape = [];
+ return makeTensor(value, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes scaled exponential linear element-wise.
+ *
+ * `x < 0 ? scale * alpha * (exp(x) - 1) : x`
+ *
+ * ```js
+ * const x = tf.tensor1d([-1, 2, -3, 4]);
+ *
+ * x.selu().print(); // or tf.selu(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function selu_(x) {
+ var $x = convertToTensor(x, 'x', 'selu');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Selu, inputs);
+ }
+ var selu = op({ selu_: selu_ });
+
+ /**
+ * 2-D convolution with separable filters.
+ *
+ * Performs a depthwise convolution that acts separately on channels followed
+ * by a pointwise convolution that mixes channels. Note that this is
+ * separability between dimensions [1, 2] and 3, not spatial separability
+ * between dimensions 1 and 2.
+ *
+ * See
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
+ * for more details.
+ *
+ * @param x The input tensor, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
+ * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
+ * the filter used in the first step.
+ * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
+ * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
+ * the filter used in the second step.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`. If strides is a single number, then `strideHeight ==
+ * strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels]. Only "NHWC" is currently supported.
+ *
+ * @doc {heading: 'Operations', subheading: 'Convolution'}
+ */
+ function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
+ if (dilation === void 0) { dilation = [1, 1]; }
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ var $x = convertToTensor(x, 'x', 'separableConv2d');
+ var $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
+ var $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ if (dataFormat === 'NCHW') {
+ throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' +
+ 'NHWC is supported');
+ }
+ assert(x4D.rank === 4, function () { return "Error in separableConv2d: input must be rank 4, but got " +
+ ("rank " + x4D.rank + "."); });
+ assert($depthwiseFilter.rank === 4, function () { return "Error in separableConv2d: depthwise filter must be rank 4, but " +
+ ("got rank " + $depthwiseFilter.rank + "."); });
+ assert($pointwiseFilter.rank === 4, function () { return "Error in separableConv2d: pointwise filter must be rank 4, but " +
+ ("got rank " + $depthwiseFilter.rank + "."); });
+ assert($pointwiseFilter.shape[0] === 1, function () { return "Error in separableConv2d: the first dimension of pointwise filter " +
+ (" must be 1, but got " + $pointwiseFilter.shape[0] + "."); });
+ assert($pointwiseFilter.shape[1] === 1, function () { return "Error in separableConv2d: the second dimension of pointwise " +
+ ("filter must be 1, but got " + $pointwiseFilter.shape[1] + "."); });
+ var inChannels = $depthwiseFilter.shape[2];
+ var channelMultiplier = $depthwiseFilter.shape[3];
+ assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, function () { return "Error in separableConv2d: the third dimension of pointwise filter " +
+ ("must be " + inChannels * channelMultiplier + ", ") +
+ ("but got " + $pointwiseFilter.shape[2] + "."); });
+ var depthwise = depthwiseConv2d$1(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
+ var pointwiseStride = 1;
+ var res = conv2d$1(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var separableConv2d = op({ separableConv2d_: separableConv2d_ });
+
+ /**
+ * Computes the difference between two lists of numbers.
+ *
+ * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
+ * that represents all values that are in `x` but not in `y`. The returned
+ * Tensor `out` is sorted in the same order that the numbers appear in `x`
+ * (duplicates are preserved). This operation also returns a Tensor indices that
+ * represents the position of each out element in `x`. In other words:
+ *
+ * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
+ *
+ * ```js
+ * const x = [1, 2, 3, 4, 5, 6];
+ * const y = [1, 3, 5];
+ *
+ * const [out, indices] = await tf.setdiff1dAsync(x, y);
+ * out.print(); // [2, 4, 6]
+ * indices.print(); // [1, 3, 5]
+ * ```
+ *
+ * @param x 1-D Tensor. Values to keep.
+ * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
+ * output.
+ * @returns Promise of Tensor tuple [out, indices].
+ * out: Tensor with the same type as x.
+ * indices: A Tensor of type int32.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function setdiff1dAsync_(x, y) {
+ return __awaiter(this, void 0, void 0, function () {
+ var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, i, p;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ $x = convertToTensor(x, 'x', 'setdiff1d');
+ $y = convertToTensor(y, 'y', 'setdiff1d');
+ assert($x.dtype === $y.dtype, function () { return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ")."; });
+ assert($x.rank === 1, function () { return "x should be 1D tensor, but got x (" + $x.shape + ")."; });
+ assert($y.rank === 1, function () { return "y should be 1D tensor, but got y (" + $y.shape + ")."; });
+ return [4 /*yield*/, $x.data()];
+ case 1:
+ xVals = _a.sent();
+ return [4 /*yield*/, $y.data()];
+ case 2:
+ yVals = _a.sent();
+ ySet = new Set(yVals);
+ outputSize = 0;
+ for (i = 0; i < xVals.length; i++) {
+ if (!ySet.has(xVals[i])) {
+ outputSize++;
+ }
+ }
+ buffer = new TensorBuffer([outputSize], $x.dtype);
+ indices = new TensorBuffer([outputSize], 'int32');
+ for (i = 0, p = 0; i < xVals.length; i++) {
+ if (!ySet.has(xVals[i])) {
+ buffer.values[p] = xVals[i];
+ indices.values[p] = i;
+ p++;
+ }
+ }
+ return [2 /*return*/, [buffer.toTensor(), indices.toTensor()]];
+ }
+ });
+ });
+ }
+ var setdiff1dAsync = setdiff1dAsync_;
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Returns an element-wise indication of the sign of a number.
+ *
+ * ```js
+ * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
+ *
+ * x.sign().print(); // or tf.sign(x)
+ * ```
+ * @param x The input Tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function sign_(x) {
+ var $x = convertToTensor(x, 'x', 'sign');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Sign, inputs);
+ }
+ var sign = op({ sign_: sign_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes sin of the input Tensor element-wise: `sin(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
+ *
+ * x.sin().print(); // or tf.sin(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function sin_(x) {
+ var $x = convertToTensor(x, 'x', 'sin', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Sin, inputs);
+ }
+ var sin = op({ sin_: sin_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 1, -1, .7]);
+ *
+ * x.sinh().print(); // or tf.sinh(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function sinh_(x) {
+ var $x = convertToTensor(x, 'x', 'sinh');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Sinh, inputs);
+ }
+ var sinh = op({ sinh_: sinh_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
+ * of length `size`. See `slice` for details.
+ */
+ function slice1d_(x, begin, size) {
+ var $x = convertToTensor(x, 'x', 'slice1d');
+ assert($x.rank === 1, function () { return "slice1d expects a rank-1 tensor, but got a rank-" + $x.rank + " tensor"; });
+ return slice($x, [begin], [size]);
+ }
+ var slice1d = op({ slice1d_: slice1d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
+ * is of size `size`. See `slice` for details.
+ */
+ function slice2d_(x, begin, size) {
+ var $x = convertToTensor(x, 'x', 'slice2d');
+ assert($x.rank === 2, function () { return "slice2d expects a rank-2 tensor, but got a rank-" + $x.rank + " tensor"; });
+ return slice($x, begin, size);
+ }
+ var slice2d = op({ slice2d_: slice2d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
+ * is of size `size`. See `slice` for details.
+ */
+ function slice3d_(x, begin, size) {
+ var $x = convertToTensor(x, 'x', 'slice3d');
+ assert($x.rank === 3, function () { return "slice3d expects a rank-3 tensor, but got a rank-" + $x.rank + " tensor"; });
+ return slice($x, begin, size);
+ }
+ var slice3d = op({ slice3d_: slice3d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
+ * is of size `size`. See `slice` for details.
+ */
+ function slice4d_(x, begin, size) {
+ var $x = convertToTensor(x, 'x', 'slice4d');
+ assert($x.rank === 4, function () { return "slice4d expects a rank-4 tensor, but got a rank-" + $x.rank + " tensor"; });
+ return slice($x, begin, size);
+ }
+ var slice4d = op({ slice4d_: slice4d_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the softmax normalized vector given the logits.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2, 3]);
+ *
+ * a.softmax().print(); // or tf.softmax(a)
+ * ```
+ *
+ * ```js
+ * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
+ *
+ * a.softmax().print(); // or tf.softmax(a)
+ * ```
+ *
+ * @param logits The logits array.
+ * @param dim The dimension softmax would be performed on. Defaults to `-1`
+ * which indicates the last dimension.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function softmax_(logits, dim) {
+ if (dim === void 0) { dim = -1; }
+ var $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
+ if (dim === -1) {
+ dim = $logits.rank - 1;
+ }
+ if (dim !== $logits.rank - 1) {
+ throw Error('Softmax along a non-last dimension is not yet supported. ' +
+ ("Logits was rank " + $logits.rank + " and dim was " + dim));
+ }
+ var inputs = { logits: $logits };
+ var attrs = { dim: dim };
+ return ENGINE.runKernel(Softmax, inputs, attrs);
+ }
+ var softmax = op({ softmax_: softmax_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Fast Fourier transform.
+ *
+ * Computes the 1-dimensional discrete Fourier transform over the inner-most
+ * dimension of input.
+ *
+ * ```js
+ * const real = tf.tensor1d([1, 2, 3]);
+ * const imag = tf.tensor1d([1, 2, 3]);
+ * const x = tf.complex(real, imag);
+ *
+ * x.fft().print(); // tf.spectral.fft(x).print();
+ * ```
+ * @param input The complex input to compute an fft over.
+ *
+ * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
+ */
+ function fft_(input) {
+ assert(input.dtype === 'complex64', function () { return "The dtype for tf.spectral.fft() must be complex64 " +
+ ("but got " + input.dtype + "."); });
+ var inputs = { input: input };
+ return ENGINE.runKernel(FFT, inputs);
+ }
+ var fft = op({ fft_: fft_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Inverse fast Fourier transform.
+ *
+ * Computes the inverse 1-dimensional discrete Fourier transform over the
+ * inner-most dimension of input.
+ *
+ * ```js
+ * const real = tf.tensor1d([1, 2, 3]);
+ * const imag = tf.tensor1d([1, 2, 3]);
+ * const x = tf.complex(real, imag);
+ *
+ * x.ifft().print(); // tf.spectral.ifft(x).print();
+ * ```
+ * @param input The complex input to compute an ifft over.
+ *
+ * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
+ */
+ function ifft_(input) {
+ assert(input.dtype === 'complex64', function () { return "The dtype for tf.spectral.ifft() must be complex64 " +
+ ("but got " + input.dtype + "."); });
+ var inputs = { input: input };
+ return ENGINE.runKernel(IFFT, inputs);
+ }
+ var ifft = op({ ifft_: ifft_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Inversed real value input fast Fourier transform.
+ *
+ * Computes the 1-dimensional inversed discrete Fourier transform over the
+ * inner-most dimension of the real input.
+ *
+ * ```js
+ * const real = tf.tensor1d([1, 2, 3]);
+ * const imag = tf.tensor1d([0, 0, 0]);
+ * const x = tf.complex(real, imag);
+ *
+ * x.irfft().print();
+ * ```
+ * @param input The real value input to compute an irfft over.
+ *
+ * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
+ */
+ function irfft_(input) {
+ var innerDimensionSize = input.shape[input.shape.length - 1];
+ var batch = input.size / innerDimensionSize;
+ var ret;
+ if (innerDimensionSize <= 2) {
+ var complexInput = reshape(input, [batch, innerDimensionSize]);
+ ret = ifft(complexInput);
+ }
+ else {
+ // The length of unique components of the DFT of a real-valued signal
+ // is 2 * (input_len - 1)
+ var outputShape = [batch, 2 * (innerDimensionSize - 1)];
+ var realInput = reshape(real(input), [batch, innerDimensionSize]);
+ var imagInput = reshape(imag(input), [batch, innerDimensionSize]);
+ var realConjugate = reverse(slice(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
+ var imagConjugate = mul(reverse(slice(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
+ var r = concat([realInput, realConjugate], 1);
+ var i = concat([imagInput, imagConjugate], 1);
+ var complexInput = reshape(complex(r, i), [outputShape[0], outputShape[1]]);
+ ret = ifft(complexInput);
+ }
+ ret = real(ret);
+ // reshape the result if the input is 3D tensor.
+ if (input.rank === 3 && input.shape[0] !== 0) {
+ var temp = ret;
+ var batch_1 = input.shape[0];
+ ret = reshape(ret, [batch_1, ret.shape[0] / batch_1, ret.shape[1]]);
+ temp.dispose();
+ }
+ return ret;
+ }
+ var irfft = op({ irfft_: irfft_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Splits a `tf.Tensor` into sub tensors.
+ *
+ * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
+ * into `numOrSizeSplits` smaller tensors.
+ * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
+ *
+ * If `numOrSizeSplits` is a number array, splits `x` into
+ * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
+ * same size as `x` except along dimension `axis` where the size is
+ * `numOrSizeSplits[i]`.
+ *
+ * ```js
+ * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
+ * const [a, b] = tf.split(x, 2, 1);
+ * a.print();
+ * b.print();
+ *
+ * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
+ * c.print();
+ * d.print();
+ * e.print();
+ * ```
+ *
+ * @param x The input tensor to split.
+ * @param numOrSizeSplits Either an integer indicating the number of
+ * splits along the axis or an array of integers containing the sizes of
+ * each output tensor along the axis. If a number then it must evenly divide
+ * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
+ * Can contain one -1 indicating that dimension is to be inferred.
+ * @param axis The dimension along which to split. Defaults to 0 (the first
+ * dim).
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function split_(x, numOrSizeSplits, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'split');
+ var inputs = { x: $x };
+ var attr = { numOrSizeSplits: numOrSizeSplits, axis: axis };
+ return ENGINE.runKernel(SplitV, inputs, attr);
+ }
+ var split = op({ split_: split_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Real value input fast Fourier transform.
+ *
+ * Computes the 1-dimensional discrete Fourier transform over the
+ * inner-most dimension of the real input.
+ *
+ * ```js
+ * const real = tf.tensor1d([1, 2, 3]);
+ *
+ * real.rfft().print();
+ * ```
+ * @param input The real value input to compute an rfft over.
+ *
+ * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
+ */
+ function rfft_(input, fftLength) {
+ assert(input.dtype === 'float32', function () { return "The dtype for rfft() must be real value but got " + input.dtype; });
+ var innerDimensionSize = input.shape[input.shape.length - 1];
+ var batch = input.size / innerDimensionSize;
+ var adjustedInput;
+ if (fftLength != null && fftLength < innerDimensionSize) {
+ // Need to crop
+ var begin = input.shape.map(function (v) { return 0; });
+ var size = input.shape.map(function (v) { return v; });
+ size[input.shape.length - 1] = fftLength;
+ adjustedInput = slice(input, begin, size);
+ innerDimensionSize = fftLength;
+ }
+ else if (fftLength != null && fftLength > innerDimensionSize) {
+ // Need to pad with zeros
+ var zerosShape = input.shape.map(function (v) { return v; });
+ zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
+ adjustedInput = concat([input, zeros(zerosShape)], input.shape.length - 1);
+ innerDimensionSize = fftLength;
+ }
+ else {
+ adjustedInput = input;
+ }
+ // Complement the input with zero imaginary numbers.
+ var zerosInput = zerosLike(adjustedInput);
+ var complexInput = reshape(complex(adjustedInput, zerosInput), [batch, innerDimensionSize]);
+ var ret = fft(complexInput);
+ // Exclude complex conjugations. These conjugations are put symmetrically.
+ var half = Math.floor(innerDimensionSize / 2) + 1;
+ var realValues = real(ret);
+ var imagValues = imag(ret);
+ var realComplexConjugate = split(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
+ var imagComplexConjugate = split(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
+ var outputShape = adjustedInput.shape.slice();
+ outputShape[adjustedInput.shape.length - 1] = half;
+ return reshape(complex(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
+ }
+ var rfft = op({ rfft_: rfft_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 4, -1]);
+ *
+ * x.sqrt().print(); // or tf.sqrt(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function sqrt_(x) {
+ var $x = convertToTensor(x, 'x', 'sqrt', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Sqrt, inputs);
+ }
+ var sqrt = op({ sqrt_: sqrt_ });
+
+ /**
+ * Returns (a - b) * (a - b) element-wise.
+ * Supports broadcasting.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 4, 3, 16]);
+ * const b = tf.tensor1d([1, 2, 9, 4]);
+ *
+ * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
+ * ```
+ *
+ * ```js
+ * // Broadcast squared difference a with b.
+ * const a = tf.tensor1d([2, 4, 6, 8]);
+ * const b = tf.scalar(5);
+ *
+ * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
+ * ```
+ *
+ * @param a The first tensor.
+ * @param b The second tensor. Must have the same type as `a`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Arithmetic'}
+ */
+ function squaredDifference_(a, b) {
+ var _a;
+ var $a = convertToTensor(a, 'a', 'squaredDifference');
+ var $b = convertToTensor(b, 'b', 'squaredDifference');
+ _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1];
+ assertAndGetBroadcastShape($a.shape, $b.shape);
+ var inputs = { a: $a, b: $b };
+ var attrs = {};
+ return ENGINE.runKernel(SquaredDifference, inputs, attrs);
+ }
+ var squaredDifference = op({ squaredDifference_: squaredDifference_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
+ *
+ * ```js
+ * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
+ * x.squeeze().print();
+ * ```
+ *
+ * @param x The input tensor to be squeezed.
+ * @param axis An optional list of numbers. If specified, only
+ * squeezes the dimensions listed. The dimension index starts at 0. It
+ * is an error to squeeze a dimension that is not 1.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Transformations'}
+ */
+ function squeeze_(x, axis) {
+ var $x = convertToTensor(x, 'x', 'squeeze');
+ return reshape($x, squeezeShape($x.shape, axis).newShape);
+ }
+ var squeeze = op({ squeeze_: squeeze_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
+ *
+ * ```js
+ * const a = tf.tensor1d([1, 2]);
+ * const b = tf.tensor1d([3, 4]);
+ * const c = tf.tensor1d([5, 6]);
+ * tf.stack([a, b, c]).print();
+ * ```
+ *
+ * @param tensors A list of tensor objects with the same shape and dtype.
+ * @param axis The axis to stack along. Defaults to 0 (the first dim).
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function stack_(tensors, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
+ assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; });
+ if ($tensors.length > 0) {
+ assert(axis <= $tensors[0].rank, function () { return 'Axis must be <= rank of the tensor'; });
+ }
+ var inputs = $tensors;
+ var attrs = { axis: axis };
+ return ENGINE.runKernel(Pack, inputs, attrs);
+ }
+ var stack = op({ stack_: stack_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, 2, -1, -3]);
+ *
+ * x.step(.5).print(); // or tf.step(x, .5)
+ * ```
+ * @param x The input tensor.
+ * @param alpha The gradient when input is negative.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function step_(x, alpha) {
+ if (alpha === void 0) { alpha = 0.0; }
+ var $x = convertToTensor(x, 'x', 'step');
+ var inputs = { x: $x };
+ var attrs = { alpha: alpha };
+ return ENGINE.runKernel(Step, inputs, attrs);
+ }
+ var step = op({ step_: step_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts a strided slice of a tensor.
+ *
+ * Roughly speaking, this op extracts a slice of size (end-begin)/stride from
+ * the given input tensor (x). Starting at the location specified by begin the
+ * slice continues by adding stride to the index until all dimensions are not
+ * less than end. Note that a stride can be negative, which causes a reverse
+ * slice.
+ *
+ * ```js
+ * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
+ * [3, 2, 3]);
+ * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
+ * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
+ * // [4, 4, 4]]]
+ * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
+ * // [3, 3, 3]]]
+ * ```
+ *
+ * @param x The tensor to stride slice.
+ * @param begin The coordinates to start the slice from.
+ * @param end: The coordinates to end the slice at.
+ * @param strides: The size of the slice.
+ * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
+ * and the fullest possible range in that dimension is used instead.
+ * @param endMask: If the ith bit of endMask is set, end[i] is ignored
+ * and the fullest possible range in that dimension is used instead.
+ * @param shrinkAxisMask: a bitmask where bit i implies that
+ * the ith specification should shrink the dimensionality. begin and end must
+ * imply a slice of size 1 in the dimension.
+ *
+ * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
+ */
+ function stridedSlice_(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
+ if (beginMask === void 0) { beginMask = 0; }
+ if (endMask === void 0) { endMask = 0; }
+ if (ellipsisMask === void 0) { ellipsisMask = 0; }
+ if (newAxisMask === void 0) { newAxisMask = 0; }
+ if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; }
+ var $x = convertToTensor(x, 'x', 'stridedSlice', 'string_or_numeric');
+ var inputs = { x: $x };
+ var attrs = {
+ begin: begin,
+ end: end,
+ strides: strides,
+ beginMask: beginMask,
+ endMask: endMask,
+ ellipsisMask: ellipsisMask,
+ newAxisMask: newAxisMask,
+ shrinkAxisMask: shrinkAxisMask
+ };
+ return ENGINE.runKernel(StridedSlice, inputs, attrs);
+ }
+ var stridedSlice = op({ stridedSlice_: stridedSlice_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
+ *
+ * ```js
+ * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
+ *
+ * x.tan().print(); // or tf.tan(x)
+ * ```
+ * @param x The input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Basic math'}
+ */
+ function tan_(x) {
+ var $x = convertToTensor(x, 'x', 'tan', 'float32');
+ var inputs = { x: $x };
+ return ENGINE.runKernel(Tan, inputs);
+ }
+ var tan = op({ tan_: tan_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor1d` as it makes the code more readable.
+ *
+ * ```js
+ * tf.tensor1d([1, 2, 3]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be array of numbers,
+ * or a `TypedArray`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor1d(values, dtype) {
+ assertNonNull(values);
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 1) {
+ throw new Error('tensor1d() requires values to be a flat/TypedArray');
+ }
+ var shape = null;
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor2d` as it makes the code more readable.
+ *
+ * ```js
+ * // Pass a nested array.
+ * tf.tensor2d([[1, 2], [3, 4]]).print();
+ * ```
+ * ```js
+ * // Pass a flat array and specify a shape.
+ * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`.
+ * @param shape The shape of the tensor. If not provided, it is inferred from
+ * `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor2d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 2) {
+ throw new Error('tensor2d() requires shape to have two numbers');
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 2 && inferredShape.length !== 1) {
+ throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
+ }
+ if (inferredShape.length === 1 && shape == null) {
+ throw new Error('tensor2d() requires shape to be provided when `values` ' +
+ 'are a flat/TypedArray');
+ }
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor4d` as it makes the code more readable.
+ *
+ * ```js
+ * // Pass a nested array.
+ * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
+ * ```
+ * ```js
+ * // Pass a flat array and specify a shape.
+ * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`.
+ * @param shape The shape of the tensor. Optional. If not provided,
+ * it is inferred from `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor4d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 4) {
+ throw new Error('tensor4d() requires shape to have four numbers');
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 4 && inferredShape.length !== 1) {
+ throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
+ }
+ if (inferredShape.length === 1 && shape == null) {
+ throw new Error('tensor4d() requires shape to be provided when `values` ' +
+ 'are a flat array');
+ }
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor5d` as it makes the code more readable.
+ *
+ * ```js
+ * // Pass a nested array.
+ * tf.tensor5d([[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]).print();
+ * ```
+ * ```js
+ * // Pass a flat array and specify a shape.
+ * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`.
+ * @param shape The shape of the tensor. Optional. If not provided,
+ * it is inferred from `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor5d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 5) {
+ throw new Error('tensor5d() requires shape to have five numbers');
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 5 && inferredShape.length !== 1) {
+ throw new Error('tensor5d() requires values to be ' +
+ 'number[][][][][] or flat/TypedArray');
+ }
+ if (inferredShape.length === 1 && shape == null) {
+ throw new Error('tensor5d() requires shape to be provided when `values` ' +
+ 'are a flat array');
+ }
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
+ *
+ * The same functionality can be achieved with `tf.tensor`, but in general
+ * we recommend using `tf.tensor6d` as it makes the code more readable.
+ *
+ * ```js
+ * // Pass a nested array.
+ * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
+ * ```
+ * ```js
+ * // Pass a flat array and specify a shape.
+ * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
+ * ```
+ *
+ * @param values The values of the tensor. Can be nested array of numbers,
+ * or a flat array, or a `TypedArray`.
+ * @param shape The shape of the tensor. Optional. If not provided,
+ * it is inferred from `values`.
+ * @param dtype The data type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function tensor6d(values, shape, dtype) {
+ assertNonNull(values);
+ if (shape != null && shape.length !== 6) {
+ throw new Error('tensor6d() requires shape to have six numbers');
+ }
+ var inferredShape = inferShape(values, dtype);
+ if (inferredShape.length !== 6 && inferredShape.length !== 1) {
+ throw new Error('tensor6d() requires values to be number[][][][][][] or ' +
+ 'flat/TypedArray');
+ }
+ if (inferredShape.length === 1 && shape == null) {
+ throw new Error('tensor6d() requires shape to be provided when `values` ' +
+ 'are a flat array');
+ }
+ shape = shape ||
+ inferredShape;
+ return makeTensor(values, shape, inferredShape, dtype);
+ }
+
+ /**
+ * Finds the values and indices of the `k` largest entries along the last
+ * dimension.
+ *
+ * If the input is a vector (rank=1), finds the k largest entries in the vector
+ * and outputs their values and indices as vectors. Thus values[j] is the j-th
+ * largest entry in input, and its index is indices[j].
+ * For higher rank inputs, computes the top k entries along the last dimension.
+ *
+ * If two elements are equal, the lower-index element appears first.
+ *
+ * ```js
+ * const a = tf.tensor2d([[1, 5], [4, 3]]);
+ * const {values, indices} = tf.topk(a);
+ * values.print();
+ * indices.print();
+ * ```
+ * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
+ * @param k Number of top elements to look for along the last dimension.
+ * @param sorted If true, the resulting `k` elements will be sorted by the
+ * values in descending order.
+ *
+ * @doc {heading: 'Operations', subheading: 'Evaluation'}
+ */
+ function topk_(x, k, sorted) {
+ if (k === void 0) { k = 1; }
+ if (sorted === void 0) { sorted = true; }
+ var $x = convertToTensor(x, 'x', 'topk');
+ if ($x.rank === 0) {
+ throw new Error('topk() expects the input to be of rank 1 or higher');
+ }
+ var lastDim = $x.shape[$x.shape.length - 1];
+ if (k < 0) {
+ throw new Error("'k' passed to topk() must be >= 0 but got " + k);
+ }
+ if (k > lastDim) {
+ throw new Error("'k' passed to topk() must be <= the last dimension (" + lastDim + ") " +
+ ("but got " + k));
+ }
+ var inputs = { x: $x };
+ var attrs = { k: k, sorted: sorted };
+ var _a = __read(ENGINE.runKernel(TopK, inputs, attrs), 2), values = _a[0], indices = _a[1];
+ return { values: values, indices: indices };
+ }
+ var topk = op({ topk_: topk_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a `tf.Tensor` with values sampled from a truncated normal
+ * distribution.
+ *
+ * ```js
+ * tf.truncatedNormal([2, 2]).print();
+ * ```
+ *
+ * The generated values follow a normal distribution with specified mean and
+ * standard deviation, except that values whose magnitude is more than 2
+ * standard deviations from the mean are dropped and re-picked.
+ *
+ * @param shape An array of integers defining the output tensor shape.
+ * @param mean The mean of the normal distribution.
+ * @param stdDev The standard deviation of the normal distribution.
+ * @param dtype The data type of the output tensor.
+ * @param seed The seed for the random number generator.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function truncatedNormal_(shape, mean, stdDev, dtype, seed) {
+ if (mean === void 0) { mean = 0; }
+ if (stdDev === void 0) { stdDev = 1; }
+ if (dtype != null && dtype === 'bool') {
+ throw new Error("Unsupported data type $ { dtype }");
+ }
+ var randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
+ var res = buffer(shape, dtype);
+ for (var i = 0; i < res.values.length; i++) {
+ res.values[i] = randGauss.nextValue();
+ }
+ return res.toTensor();
+ }
+ var truncatedNormal = op({ truncatedNormal_: truncatedNormal_ });
+
+ /**
+ * Finds unique elements along an axis of a tensor.
+ *
+ * It returns a tensor `values` containing all of the unique elements along the
+ * `axis` of the given tensor `x` in the same order that they occur along the
+ * `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
+ * `indices` the same size as the number of the elements in `x` along the `axis`
+ * dimension. It contains the index in the unique output `values`.
+ *
+ * ```js
+ * // A 1-D tensor
+ * const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
+ * const {values, indices} = tf.unique(a);
+ * values.print(); // [1, 2, 4, 7, 8,]
+ * indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
+ * ```
+ *
+ * ```js
+ * // A 2-D tensor with axis=0
+ * //
+ * // 'a' is: [[1, 0, 0],
+ * // [1, 0, 0],
+ * // [2, 0, 0]]
+ * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
+ * const {values, indices} = tf.unique(a, 0)
+ * values.print(); // [[1, 0, 0],
+ * // [2, 0, 0]]
+ * indices.print(); // [0, 0, 1]
+ * ```
+ *
+ * ```js
+ * // A 2-D tensor with axis=1
+ * //
+ * // 'a' is: [[1, 0, 0],
+ * // [1, 0, 0],
+ * // [2, 0, 0]]
+ * const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
+ * const {values, indices} = tf.unique(a, 1)
+ * values.print(); // [[1, 0],
+ * // [1, 0],
+ * // [2, 0]]
+ * indices.print(); // [0, 1, 1]
+ * ```
+ * @param x A tensor (int32, string, bool).
+ * @param axis The axis of the tensor to find the unique elements.
+ * @returns [uniqueElements, indices] (see above for details)
+ *
+ * @doc {heading: 'Operations', subheading: 'Evaluation'}
+ */
+ function unique_(x, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'unique', 'string_or_numeric');
+ assert($x.rank > 0, function () { return 'The input tensor must be at least 1D'; });
+ var inputs = { x: $x };
+ var attrs = { axis: axis };
+ var _a = __read(ENGINE.runKernel(Unique, inputs, attrs), 2), values = _a[0], indices = _a[1];
+ return { values: values, indices: indices };
+ }
+ var unique = op({ unique_: unique_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the sum along segments of a `tf.Tensor`.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
+ * const numSegments = 3;
+ *
+ * x.unsortedSegmentSum(segmentIds, numSegments).print()
+ * //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
+ * ```
+ * @param x The `tf.Tensor` that will be summed along its segments.
+ * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
+ * dimension along the `axis`. Maps each element of `x` to a segment.
+ * @param numSegments The number of distinct `segmentIds`.
+ *
+ * @doc {heading: 'Operations', subheading: 'Segment'}
+ */
+ function unsortedSegmentSum_(x, segmentIds, numSegments) {
+ var $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
+ var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
+ assert(isInt(numSegments), function () { return 'numSegments must be of dtype int'; });
+ var inputs = { x: $x, segmentIds: $segmentIds };
+ var attrs = { numSegments: numSegments };
+ return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
+ }
+ var unsortedSegmentSum = op({ unsortedSegmentSum_: unsortedSegmentSum_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
+ *
+ * ```js
+ * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ *
+ * tf.unstack(a).forEach(tensor => tensor.print());
+ * ```
+ *
+ * @param x A tensor object.
+ * @param axis The axis to unstack along. Defaults to 0 (the first dim).
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function unstack_(x, axis) {
+ if (axis === void 0) { axis = 0; }
+ var $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
+ assert(axis >= -$x.shape.length && axis < $x.shape.length, function () { return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")"; });
+ var inputs = { value: $x };
+ var attrs = { axis: axis };
+ return ENGINE.runKernel(Unpack, inputs, attrs);
+ }
+ var unstack = op({ unstack_: unstack_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a new variable with the provided initial value.
+ * ```js
+ * const x = tf.variable(tf.tensor([1, 2, 3]));
+ * x.assign(tf.tensor([4, 5, 6]));
+ *
+ * x.print();
+ * ```
+ *
+ * @param initialValue Initial value for the tensor.
+ * @param trainable If true, optimizers are allowed to update it.
+ * @param name Name of the variable. Defaults to a unique id.
+ * @param dtype If set, initialValue will be converted to the given type.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Creation'}
+ */
+ function variable(initialValue, trainable, name, dtype) {
+ if (trainable === void 0) { trainable = true; }
+ return ENGINE.makeVariable(initialValue, trainable, name, dtype);
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function whereImpl(condShape, condVals) {
+ var indices = [];
+ for (var i = 0; i < condVals.length; i++) {
+ if (condVals[i]) {
+ indices.push(i);
+ }
+ }
+ var inBuffer = buffer(condShape, 'int32');
+ var out = buffer([indices.length, condShape.length], 'int32');
+ for (var i = 0; i < indices.length; i++) {
+ var loc = inBuffer.indexToLoc(indices[i]);
+ var offset = i * condShape.length;
+ out.values.set(loc, offset);
+ }
+ return out.toTensor();
+ }
+
+ /**
+ * Returns the coordinates of true elements of condition.
+ *
+ * The coordinates are returned in a 2-D tensor where the first dimension (rows)
+ * represents the number of true elements, and the second dimension (columns)
+ * represents the coordinates of the true elements. Keep in mind, the shape of
+ * the output tensor can vary depending on how many true values there are in
+ * input. Indices are output in row-major order. The resulting tensor has the
+ * shape `[numTrueElems, condition.rank]`.
+ *
+ * This is analogous to calling the python `tf.where(cond)` without an x or y.
+ *
+ * ```js
+ * const cond = tf.tensor1d([false, false, true], 'bool');
+ * const result = await tf.whereAsync(cond);
+ * result.print();
+ * ```
+ *
+ * @doc {heading: 'Operations', subheading: 'Logical'}
+ */
+ function whereAsync_(condition) {
+ return __awaiter(this, void 0, void 0, function () {
+ var $condition, vals, res;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
+ return [4 /*yield*/, $condition.data()];
+ case 1:
+ vals = _a.sent();
+ res = whereImpl($condition.shape, vals);
+ if (condition !== $condition) {
+ $condition.dispose();
+ }
+ return [2 /*return*/, res];
+ }
+ });
+ });
+ }
+ var whereAsync = whereAsync_;
+
+ /**
+ * Apply boolean mask to tensor.
+ *
+ * ```js
+ * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
+ * const mask = tf.tensor1d([1, 0, 1], 'bool');
+ * const result = await tf.booleanMaskAsync(tensor, mask);
+ * result.print();
+ * ```
+ *
+ * @param tensor N-D tensor.
+ * @param mask K-D boolean tensor, K <= N and K must be known statically.
+ * @param axis A 0-D int Tensor representing the axis in tensor to mask from.
+ * By default, axis is 0 which will mask from the first dimension.
+ * Otherwise K + axis <= N.
+ *
+ * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
+ */
+ function booleanMaskAsync_(tensor, mask, axis) {
+ return __awaiter(this, void 0, void 0, function () {
+ var $tensor, $mask, axisFrom, maskDim, tensorShape, leadingSize, i, targetTensorShape, reshapedTensor, reshapedMask, positivePositions, indices, res;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
+ $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
+ axisFrom = axis == null ? 0 : axis;
+ maskDim = $mask.rank;
+ tensorShape = $tensor.shape;
+ assert(maskDim > 0, function () { return 'mask cannot be scalar'; });
+ assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, "mask's shape must match the first K dimensions of tensor's shape,");
+ leadingSize = 1;
+ for (i = axisFrom; i < axisFrom + maskDim; i++) {
+ leadingSize *= tensorShape[i];
+ }
+ targetTensorShape = tensorShape.slice(0, axisFrom)
+ .concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
+ reshapedTensor = reshape($tensor, targetTensorShape);
+ reshapedMask = reshape($mask, [-1]);
+ return [4 /*yield*/, whereAsync(reshapedMask)];
+ case 1:
+ positivePositions = _a.sent();
+ indices = squeeze(positivePositions, [1]);
+ res = gather(reshapedTensor, indices, axisFrom);
+ // Ensure no memory leak.
+ if (tensor !== $tensor) {
+ $tensor.dispose();
+ }
+ if (mask !== $mask) {
+ $mask.dispose();
+ }
+ indices.dispose();
+ reshapedTensor.dispose();
+ reshapedMask.dispose();
+ positivePositions.dispose();
+ return [2 /*return*/, res];
+ }
+ });
+ });
+ }
+ var booleanMaskAsync = booleanMaskAsync_;
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the norm of scalar, vectors, and matrices.
+ * This function can compute several different vector norms (the 1-norm, the
+ * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
+ * and matrix norms (Frobenius, 1-norm, and inf-norm).
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 3, 4]);
+ *
+ * x.norm().print(); // or tf.norm(x)
+ * ```
+ *
+ * @param x The input array.
+ * @param ord Optional. Order of the norm. Supported norm types are
+ * following:
+ *
+ * | ord | norm for matrices | norm for vectors
+ * |------------|---------------------------|---------------------
+ * |'euclidean' |Frobenius norm |2-norm
+ * |'fro' |Frobenius norm |
+ * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
+ * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
+ * |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
+ * |2 | |sum(abs(x)^2)^1/2*
+ *
+ * @param axis Optional. If axis is null (the default), the input is
+ * considered a vector and a single vector norm is computed over the entire
+ * set of values in the Tensor, i.e. norm(x, ord) is equivalent
+ * to norm(x.reshape([-1]), ord). If axis is a integer, the input
+ * is considered a batch of vectors, and axis determines the axis in x
+ * over which to compute vector norms. If axis is a 2-tuple of integer it is
+ * considered a batch of matrices and axis determines the axes in NDArray
+ * over which to compute a matrix norm.
+ * @param keepDims Optional. If true, the norm have the same dimensionality
+ * as the input.
+ *
+ * @doc {heading: 'Operations', subheading: 'Matrices'}
+ */
+ function norm_(x, ord, axis, keepDims) {
+ if (ord === void 0) { ord = 'euclidean'; }
+ if (axis === void 0) { axis = null; }
+ if (keepDims === void 0) { keepDims = false; }
+ x = convertToTensor(x, 'x', 'norm');
+ var norm = normImpl(x, ord, axis);
+ var keepDimsShape = norm.shape;
+ if (keepDims) {
+ var axes = parseAxisParam(axis, x.shape);
+ keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
+ }
+ return reshape(norm, keepDimsShape);
+ }
+ function normImpl(x, p, axis) {
+ if (axis === void 0) { axis = null; }
+ if (x.rank === 0) {
+ return abs(x);
+ }
+ // consider vector when no axis is specified
+ if (x.rank !== 1 && axis === null) {
+ return normImpl(reshape(x, [-1]), p, axis);
+ }
+ // vector
+ if (x.rank === 1 || typeof axis === 'number' ||
+ Array.isArray(axis) && axis.length === 1) {
+ if (p === 1) {
+ return sum(abs(x), axis);
+ }
+ if (p === Infinity) {
+ return max(abs(x), axis);
+ }
+ if (p === -Infinity) {
+ return min(abs(x), axis);
+ }
+ if (p === 'euclidean' || p === 2) {
+ // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
+ return sqrt(sum(pow(abs(x), scalar(2, 'int32')), axis));
+ }
+ throw new Error("Error in norm: invalid ord value: " + p);
+ }
+ // matrix (assumption axis[0] < axis[1])
+ if (Array.isArray(axis) && axis.length === 2) {
+ if (p === 1) {
+ return max(sum(abs(x), axis[0]), axis[1] - 1);
+ }
+ if (p === Infinity) {
+ return max(sum(abs(x), axis[1]), axis[0]);
+ }
+ if (p === -Infinity) {
+ return min(sum(abs(x), axis[1]), axis[0]);
+ }
+ if (p === 'fro' || p === 'euclidean') {
+ // norm(x) = sqrt(sum(pow(x, 2)))
+ return sqrt(sum(square(x), axis));
+ }
+ throw new Error("Error in norm: invalid ord value: " + p);
+ }
+ throw new Error("Error in norm: invalid axis: " + axis);
+ }
+ var norm = op({ norm_: norm_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Compute the moving average of a variable.
+ *
+ * Without zeroDebias, the moving average operation is defined by:
+ * `v += delta`
+ * where
+ * `delta = (1 - decay) * (x - v)`
+ *
+ * With zeroDebias (default), the `delta` term is scaled to debias the
+ * effect of the (assumed) zero-initialization of `v`.
+ * `delta /= (1 - decay ^ step)`
+ *
+ * For more details on the zero-debiasing algorithm, see:
+ * https://arxiv.org/abs/1412.6980
+ *
+ * Note that this function is completely stateless and does not keep track of
+ * step count. The step count needs to be maintained by the caller and passed
+ * in as `step`.
+ *
+ * @param v The current moving average value.
+ * @param x New input value, must have the same shape and dtype as `v`.
+ * @param decay The decay factor. Typical values are 0.95 and 0.99.
+ * @param step Step count.
+ * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
+ * @returns The new moving average value.
+ *
+ * @doc {heading: 'Operations', subheading: 'Moving Average'}
+ */
+ function movingAverage_(v, x, decay, step, zeroDebias) {
+ if (zeroDebias === void 0) { zeroDebias = true; }
+ var $v = convertToTensor(v, 'v', 'movingAverage');
+ var $x = convertToTensor(x, 'x', 'movingAverage');
+ var $decay = convertToTensor(decay, 'decay', 'movingAverage');
+ assertTypesMatch($v, $x);
+ assert(arraysEqual($v.shape, $x.shape), function () { return 'Shape mismatch in v and x'; });
+ var one = scalar(1);
+ var oneMinusDecay = sub(one, $decay);
+ var update = mul(sub($x, $v), oneMinusDecay);
+ if (zeroDebias) {
+ assert(step != null, function () { return 'When using zeroDebias: true, step is required.'; });
+ var $step = convertToTensor(step, 'step', 'movingAverage');
+ update = div(update, sub(one, pow($decay, $step)));
+ }
+ return add($v, update);
+ }
+ var movingAverage = op({ movingAverage_: movingAverage_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates a new tensor by applying sparse updates to individual
+ * values or slices within a zero tensor of the given shape tensor according to
+ * indices. This operator is the inverse of the `tf.gatherND` operator which
+ * extracts values or slices from a given tensor.
+ *
+ * ```js
+ * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
+ * const updates = tf.tensor1d([9, 10, 11, 12]);
+ * const shape = [8];
+ * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
+ * ```
+ *
+ * @param indices The tensor contains the indices into the output tensor.
+ * @param updates The tensor contains the value for the indices.
+ * @param shape: The shape of the output tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
+ */
+ function scatterND_(indices, updates, shape) {
+ var $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
+ var $updates = convertToTensor(updates, 'updates', 'scatterND');
+ validateInput$1($updates, $indices, shape);
+ var inputs = { indices: $indices, updates: $updates };
+ var attrs = { shape: shape };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ return ENGINE.runKernel(ScatterNd, inputs, attrs);
+ }
+ var scatterND = op({ scatterND_: scatterND_ });
+
+ /**
+ * Validate sparseToDense inputs.
+ *
+ * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
+ * sparseIndices[i] contains the complete index where sparseValues[i] will be
+ * placed.
+ * @param sparseValues A 0-D or 1-D Tensor. Values
+ * corresponding to each row of sparseIndices, or a scalar value to be used for
+ * all sparse indices.
+ * @param outputShape number[]. Shape of the dense output tensor.
+ * @param validateIndices boolean. indice validation is not supported, error
+ * will be thrown if it is set.
+ */
+ function validateInput(sparseIndices, sparseValues, outputShape, defaultValues) {
+ if (sparseIndices.dtype !== 'int32') {
+ throw new Error('tf.sparseToDense() expects the indices to be int32 type,' +
+ (" but the dtype was " + sparseIndices.dtype + "."));
+ }
+ if (sparseIndices.rank > 2) {
+ throw new Error('sparseIndices should be a scalar, vector, or matrix,' +
+ (" but got shape " + sparseIndices.shape + "."));
+ }
+ var numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
+ var numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
+ if (outputShape.length !== numDims) {
+ throw new Error('outputShape has incorrect number of elements:,' +
+ (" " + outputShape.length + ", should be: " + numDims + "."));
+ }
+ var numValues = sparseValues.size;
+ if (!(sparseValues.rank === 0 ||
+ sparseValues.rank === 1 && numValues === numElems)) {
+ throw new Error('sparseValues has incorrect shape ' +
+ (sparseValues.shape + ", should be [] or [" + numElems + "]"));
+ }
+ if (sparseValues.dtype !== defaultValues.dtype) {
+ throw new Error('sparseValues.dtype must match defaultValues.dtype');
+ }
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Converts a sparse representation into a dense tensor.
+ *
+ * Builds an array dense with shape outputShape such that:
+ *
+ * // If sparseIndices is scalar
+ * dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
+ *
+ * // If sparseIndices is a vector, then for each i
+ * dense[sparseIndices[i]] = sparseValues[i]
+ *
+ * // If sparseIndices is an n by d matrix, then for each i in [0, n)
+ * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
+ * All other values in dense are set to defaultValue. If sparseValues is a
+ * scalar, all sparse indices are set to this single value.
+ *
+ * If indices are repeated the final value is summed over all values for those
+ * indices.
+ *
+ * ```js
+ * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
+ * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
+ * const shape = [8];
+ * tf.sparseToDense(indices, values, shape).print();
+ * ```
+ *
+ * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
+ * sparseIndices[i] contains the complete index where sparseValues[i] will be
+ * placed.
+ * @param sparseValues A 0-D or 1-D Tensor. Values
+ * corresponding to each row of sparseIndices, or a scalar value to be used for
+ * all sparse indices.
+ * @param outputShape Shape of the dense output tensor. the type is inferred.
+ * @param defaultValue Scalar. Value to set for indices not specified in
+ * sparseIndices. Defaults to zero.
+ *
+ * @doc {heading: 'Operations', subheading: 'Normalization'}
+ */
+ function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue) {
+ if (defaultValue === void 0) { defaultValue = 0; }
+ var $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
+ var $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense');
+ var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
+ validateInput($sparseIndices, $sparseValues, outputShape, $defaultValue);
+ var inputs = {
+ sparseIndices: $sparseIndices,
+ sparseValues: $sparseValues,
+ defaultValue: $defaultValue
+ };
+ var attrs = { outputShape: outputShape };
+ return ENGINE.runKernel(SparseToDense, inputs, attrs);
+ }
+ var sparseToDense = op({ sparseToDense_: sparseToDense_ });
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Gather slices from input tensor into a Tensor with shape specified by
+ * `indices`.
+ *
+ * `indices` is an K-dimensional integer tensor, best thought of as a
+ * (K-1)-dimensional tensor of indices into input, where each element defines a
+ * slice of input:
+ * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
+ *
+ * Whereas in `tf.gather`, `indices` defines slices into the first dimension of
+ * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
+ * of input, where N = indices.shape[-1].
+ *
+ * The last dimension of indices can be at most the rank of input:
+ * indices.shape[-1] <= input.rank
+ *
+ * The last dimension of `indices` corresponds to elements
+ * (if indices.shape[-1] == input.rank) or slices
+ * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
+ * input.
+ * The output tensor has shape
+ * indices.shape[:-1] + input.shape[indices.shape[-1]:]
+ *
+ * Note that on CPU, if an out of bound index is found, an error is returned. On
+ * GPU, if an out of bound index is found, a 0 is stored in the corresponding
+ * output value.
+ *
+ * ```js
+ * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
+ * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
+ * tf.gatherND(input, indices).print() // [10, 11]
+ * ```
+ *
+ * @param x The tensor from which to gather values.
+ * @param indices Index tensor, must be of type int32.
+ *
+ * @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
+ */
+ function gatherND_(x, indices) {
+ var $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
+ var $x = convertToTensor(x, 'x', 'gatherND', 'string_or_numeric');
+ var inputs = { params: $x, indices: $indices };
+ return ENGINE.runKernel(GatherNd, inputs);
+ }
+ var gatherND = op({ gatherND_: gatherND_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Normalize noise shape based on provided tensor and noise shape.
+ *
+ * @param x Tensor.
+ * @param noiseShape The shape for the randomly generated keep/drop flags, as
+ * an array of numbers. Optional.
+ * @returns Normalized noise shape.
+ */
+ function getNoiseShape(x, noiseShape) {
+ if (noiseShape == null) {
+ return x.shape.slice();
+ }
+ if (arraysEqual(x.shape, noiseShape)) {
+ return noiseShape;
+ }
+ if (x.shape.length === noiseShape.length) {
+ var newDimension = [];
+ for (var i = 0; i < x.shape.length; i++) {
+ if (noiseShape[i] == null && x.shape[i] != null) {
+ newDimension.push(x.shape[i]);
+ }
+ else {
+ newDimension.push(noiseShape[i]);
+ }
+ }
+ return newDimension;
+ }
+ return noiseShape;
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes dropout.
+ *
+ * ```js
+ * const x = tf.tensor1d([1, 2, 2, 1]);
+ * const rate = 0.75;
+ * const output = tf.dropout(x, rate);
+ * output.print();
+ * ```
+ *
+ * @param x A floating point Tensor or TensorLike.
+ * @param rate A float in the range [0, 1). The probability that each element
+ * of x is discarded.
+ * @param noiseShape An array of numbers of type int32, representing the
+ * shape for randomly generated keep/drop flags. If the noiseShape has null
+ * value, it will be automatically replaced with the x's relative dimension
+ * size. Optional.
+ * @param seed Used to create random seeds. Optional.
+ * @returns A Tensor of the same shape of x.
+ *
+ * @doc {heading: 'Operations', subheading: 'Dropout'}
+ */
+ function dropout_(x, rate, noiseShape, seed) {
+ var $x = convertToTensor(x, 'x', 'dropout');
+ assert($x.dtype === 'float32', function () { return "x has to be a floating point tensor since it's going to be " +
+ ("scaled, but got a " + $x.dtype + " tensor instead."); });
+ assert(rate >= 0 && rate < 1, function () { return "rate must be a float in the range [0, 1), but got " + rate + "."; });
+ if (rate === 0) {
+ return x instanceof Tensor ? $x.clone() : $x;
+ }
+ var $noiseShape = getNoiseShape($x, noiseShape);
+ var keepProb = 1 - rate;
+ var multiplier = div(floor(add(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
+ return mul($x, multiplier);
+ }
+ var dropout = op({ dropout_: dropout_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function enclosingPowerOfTwo(value) {
+ // Return 2**N for integer N such that 2**N >= value.
+ return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
+ }
+ function cosineWindow(windowLength, a, b) {
+ var even = 1 - windowLength % 2;
+ var newValues = new Float32Array(windowLength);
+ for (var i = 0; i < windowLength; ++i) {
+ var cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1);
+ newValues[i] = a - b * Math.cos(cosArg);
+ }
+ return tensor1d(newValues, 'float32');
+ }
+
+ /**
+ * Returns whether the targets are in the top K predictions.
+ *
+ * ```js
+ * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
+ * const targets = tf.tensor1d([2, 0]);
+ * const precision = await tf.inTopKAsync(predictions, targets);
+ * precision.print();
+ * ```
+ * @param predictions 2-D or higher `tf.Tensor` with last dimension being
+ * at least `k`.
+ * @param targets 1-D or higher `tf.Tensor`.
+ * @param k Optional Number of top elements to look at for computing precision,
+ * default to 1.
+ *
+ * @doc {heading: 'Operations', subheading: 'Evaluation'}
+ */
+ function inTopKAsync_(predictions, targets, k) {
+ if (k === void 0) { k = 1; }
+ return __awaiter(this, void 0, void 0, function () {
+ var $predictions, $targets, lastDim, predictionsVals, targetsVals, _a, batch, size, precision, b, offset, vals, valAndInd, i, i;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
+ $targets = convertToTensor(targets, 'targets', 'inTopK');
+ assert($predictions.rank > 1, function () { return 'inTopK() expects the predictions to be of rank 2 or higher, ' +
+ ("but got " + $predictions.rank); });
+ assert($predictions.rank - 1 === $targets.rank, function () { return "predictions rank should be 1 larger than " +
+ "targets rank, but got predictions rank " +
+ ($predictions.rank + " and targets rank " + $targets.rank); });
+ assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, "predictions's shape should be align with the targets' shape, " +
+ 'except the last dimension.');
+ lastDim = $predictions.shape[$predictions.shape.length - 1];
+ assert(k > 0 && k <= lastDim, function () { return "'k' passed to inTopK() must be > 0 && <= the predictions last " +
+ ("dimension (" + lastDim + "), but got " + k); });
+ return [4 /*yield*/, $predictions.data()];
+ case 1:
+ predictionsVals = _b.sent();
+ return [4 /*yield*/, $targets.data()];
+ case 2:
+ targetsVals = _b.sent();
+ _a = __read([predictionsVals.length / lastDim, lastDim], 2), batch = _a[0], size = _a[1];
+ precision = getTypedArrayFromDType('bool', batch);
+ for (b = 0; b < batch; b++) {
+ offset = b * size;
+ vals = predictionsVals.subarray(offset, offset + size);
+ valAndInd = [];
+ for (i = 0; i < vals.length; i++) {
+ valAndInd.push({ value: vals[i], index: i });
+ }
+ valAndInd.sort(function (a, b) { return b.value - a.value; });
+ precision[b] = 0;
+ for (i = 0; i < k; i++) {
+ if (valAndInd[i].index === targetsVals[b]) {
+ precision[b] = 1;
+ break;
+ }
+ }
+ }
+ if (predictions !== $predictions) {
+ $predictions.dispose();
+ }
+ if (targets !== $targets) {
+ $targets.dispose();
+ }
+ // Output precision has the same shape as targets.
+ return [2 /*return*/, tensor(precision, $targets.shape, 'bool')];
+ }
+ });
+ });
+ }
+ var inTopKAsync = inTopKAsync_;
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the derivative of the filter of a 2D convolution.
+ *
+ * @param x The input tensor, of rank 4 or rank 3 of shape
+ * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
+ * @param dy The dy image, of rank 4 or rank 3, of shape
+ * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
+ * @param filterShape The shape of the filter, length 4,
+ * [filterHeight, filterWidth, inDepth, outDepth].
+ * @param strides The strides of the convolution: [strideHeight,
+ * strideWidth].
+ * @param pad A string from: 'same', 'valid'. The type of padding algorithm
+ * used in the forward prop of the op.
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels].
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ */
+ function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat, dimRoundingMode) {
+ if (dataFormat === void 0) { dataFormat = 'NHWC'; }
+ var x4D = x;
+ if (x.rank === 3) {
+ x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
+ }
+ var dy4D = dy;
+ if (dy4D.rank === 3) {
+ dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in conv2dDerFilter: input must be rank 4, but got shape " +
+ (x4D.shape + "."); });
+ assert(dy4D.rank === 4, function () { return "Error in conv2dDerFilter: dy must be rank 4, but got shape " +
+ (dy4D.shape + "."); });
+ assert(filterShape.length === 4, function () { return "Error in conv2dDerFilter: filterShape must be length 4, but got " +
+ (filterShape + "."); });
+ var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
+ var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
+ assert(inDepth === filterShape[2], function () { return "Error in conv2dDerFilter: depth of input " + inDepth + ") must " +
+ ("match input depth in filter (" + filterShape[2] + "."); });
+ assert(outDepth === filterShape[3], function () { return "Error in conv2dDerFilter: depth of dy (" + outDepth + ") must " +
+ ("match output depth for filter (" + filterShape[3] + ")."); });
+ checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
+ var inputs = { x: x4D, dy: dy4D };
+ var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode, filterShape: filterShape };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
+ }
+ var conv2DBackpropFilter = op({ conv2DBackpropFilter_: conv2DBackpropFilter_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // Returns gradient for fused activation.
+ function getFusedDyActivation(dy, y, activation) {
+ if (activation == null || activation === 'linear') {
+ return dy;
+ }
+ if (activation === 'relu') {
+ return mul(dy, step(y));
+ }
+ throw new Error("Cannot compute gradient for fused activation " + activation + ".");
+ }
+ // Returns gradient for fused bias.
+ function getFusedBiasGradient(bias, dyActivation) {
+ var res = dyActivation;
+ var reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
+ if (reduceAxes.length > 0) {
+ res = sum(res, reduceAxes);
+ }
+ return reshape(res, bias.shape);
+ }
+ function applyActivation(x, activation, preluActivationWeights, leakyreluAlpha) {
+ if (activation === 'linear') {
+ return x;
+ }
+ else if (activation === 'relu') {
+ return relu(x);
+ }
+ else if (activation === 'elu') {
+ return elu(x);
+ }
+ else if (activation === 'relu6') {
+ return relu6(x);
+ }
+ else if (activation === 'prelu') {
+ return prelu(x, preluActivationWeights);
+ }
+ else if (activation === 'leakyrelu') {
+ return leakyRelu(x, leakyreluAlpha);
+ }
+ else if (activation === 'sigmoid') {
+ return sigmoid(x);
+ }
+ throw new Error("Unknown fused activation " + activation + ".");
+ }
+ // Whether we should call fused ops.
+ var shouldFuse = function (gradientDepth, activation) {
+ var gradientMode = gradientDepth > 0;
+ return !gradientMode || activation === 'linear';
+ };
+
+ /**
+ * Computes a 2D convolution over the input x, optionally fused with adding a
+ * bias and applying an activation.
+ *
+ * ```js
+ * const inputDepth = 2;
+ * const inShape = [2, 2, 2, inputDepth];
+ * const outputDepth = 2;
+ * const fSize = 1;
+ * const pad = 0;
+ * const strides = 1;
+ *
+ * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ * 16], inShape);
+ * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
+ * outputDepth]);
+ *
+ * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
+ * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
+ * ```
+ *
+ * @param obj An object with the following properties:
+ * @param x The input tensor, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param filter The filter, rank 4, of shape
+ * `[filterHeight, filterWidth, inDepth, outDepth]`.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid` output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels]. Only "NHWC" is currently supported.
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ * @param bias Tensor to be added to the result.
+ * @param activation Name of activation kernel (defaults to `linear`) to be
+ * applied
+ * after biasAdd.
+ * @param preluActivationWeights Tensor of prelu weights to be applied as part
+ * of a `prelu` activation, typically the same shape as `x`.
+ * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
+ * activation.
+ */
+ function fusedConv2d_(_a) {
+ var _b;
+ var x = _a.x, filter = _a.filter, strides = _a.strides, pad = _a.pad, _c = _a.dataFormat, dataFormat = _c === void 0 ? 'NHWC' : _c, _d = _a.dilations, dilations = _d === void 0 ? [1, 1] : _d, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _e = _a.activation, activation = _e === void 0 ? 'linear' : _e, preluActivationWeights = _a.preluActivationWeights, leakyreluAlpha = _a.leakyreluAlpha;
+ activation = activation || 'linear';
+ if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
+ var result = conv2d$1(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
+ if (bias != null) {
+ result = add(result, bias);
+ }
+ return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
+ }
+ var $x = convertToTensor(x, 'x', 'conv2d', 'float32');
+ var $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in fused conv2d: input must be rank 4, but got rank " +
+ (x4D.rank + "."); });
+ assert($filter.rank === 4, function () { return "Error in fused conv2d: filter must be rank 4, but got rank " +
+ ($filter.rank + "."); });
+ checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode);
+ assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in conv2d: depth of input (" + x4D.shape[3] + ") must match " +
+ ("input depth for filter " + $filter.shape[2] + "."); });
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' +
+ ("Got strides " + strides + " and dilations '" + dilations + "'"); });
+ assert(dataFormat === 'NHWC', function () { return "Error in conv2d: got dataFormat of " + dataFormat + " but only NHWC is currently supported."; });
+ var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
+ var $bias;
+ if (bias != null) {
+ $bias = convertToTensor(bias, 'bias', 'fused conv2d');
+ _b = __read(makeTypesMatch($bias, $x), 1), $bias = _b[0];
+ assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
+ }
+ var $preluActivationWeights;
+ if (preluActivationWeights != null) {
+ $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
+ }
+ var grad = function (dy, saved) {
+ var _a = __read(saved, 4), $filter = _a[0], x4D = _a[1], y = _a[2], $bias = _a[3];
+ var dyActivation = getFusedDyActivation(dy, y, activation);
+ assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of fused conv2D: ' +
+ "dilation rates greater than 1 " +
+ ("are not yet supported in gradients. Got dilations '" + dilations + "'"); });
+ var xDer = conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);
+ var filterDer = conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);
+ var der = [xDer, filterDer];
+ if ($bias != null) {
+ var biasDer = getFusedBiasGradient($bias, dyActivation);
+ der.push(biasDer);
+ }
+ return der;
+ };
+ var inputs = {
+ x: x4D,
+ filter: $filter,
+ bias: $bias,
+ preluActivationWeights: $preluActivationWeights
+ };
+ var attrs = {
+ strides: strides,
+ pad: pad,
+ dataFormat: dataFormat,
+ dilations: dilations,
+ dimRoundingMode: dimRoundingMode,
+ activation: activation,
+ leakyreluAlpha: leakyreluAlpha
+ };
+ // Depending on the the params passed in we will have different number of
+ // inputs and thus a a different number of elements in the gradient.
+ if (bias == null) {
+ var customOp = customGrad(function (x4D, filter, save) {
+ var res =
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ ENGINE.runKernel(FusedConv2D, inputs, attrs);
+ save([filter, x4D, res]);
+ if (reshapedTo4D) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return { value: res, gradFunc: grad };
+ });
+ return customOp(x4D, $filter);
+ }
+ else {
+ var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
+ var res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
+ save([filter, x4D, res, bias]);
+ if (reshapedTo4D) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return { value: res, gradFunc: grad };
+ });
+ return customOpWithBias(x4D, $filter, $bias);
+ }
+ }
+ var conv2d = op({ fusedConv2d_: fusedConv2d_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations, dimRoundingMode) {
+ if (dilations === void 0) { dilations = [1, 1]; }
+ var x4D = x;
+ if (x.rank === 3) {
+ x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
+ }
+ var dy4D = dy;
+ if (dy4D.rank === 3) {
+ dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
+ }
+ var inputs = { x: x4D, dy: dy4D };
+ var attrs = { strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dilations: dilations, filterShape: filterShape };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
+ }
+ var depthwiseConv2dNativeBackpropFilter = op({ depthwiseConv2dNativeBackpropFilter_: depthwiseConv2dNativeBackpropFilter_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations, dimRoundingMode) {
+ if (dilations === void 0) { dilations = [1, 1]; }
+ var dy4D = dy;
+ var reshapedTo4D = false;
+ if (dy.rank === 3) {
+ reshapedTo4D = true;
+ dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
+ }
+ var inputs = { dy: dy4D, filter: filter };
+ var attrs = { strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dilations: dilations, inputShape: xShape };
+ var res =
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var depthwiseConv2dNativeBackpropInput = op({ depthwiseConv2dNativeBackpropInput_: depthwiseConv2dNativeBackpropInput_ });
+
+ /**
+ * Computes depthwise 2D convolution, optionally fused with adding a
+ * bias and applying an activation.
+ *
+ * Given a 4D `input` array and a `filter` array of shape
+ * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
+ * `inChannels` convolutional filters of depth 1, this op applies a
+ * different filter to each input channel (expanding from 1 channel to
+ * `channelMultiplier` channels for each), then concatenates the results
+ * together. The output has `inChannels * channelMultiplier` channels.
+ *
+ * See
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
+ * for more details.
+ *
+ * @param obj An object with the following properties:
+ * @param x The input tensor, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
+ * assumed.
+ * @param filter The filter tensor, rank 4, of shape
+ * `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
+ * @param strides The strides of the convolution: `[strideHeight,
+ * strideWidth]`. If strides is a single number, then `strideHeight ==
+ * strideWidth`.
+ * @param pad The type of padding algorithm.
+ * - `same` and stride 1: output will be of same size as input,
+ * regardless of filter size.
+ * - `valid`: output will be smaller than input if filter is larger
+ * than 1x1.
+ * - For more info, see this guide:
+ * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
+ * https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
+ * in which we sample input values across the height and width dimensions
+ * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
+ * 1, then all values of `strides` must be 1.
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
+ * "NHWC". Specify the data format of the input and output data. With the
+ * default format "NHWC", the data is stored in the order of: [batch,
+ * height, width, channels]. Only "NHWC" is currently supported.
+ * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
+ * provided, it will default to truncate.
+ * @param bias Tensor to be added to the result.
+ * @param activation Name of activation kernel (defaults to `linear`).
+ * @param preluActivationWeights Tensor of prelu weights to be applied as part
+ * of a `prelu` activation, typically the same shape as `x`.
+ * @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
+ * activation.
+ */
+ function fusedDepthwiseConv2d_(_a) {
+ var _b;
+ var x = _a.x, filter = _a.filter, strides = _a.strides, pad = _a.pad, _c = _a.dataFormat, dataFormat = _c === void 0 ? 'NHWC' : _c, _d = _a.dilations, dilations = _d === void 0 ? [1, 1] : _d, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _e = _a.activation, activation = _e === void 0 ? 'linear' : _e, preluActivationWeights = _a.preluActivationWeights, leakyreluAlpha = _a.leakyreluAlpha;
+ if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
+ var result = depthwiseConv2d$1(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
+ if (bias != null) {
+ result = add(result, bias);
+ }
+ return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
+ }
+ var $x = convertToTensor(x, 'x', 'depthwiseConv2d', 'float32');
+ var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d', 'float32');
+ var x4D = $x;
+ var reshapedTo4D = false;
+ if ($x.rank === 3) {
+ reshapedTo4D = true;
+ x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
+ }
+ assert(x4D.rank === 4, function () { return "Error in fused depthwiseConv2d: input must be rank 4, but got " +
+ ("rank " + x4D.rank + "."); });
+ assert($filter.rank === 4, function () { return "Error in fused depthwiseConv2d: filter must be rank 4, " +
+ ("but got rank " + $filter.rank + "."); });
+ assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in fused depthwiseConv2d: number of input channels " +
+ ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") +
+ ("filter " + $filter.shape[2] + "."); });
+ if (dilations == null) {
+ dilations = [1, 1];
+ }
+ assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in fused depthwiseConv2d: Either strides or dilations must ' +
+ ("be 1. Got strides " + strides + " and dilations '" + dilations + "'"); });
+ checkPadOnDimRoundingMode('fused depthwiseConv2d', pad, dimRoundingMode);
+ var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
+ var $bias;
+ if (bias != null) {
+ $bias = convertToTensor(bias, 'bias', 'fused conv2d');
+ _b = __read(makeTypesMatch($bias, $x), 1), $bias = _b[0];
+ assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
+ }
+ var $preluActivationWeights;
+ if (preluActivationWeights != null) {
+ $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
+ }
+ var grad = function (dy, saved) {
+ assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of fused depthwiseConv2d: dilation rates ' +
+ "greater than 1 are not yet supported. Got dilations " +
+ ("'" + dilations + "'"); });
+ var _a = __read(saved, 4), $filter = _a[0], x4D = _a[1], y = _a[2], bias = _a[3];
+ var dyActivation = getFusedDyActivation(dy, y, activation);
+ var xDer = depthwiseConv2dNativeBackpropInput(x4D.shape, dyActivation, $filter, strides, pad, dilations, dimRoundingMode);
+ var filterDer = depthwiseConv2dNativeBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad, dilations, dimRoundingMode);
+ if (bias != null) {
+ var biasDer = getFusedBiasGradient($bias, dyActivation);
+ return [xDer, filterDer, biasDer];
+ }
+ return [xDer, filterDer];
+ };
+ var inputs = {
+ x: x4D,
+ filter: $filter,
+ bias: $bias,
+ preluActivationWeights: $preluActivationWeights
+ };
+ var attrs = {
+ strides: strides,
+ pad: pad,
+ dataFormat: dataFormat,
+ dilations: dilations,
+ dimRoundingMode: dimRoundingMode,
+ activation: activation,
+ leakyreluAlpha: leakyreluAlpha
+ };
+ // Depending on the the params passed in we will have different number of
+ // inputs and thus a a different number of elements in the gradient.
+ if (bias == null) {
+ var customOp = customGrad(function (x4D, filter, save) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
+ save([filter, x4D, res]);
+ if (reshapedTo4D) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return { value: res, gradFunc: grad };
+ });
+ return customOp(x4D, $filter);
+ }
+ else {
+ var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
+ save([filter, x4D, res, bias]);
+ if (reshapedTo4D) {
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return { value: res, gradFunc: grad };
+ });
+ return customOpWithBias(x4D, $filter, $bias);
+ }
+ }
+ var depthwiseConv2d = op({ fusedDepthwiseConv2d_: fusedDepthwiseConv2d_ });
+
+ /**
+ * Computes the dot product of two matrices with optional activation and bias.
+ *
+ * ```js
+ * const a = tf.tensor2d([-1, -2], [1, 2]);
+ * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
+ * const bias = tf.tensor2d([1, 2], [1, 2]);
+ *
+ * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
+ * ```
+ *
+ * @param obj An object with the following properties:
+ * - `a` First matrix in dot product operation.
+ * - `b` Second matrix in dot product operation.
+ * - `transposeA` If true, `a` is transposed before multiplication.
+ * - `transposeB` If true, `b` is transposed before multiplication.
+ * - `bias` Matrix to be added to the result.
+ * - `activation` Name of activation kernel (defaults to `linear`).
+ * - `preluActivationWeights` Tensor of prelu weights.
+ * - `leakyreluAlpha` Alpha of leakyrelu.
+ */
+ function fusedMatMul_(_a) {
+ var _b, _c;
+ var a = _a.a, b = _a.b, _d = _a.transposeA, transposeA = _d === void 0 ? false : _d, _e = _a.transposeB, transposeB = _e === void 0 ? false : _e, bias = _a.bias, _f = _a.activation, activation = _f === void 0 ? 'linear' : _f, preluActivationWeights = _a.preluActivationWeights, leakyreluAlpha = _a.leakyreluAlpha;
+ if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
+ var result = matMul$1(a, b, transposeA, transposeB);
+ if (bias != null) {
+ result = add(result, bias);
+ }
+ return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
+ }
+ var $a = convertToTensor(a, 'a', 'fused matMul');
+ var $b = convertToTensor(b, 'b', 'fused matMul');
+ _b = __read(makeTypesMatch($a, $b), 2), $a = _b[0], $b = _b[1];
+ var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
+ var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
+ var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
+ var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
+ var outerDimsA = $a.shape.slice(0, -2);
+ var outerDimsB = $b.shape.slice(0, -2);
+ var batchDimA = sizeFromShape(outerDimsA);
+ var batchDimB = sizeFromShape(outerDimsB);
+ assert(innerShapeA === innerShapeB, function () { return "Error in fused matMul: inner shapes (" + innerShapeA + ") and (" +
+ (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") +
+ ($b.shape + " and transposeA=" + transposeA) +
+ (" and transposeB=" + transposeB + " must match."); });
+ var outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
+ var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
+ var a3D = transposeA ?
+ reshape($a, [batchDimA, innerShapeA, outerShapeA]) :
+ reshape($a, [batchDimA, outerShapeA, innerShapeA]);
+ var b3D = transposeB ?
+ reshape($b, [batchDimB, outerShapeB, innerShapeB]) :
+ reshape($b, [batchDimB, innerShapeB, outerShapeB]);
+ var $bias;
+ if (bias != null) {
+ $bias = convertToTensor(bias, 'bias', 'fused matMul');
+ _c = __read(makeTypesMatch($bias, $a), 1), $bias = _c[0];
+ assertAndGetBroadcastShape(outShape, $bias.shape);
+ }
+ var $preluActivationWeights;
+ if (preluActivationWeights != null) {
+ $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
+ }
+ var grad = function (dy, saved) {
+ var _a = __read(saved, 4), a3D = _a[0], b3D = _a[1], y = _a[2], $bias = _a[3];
+ // we reshape dy because the result of the forward is not
+ // necessarily going to be a 3d tensor due to a reshape done at the end of
+ // the customOp.
+ var dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation);
+ var aDer;
+ var bDer;
+ if (!transposeA && !transposeB) {
+ aDer = matMul$1(dyActivation, b3D, false, true);
+ bDer = matMul$1(a3D, dyActivation, true, false);
+ }
+ else if (!transposeA && transposeB) {
+ aDer = matMul$1(dyActivation, b3D, false, false);
+ bDer = matMul$1(dyActivation, a3D, true, false);
+ }
+ else if (transposeA && !transposeB) {
+ aDer = matMul$1(b3D, dyActivation, false, true);
+ bDer = matMul$1(a3D, dyActivation, false, false);
+ }
+ else {
+ aDer = matMul$1(b3D, dyActivation, true, true);
+ bDer = matMul$1(dyActivation, a3D, true, true);
+ }
+ if (bias != null) {
+ var biasDer = getFusedBiasGradient($bias, dyActivation);
+ return [aDer, bDer, biasDer];
+ }
+ else {
+ return [aDer, bDer];
+ }
+ };
+ var inputs = {
+ a: a3D,
+ b: b3D,
+ bias: $bias,
+ preluActivationWeights: $preluActivationWeights
+ };
+ var attrs = { transposeA: transposeA, transposeB: transposeB, activation: activation, leakyreluAlpha: leakyreluAlpha };
+ // Depending on the the params passed in we will have different number of
+ // inputs and thus a a different number of elements in the gradient.
+ if (bias == null) {
+ var customOp = customGrad(function (a3D, b3D, save) {
+ var res =
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ ENGINE.runKernel(_FusedMatMul, inputs, attrs);
+ save([a3D, b3D, res]);
+ return { value: reshape(res, outShape), gradFunc: grad };
+ });
+ return customOp(a3D, b3D);
+ }
+ else {
+ var customOpWithBias = customGrad(function (a3D, b3D, $bias, save) {
+ var res =
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ ENGINE.runKernel(_FusedMatMul, inputs, attrs);
+ save([a3D, b3D, res, $bias]);
+ return { value: reshape(res, outShape), gradFunc: grad };
+ });
+ return customOpWithBias(a3D, b3D, $bias);
+ }
+ }
+ var matMul = op({ fusedMatMul_: fusedMatMul_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+ var fused_ops = {
+ __proto__: null,
+ conv2d: conv2d,
+ depthwiseConv2d: depthwiseConv2d,
+ matMul: matMul
+ };
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Generate a hamming window.
+ *
+ * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
+ *
+ * ```js
+ * tf.signal.hammingWindow(10).print();
+ * ```
+ * @param The length of window
+ *
+ * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
+ */
+ function hammingWindow_(windowLength) {
+ return cosineWindow(windowLength, 0.54, 0.46);
+ }
+ var hammingWindow = op({ hammingWindow_: hammingWindow_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Generate a Hann window.
+ *
+ * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
+ *
+ * ```js
+ * tf.signal.hannWindow(10).print();
+ * ```
+ * @param The length of window
+ *
+ * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
+ */
+ function hannWindow_(windowLength) {
+ return cosineWindow(windowLength, 0.5, 0.5);
+ }
+ var hannWindow = op({ hannWindow_: hannWindow_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Expands input into frames of frameLength.
+ * Slides a window size with frameStep.
+ *
+ * ```js
+ * tf.signal.frame([1, 2, 3], 2, 1).print();
+ * ```
+ * @param signal The input tensor to be expanded
+ * @param frameLength Length of each frame
+ * @param frameStep The frame hop size in samples.
+ * @param padEnd Whether to pad the end of signal with padValue.
+ * @param padValue An number to use where the input signal does
+ * not exist when padEnd is True.
+ *
+ * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
+ */
+ function frame_(signal, frameLength, frameStep, padEnd, padValue) {
+ if (padEnd === void 0) { padEnd = false; }
+ if (padValue === void 0) { padValue = 0; }
+ var start = 0;
+ var output = [];
+ while (start + frameLength <= signal.size) {
+ output.push(slice(signal, start, frameLength));
+ start += frameStep;
+ }
+ if (padEnd) {
+ while (start < signal.size) {
+ var padLen = (start + frameLength) - signal.size;
+ var pad = concat([
+ slice(signal, start, frameLength - padLen), fill([padLen], padValue)
+ ]);
+ output.push(pad);
+ start += frameStep;
+ }
+ }
+ if (output.length === 0) {
+ return tensor2d([], [0, frameLength]);
+ }
+ return reshape(concat(output), [output.length, frameLength]);
+ }
+ var frame = op({ frame_: frame_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the Short-time Fourier Transform of signals
+ * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
+ *
+ * ```js
+ * const input = tf.tensor1d([1, 1, 1, 1, 1])
+ * tf.signal.stft(input, 3, 1).print();
+ * ```
+ * @param signal 1-dimensional real value tensor.
+ * @param frameLength The window length of samples.
+ * @param frameStep The number of samples to step.
+ * @param fftLength The size of the FFT to apply.
+ * @param windowFn A callable that takes a window length and returns 1-d tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
+ */
+ function stft_(signal, frameLength, frameStep, fftLength, windowFn) {
+ if (windowFn === void 0) { windowFn = hannWindow; }
+ if (fftLength == null) {
+ fftLength = enclosingPowerOfTwo(frameLength);
+ }
+ var framedSignal = frame(signal, frameLength, frameStep);
+ var windowedSignal = mul(framedSignal, windowFn(frameLength));
+ return rfft(windowedSignal, fftLength);
+ }
+ var stft = op({ stft_: stft_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Extracts crops from the input image tensor and resizes them using bilinear
+ * sampling or nearest neighbor sampling (possibly with aspect ratio change)
+ * to a common output size specified by cropSize.
+ *
+ * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
+ * where imageHeight and imageWidth must be positive, specifying the
+ * batch of images from which to take crops
+ * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
+ * coordinates of the box in the boxInd[i]'th image in the batch
+ * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
+ * `[0, batch)` that specifies the image that the `i`-th box refers to.
+ * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
+ * specifying the size to which all crops are resized to.
+ * @param method Optional string from `'bilinear' | 'nearest'`,
+ * defaults to bilinear, which specifies the sampling method for resizing
+ * @param extrapolationValue A threshold for deciding when to remove boxes based
+ * on score. Defaults to 0.
+ * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function cropAndResize_(image, boxes, boxInd, cropSize, method, extrapolationValue) {
+ if (method === void 0) { method = 'bilinear'; }
+ if (extrapolationValue === void 0) { extrapolationValue = 0; }
+ var $image = convertToTensor(image, 'image', 'cropAndResize');
+ var $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
+ var $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
+ var numBoxes = $boxes.shape[0];
+ assert($image.rank === 4, function () { return 'Error in cropAndResize: image must be rank 4,' +
+ ("but got rank " + $image.rank + "."); });
+ assert($boxes.rank === 2 && $boxes.shape[1] === 4, function () { return "Error in cropAndResize: boxes must be have size [" + numBoxes + ",4] " +
+ ("but had shape " + $boxes.shape + "."); });
+ assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, function () { return "Error in cropAndResize: boxInd must be have size [" + numBoxes + "] " +
+ ("but had shape " + $boxes.shape + "."); });
+ assert(cropSize.length === 2, function () { return "Error in cropAndResize: cropSize must be of length 2, but got " +
+ ("length " + cropSize.length + "."); });
+ assert(cropSize[0] >= 1 && cropSize[1] >= 1, function () { return "cropSize must be atleast [1,1], but was " + cropSize; });
+ assert(method === 'bilinear' || method === 'nearest', function () { return "method must be bilinear or nearest, but was " + method; });
+ var inputs = { image: $image, boxes: $boxes, boxInd: $boxInd };
+ var attrs = { method: method, extrapolationValue: extrapolationValue, cropSize: cropSize };
+ var res = ENGINE.runKernel(CropAndResize, inputs, attrs);
+ return res;
+ }
+ var cropAndResize = op({ cropAndResize_: cropAndResize_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Flips the image left to right. Currently available in the CPU, WebGL, and
+ * WASM backends.
+ *
+ * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
+ */
+ /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
+ function flipLeftRight_(image) {
+ var $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
+ assert($image.rank === 4, function () { return 'Error in flipLeftRight: image must be rank 4,' +
+ ("but got rank " + $image.rank + "."); });
+ var inputs = { image: $image };
+ var res = ENGINE.runKernel(FlipLeftRight, inputs, {});
+ return res;
+ }
+ var flipLeftRight = op({ flipLeftRight_: flipLeftRight_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Converts images from grayscale to RGB format.
+ *
+ * @param image A grayscale tensor to convert. The `image`'s last dimension must
+ * be size 1 with at least a two-dimensional shape.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function grayscaleToRGB_(image) {
+ var $image = convertToTensor(image, 'image', 'grayscaleToRGB');
+ var lastDimsIdx = $image.rank - 1;
+ var lastDims = $image.shape[lastDimsIdx];
+ assert($image.rank >= 2, function () { return 'Error in grayscaleToRGB: images must be at least rank 2, ' +
+ ("but got rank " + $image.rank + "."); });
+ assert(lastDims === 1, function () { return 'Error in grayscaleToRGB: last dimension of a grayscale image ' +
+ ("should be size 1, but got size " + lastDims + "."); });
+ var reps = new Array($image.rank);
+ reps.fill(1, 0, lastDimsIdx);
+ reps[lastDimsIdx] = 3;
+ return tile($image, reps);
+ }
+ var grayscaleToRGB = op({ grayscaleToRGB_: grayscaleToRGB_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Rotates the input image tensor counter-clockwise with an optional offset
+ * center of rotation. Currently available in the CPU, WebGL, and WASM backends.
+ *
+ * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
+ * @param radians The amount of rotation.
+ * @param fillValue The value to fill in the empty space leftover
+ * after rotation. Can be either a single grayscale value (0-255), or an
+ * array of three numbers `[red, green, blue]` specifying the red, green,
+ * and blue channels. Defaults to `0` (black).
+ * @param center The center of rotation. Can be either a single value (0-1), or
+ * an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
+ * the image around its center).
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function rotateWithOffset_(image, radians, fillValue, center) {
+ if (fillValue === void 0) { fillValue = 0; }
+ if (center === void 0) { center = 0.5; }
+ var $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
+ assert($image.rank === 4, function () { return 'Error in rotateWithOffset: image must be rank 4,' +
+ ("but got rank " + $image.rank + "."); });
+ var inputs = { image: $image };
+ var attrs = { radians: radians, fillValue: fillValue, center: center };
+ var res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
+ return res;
+ }
+ var rotateWithOffset = op({ rotateWithOffset_: rotateWithOffset_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
+ if (iouThreshold == null) {
+ iouThreshold = 0.5;
+ }
+ if (scoreThreshold == null) {
+ scoreThreshold = Number.NEGATIVE_INFINITY;
+ }
+ if (softNmsSigma == null) {
+ softNmsSigma = 0.0;
+ }
+ var numBoxes = boxes.shape[0];
+ maxOutputSize = Math.min(maxOutputSize, numBoxes);
+ assert(0 <= iouThreshold && iouThreshold <= 1, function () { return "iouThreshold must be in [0, 1], but was '" + iouThreshold + "'"; });
+ assert(boxes.rank === 2, function () { return "boxes must be a 2D tensor, but was of rank '" + boxes.rank + "'"; });
+ assert(boxes.shape[1] === 4, function () { return "boxes must have 4 columns, but 2nd dimension was " + boxes.shape[1]; });
+ assert(scores.rank === 1, function () { return 'scores must be a 1D tensor'; });
+ assert(scores.shape[0] === numBoxes, function () { return "scores has incompatible shape with boxes. Expected " + numBoxes + ", " +
+ ("but was " + scores.shape[0]); });
+ assert(0 <= softNmsSigma && softNmsSigma <= 1, function () { return "softNmsSigma must be in [0, 1], but was '" + softNmsSigma + "'"; });
+ return { maxOutputSize: maxOutputSize, iouThreshold: iouThreshold, scoreThreshold: scoreThreshold, softNmsSigma: softNmsSigma };
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union).
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @return A 1D tensor with the selected box indices.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression', 'float32');
+ var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression', 'float32');
+ var inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
+ maxOutputSize = inputs.maxOutputSize;
+ iouThreshold = inputs.iouThreshold;
+ scoreThreshold = inputs.scoreThreshold;
+ var attrs = { maxOutputSize: maxOutputSize, iouThreshold: iouThreshold, scoreThreshold: scoreThreshold };
+ return ENGINE.runKernel(NonMaxSuppressionV3, { boxes: $boxes, scores: $scores }, attrs);
+ }
+ var nonMaxSuppression = op({ nonMaxSuppression_: nonMaxSuppression_ });
+
+ /**
+ * @license
+ * Copyright 2019 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Inserts a value into a sorted array. This method allows duplicate, meaning it
+ * allows inserting duplicate value, in which case, the element will be inserted
+ * at the lowest index of the value.
+ * @param arr The array to modify.
+ * @param element The element to insert.
+ * @param comparator Optional. If no comparator is specified, elements are
+ * compared using array_util.defaultComparator, which is suitable for Strings
+ * and Numbers in ascending arrays. If the array contains multiple instances of
+ * the target value, the left-most instance will be returned. To provide a
+ * comparator, it should take 2 arguments to compare and return a negative,
+ * zero, or a positive number.
+ */
+ function binaryInsert(arr, element, comparator) {
+ var index = binarySearch(arr, element, comparator);
+ var insertionPoint = index < 0 ? -(index + 1) : index;
+ arr.splice(insertionPoint, 0, element);
+ }
+ /**
+ * Searches the array for the target using binary search, returns the index
+ * of the found element, or position to insert if element not found. If no
+ * comparator is specified, elements are compared using array_
+ * util.defaultComparator, which is suitable for Strings and Numbers in
+ * ascending arrays. If the array contains multiple instances of the target
+ * value, the left-most instance will be returned.
+ * @param arr The array to be searched in.
+ * @param target The target to be searched for.
+ * @param comparator Should take 2 arguments to compare and return a negative,
+ * zero, or a positive number.
+ * @return Lowest index of the target value if found, otherwise the insertion
+ * point where the target should be inserted, in the form of
+ * (-insertionPoint - 1).
+ */
+ function binarySearch(arr, target, comparator) {
+ return binarySearch_(arr, target, comparator || defaultComparator);
+ }
+ /**
+ * Compares its two arguments for order.
+ * @param a The first element to be compared.
+ * @param b The second element to be compared.
+ * @return A negative number, zero, or a positive number as the first
+ * argument is less than, equal to, or greater than the second.
+ */
+ function defaultComparator(a, b) {
+ return a > b ? 1 : a < b ? -1 : 0;
+ }
+ function binarySearch_(arr, target, comparator) {
+ var left = 0;
+ var right = arr.length;
+ var middle = 0;
+ var found = false;
+ while (left < right) {
+ middle = left + ((right - left) >>> 1);
+ var compareResult = comparator(target, arr[middle]);
+ if (compareResult > 0) {
+ left = middle + 1;
+ }
+ else {
+ right = middle;
+ // If compareResult is 0, the value is found. We record it is found,
+ // and then keep looking because there may be duplicate.
+ found = !compareResult;
+ }
+ }
+ return found ? left : -left - 1;
+ }
+
+ function nonMaxSuppressionV3Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
+ return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */);
+ }
+ function nonMaxSuppressionV4Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
+ return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 /* softNmsSigma */, false /* returnScoresTensor */, padToMaxOutputSize /* padToMaxOutputSize */, true
+ /* returnValidOutputs */ );
+ }
+ function nonMaxSuppressionV5Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
+ return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true /* returnScoresTensor */);
+ }
+ function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor, padToMaxOutputSize, returnValidOutputs) {
+ if (returnScoresTensor === void 0) { returnScoresTensor = false; }
+ if (padToMaxOutputSize === void 0) { padToMaxOutputSize = false; }
+ if (returnValidOutputs === void 0) { returnValidOutputs = false; }
+ // The list is sorted in ascending order, so that we can always pop the
+ // candidate with the largest score in O(1) time.
+ var candidates = [];
+ for (var i = 0; i < scores.length; i++) {
+ if (scores[i] > scoreThreshold) {
+ candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
+ }
+ }
+ candidates.sort(ascendingComparator);
+ // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
+ // before.
+ var scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
+ var selectedIndices = [];
+ var selectedScores = [];
+ while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
+ var candidate = candidates.pop();
+ var originalScore = candidate.score, boxIndex = candidate.boxIndex, suppressBeginIndex = candidate.suppressBeginIndex;
+ if (originalScore < scoreThreshold) {
+ break;
+ }
+ // Overlapping boxes are likely to have similar scores, therefore we
+ // iterate through the previously selected boxes backwards in order to
+ // see if candidate's score should be suppressed. We use
+ // suppressBeginIndex to track and ensure a candidate can be suppressed
+ // by a selected box no more than once. Also, if the overlap exceeds
+ // iouThreshold, we simply ignore the candidate.
+ var ignoreCandidate = false;
+ for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
+ var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
+ if (iou >= iouThreshold) {
+ ignoreCandidate = true;
+ break;
+ }
+ candidate.score =
+ candidate.score * suppressWeight(iouThreshold, scale, iou);
+ if (candidate.score <= scoreThreshold) {
+ break;
+ }
+ }
+ // At this point, if `candidate.score` has not dropped below
+ // `scoreThreshold`, then we know that we went through all of the
+ // previous selections and can safely update `suppressBeginIndex` to the
+ // end of the selected array. Then we can re-insert the candidate with
+ // the updated score and suppressBeginIndex back in the candidate list.
+ // If on the other hand, `candidate.score` has dropped below the score
+ // threshold, we will not add it back to the candidates list.
+ candidate.suppressBeginIndex = selectedIndices.length;
+ if (!ignoreCandidate) {
+ // Candidate has passed all the tests, and is not suppressed, so
+ // select the candidate.
+ if (candidate.score === originalScore) {
+ selectedIndices.push(boxIndex);
+ selectedScores.push(candidate.score);
+ }
+ else if (candidate.score > scoreThreshold) {
+ // Candidate's score is suppressed but is still high enough to be
+ // considered, so add back to the candidates list.
+ binaryInsert(candidates, candidate, ascendingComparator);
+ }
+ }
+ }
+ // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
+ var validOutputs = selectedIndices.length;
+ var elemsToPad = maxOutputSize - validOutputs;
+ if (padToMaxOutputSize && elemsToPad > 0) {
+ selectedIndices.push.apply(selectedIndices, __spread(new Array(elemsToPad).fill(0)));
+ selectedScores.push.apply(selectedScores, __spread(new Array(elemsToPad).fill(0.0)));
+ }
+ var result = { selectedIndices: selectedIndices };
+ if (returnScoresTensor) {
+ result['selectedScores'] = selectedScores;
+ }
+ if (returnValidOutputs) {
+ result['validOutputs'] = validOutputs;
+ }
+ return result;
+ }
+ function intersectionOverUnion(boxes, i, j) {
+ var iCoord = boxes.subarray(i * 4, i * 4 + 4);
+ var jCoord = boxes.subarray(j * 4, j * 4 + 4);
+ var yminI = Math.min(iCoord[0], iCoord[2]);
+ var xminI = Math.min(iCoord[1], iCoord[3]);
+ var ymaxI = Math.max(iCoord[0], iCoord[2]);
+ var xmaxI = Math.max(iCoord[1], iCoord[3]);
+ var yminJ = Math.min(jCoord[0], jCoord[2]);
+ var xminJ = Math.min(jCoord[1], jCoord[3]);
+ var ymaxJ = Math.max(jCoord[0], jCoord[2]);
+ var xmaxJ = Math.max(jCoord[1], jCoord[3]);
+ var areaI = (ymaxI - yminI) * (xmaxI - xminI);
+ var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
+ if (areaI <= 0 || areaJ <= 0) {
+ return 0.0;
+ }
+ var intersectionYmin = Math.max(yminI, yminJ);
+ var intersectionXmin = Math.max(xminI, xminJ);
+ var intersectionYmax = Math.min(ymaxI, ymaxJ);
+ var intersectionXmax = Math.min(xmaxI, xmaxJ);
+ var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
+ Math.max(intersectionXmax - intersectionXmin, 0.0);
+ return intersectionArea / (areaI + areaJ - intersectionArea);
+ }
+ // A Gaussian penalty function, this method always returns values in [0, 1].
+ // The weight is a function of similarity, the more overlap two boxes are, the
+ // smaller the weight is, meaning highly overlapping boxe will be significantly
+ // penalized. On the other hand, a non-overlapping box will not be penalized.
+ function suppressWeight(iouThreshold, scale, iou) {
+ var weight = Math.exp(scale * iou * iou);
+ return iou <= iouThreshold ? weight : 0.0;
+ }
+ function ascendingComparator(c1, c2) {
+ // For objects with same scores, we make the object with the larger index go
+ // first. In an array that pops from the end, this means that the object with
+ // the smaller index will be popped first. This ensures the same output as
+ // the TensorFlow python version.
+ return (c1.score - c2.score) ||
+ ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
+ }
+
+ /**
+ * Performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union).
+ *
+ * This is the async version of `nonMaxSuppression`
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @return A 1D tensor with the selected box indices.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ return __awaiter(this, void 0, void 0, function () {
+ var $boxes, $scores, inputs, boxesAndScores, boxesVals, scoresVals, selectedIndices;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
+ $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
+ inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
+ maxOutputSize = inputs.maxOutputSize;
+ iouThreshold = inputs.iouThreshold;
+ scoreThreshold = inputs.scoreThreshold;
+ return [4 /*yield*/, Promise.all([$boxes.data(), $scores.data()])];
+ case 1:
+ boxesAndScores = _a.sent();
+ boxesVals = boxesAndScores[0];
+ scoresVals = boxesAndScores[1];
+ selectedIndices = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold).selectedIndices;
+ if ($boxes !== boxes) {
+ $boxes.dispose();
+ }
+ if ($scores !== scores) {
+ $scores.dispose();
+ }
+ return [2 /*return*/, tensor1d(selectedIndices, 'int32')];
+ }
+ });
+ });
+ }
+ var nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union).
+ *
+ * This op also supports a Soft-NMS mode (c.f.
+ * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
+ * of other overlapping boxes, therefore favoring different regions of the image
+ * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
+ * parameter to be larger than 0.
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
+ * When sigma is 0, it falls back to nonMaxSuppression.
+ * @return A map with the following properties:
+ * - selectedIndices: A 1D tensor with the selected box indices.
+ * - selectedScores: A 1D tensor with the corresponding scores for each
+ * selected box.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ if (softNmsSigma === void 0) { softNmsSigma = 0.0; }
+ var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
+ var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
+ var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
+ maxOutputSize = params.maxOutputSize;
+ iouThreshold = params.iouThreshold;
+ scoreThreshold = params.scoreThreshold;
+ softNmsSigma = params.softNmsSigma;
+ var inputs = { boxes: $boxes, scores: $scores };
+ var attrs = { maxOutputSize: maxOutputSize, iouThreshold: iouThreshold, scoreThreshold: scoreThreshold, softNmsSigma: softNmsSigma };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
+ return { selectedIndices: result[0], selectedScores: result[1] };
+ }
+ var nonMaxSuppressionWithScore = op({ nonMaxSuppressionWithScore_: nonMaxSuppressionWithScore_ });
+
+ /**
+ * Asynchronously performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union).
+ *
+ * This op also supports a Soft-NMS mode (c.f.
+ * Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
+ * of other overlapping boxes, therefore favoring different regions of the image
+ * with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
+ * parameter to be larger than 0.
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @param softNmsSigma A float representing the sigma parameter for Soft NMS.
+ * When sigma is 0, it falls back to nonMaxSuppression.
+ * @return A map with the following properties:
+ * - selectedIndices: A 1D tensor with the selected box indices.
+ * - selectedScores: A 1D tensor with the corresponding scores for each
+ * selected box.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppressionWithScoreAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ if (softNmsSigma === void 0) { softNmsSigma = 0.0; }
+ return __awaiter(this, void 0, void 0, function () {
+ var $boxes, $scores, params, boxesAndScores, boxesVals, scoresVals, _a, selectedIndices, selectedScores;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
+ $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
+ params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
+ maxOutputSize = params.maxOutputSize;
+ iouThreshold = params.iouThreshold;
+ scoreThreshold = params.scoreThreshold;
+ softNmsSigma = params.softNmsSigma;
+ return [4 /*yield*/, Promise.all([$boxes.data(), $scores.data()])];
+ case 1:
+ boxesAndScores = _b.sent();
+ boxesVals = boxesAndScores[0];
+ scoresVals = boxesAndScores[1];
+ _a = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma), selectedIndices = _a.selectedIndices, selectedScores = _a.selectedScores;
+ if ($boxes !== boxes) {
+ $boxes.dispose();
+ }
+ if ($scores !== scores) {
+ $scores.dispose();
+ }
+ return [2 /*return*/, {
+ selectedIndices: tensor1d(selectedIndices, 'int32'),
+ selectedScores: tensor1d(selectedScores)
+ }];
+ }
+ });
+ });
+ }
+ var nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Asynchronously performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union), with an option to pad results.
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @param padToMaxOutputSize Defalts to false. If true, size of output
+ * `selectedIndices` is padded to maxOutputSize.
+ * @return A map with the following properties:
+ * - selectedIndices: A 1D tensor with the selected box indices.
+ * - validOutputs: A scalar denoting how many elements in `selectedIndices`
+ * are valid. Valid elements occur first, then padding.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ if (padToMaxOutputSize === void 0) { padToMaxOutputSize = false; }
+ var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
+ var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
+ var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
+ var $maxOutputSize = params.maxOutputSize;
+ var $iouThreshold = params.iouThreshold;
+ var $scoreThreshold = params.scoreThreshold;
+ var inputs = { boxes: $boxes, scores: $scores };
+ var attrs = {
+ maxOutputSize: $maxOutputSize,
+ iouThreshold: $iouThreshold,
+ scoreThreshold: $scoreThreshold,
+ padToMaxOutputSize: padToMaxOutputSize
+ };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
+ return { selectedIndices: result[0], validOutputs: result[1] };
+ }
+ var nonMaxSuppressionPadded = op({ nonMaxSuppressionPadded_: nonMaxSuppressionPadded_ });
+
+ /**
+ * Asynchronously performs non maximum suppression of bounding boxes based on
+ * iou (intersection over union), with an option to pad results.
+ *
+ * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
+ * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
+ * the bounding box.
+ * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
+ * @param maxOutputSize The maximum number of boxes to be selected.
+ * @param iouThreshold A float representing the threshold for deciding whether
+ * boxes overlap too much with respect to IOU. Must be between [0, 1].
+ * Defaults to 0.5 (50% box overlap).
+ * @param scoreThreshold A threshold for deciding when to remove boxes based
+ * on score. Defaults to -inf, which means any score is accepted.
+ * @param padToMaxOutputSize Defalts to false. If true, size of output
+ * `selectedIndices` is padded to maxOutputSize.
+ * @return A map with the following properties:
+ * - selectedIndices: A 1D tensor with the selected box indices.
+ * - validOutputs: A scalar denoting how many elements in `selectedIndices`
+ * are valid. Valid elements occur first, then padding.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function nonMaxSuppressionPaddedAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
+ if (iouThreshold === void 0) { iouThreshold = 0.5; }
+ if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; }
+ if (padToMaxOutputSize === void 0) { padToMaxOutputSize = false; }
+ return __awaiter(this, void 0, void 0, function () {
+ var $boxes, $scores, params, $maxOutputSize, $iouThreshold, $scoreThreshold, _a, boxesVals, scoresVals, _b, selectedIndices, validOutputs;
+ return __generator(this, function (_c) {
+ switch (_c.label) {
+ case 0:
+ $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
+ $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
+ params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null /* softNmsSigma */);
+ $maxOutputSize = params.maxOutputSize;
+ $iouThreshold = params.iouThreshold;
+ $scoreThreshold = params.scoreThreshold;
+ return [4 /*yield*/, Promise.all([$boxes.data(), $scores.data()])];
+ case 1:
+ _a = __read.apply(void 0, [_c.sent(), 2]), boxesVals = _a[0], scoresVals = _a[1];
+ _b = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize), selectedIndices = _b.selectedIndices, validOutputs = _b.validOutputs;
+ if ($boxes !== boxes) {
+ $boxes.dispose();
+ }
+ if ($scores !== scores) {
+ $scores.dispose();
+ }
+ return [2 /*return*/, {
+ selectedIndices: tensor1d(selectedIndices, 'int32'),
+ validOutputs: scalar(validOutputs, 'int32')
+ }];
+ }
+ });
+ });
+ }
+ var nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
+
+ /**
+ * Bilinear resize a single 3D image or a batch of 3D images to a new shape.
+ *
+ * @param images The images, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param size The new shape `[newHeight, newWidth]` to resize the
+ * images to. Each channel is resized individually.
+ * @param alignCorners Defaults to `false`. If true, rescale
+ * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
+ * corners of images and resized images. If false, rescale by
+ * `new_height / height`. Treat similarly the width dimension.
+ * @param halfPixelCenters Defaults to `false`. Whether to assume pixel centers
+ * are at 0.5, which would make the floating point coordinates of the top
+ * left pixel 0.5, 0.5.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function resizeBilinear_(images, size, alignCorners, halfPixelCenters) {
+ if (alignCorners === void 0) { alignCorners = false; }
+ if (halfPixelCenters === void 0) { halfPixelCenters = false; }
+ var $images = convertToTensor(images, 'images', 'resizeBilinear');
+ assert($images.rank === 3 || $images.rank === 4, function () { return "Error in resizeBilinear: x must be rank 3 or 4, but got " +
+ ("rank " + $images.rank + "."); });
+ assert(size.length === 2, function () { return "Error in resizeBilinear: new shape must 2D, but got shape " +
+ (size + "."); });
+ assert(halfPixelCenters === false || alignCorners === false, function () { return "Error in resizeBilinear: If halfPixelCenters is true, " +
+ "alignCorners must be false."; });
+ var batchImages = $images;
+ var reshapedTo4D = false;
+ if ($images.rank === 3) {
+ reshapedTo4D = true;
+ batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
+ }
+ __read(size, 0);
+ var inputs = { images: batchImages };
+ var attrs = { alignCorners: alignCorners, halfPixelCenters: halfPixelCenters, size: size };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(ResizeBilinear, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var resizeBilinear = op({ resizeBilinear_: resizeBilinear_ });
+
+ /**
+ * NearestNeighbor resize a batch of 3D images to a new shape.
+ *
+ * @param images The images, of rank 4 or rank 3, of shape
+ * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
+ * @param size The new shape `[newHeight, newWidth]` to resize the
+ * images to. Each channel is resized individually.
+ * @param alignCorners Defaults to False. If true, rescale
+ * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
+ * corners of images and resized images. If false, rescale by
+ * `new_height / height`. Treat similarly the width dimension.
+ * @param halfPixelCenters Defaults to `false`. Whether to assumes pixels are of
+ * half the actual dimensions, and yields more accurate resizes. This flag
+ * would also make the floating point coordinates of the top left pixel
+ * 0.5, 0.5.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function resizeNearestNeighbor_(images, size, alignCorners, halfPixelCenters) {
+ if (alignCorners === void 0) { alignCorners = false; }
+ if (halfPixelCenters === void 0) { halfPixelCenters = false; }
+ var $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
+ assert($images.rank === 3 || $images.rank === 4, function () { return "Error in resizeNearestNeighbor: x must be rank 3 or 4, but got " +
+ ("rank " + $images.rank + "."); });
+ assert(size.length === 2, function () { return "Error in resizeNearestNeighbor: new shape must 2D, but got shape " +
+ (size + "."); });
+ assert($images.dtype === 'float32' || $images.dtype === 'int32', function () { return '`images` must have `int32` or `float32` as dtype'; });
+ assert(halfPixelCenters === false || alignCorners === false, function () { return "Error in resizeNearestNeighbor: If halfPixelCenters is true, " +
+ "alignCorners must be false."; });
+ var batchImages = $images;
+ var reshapedTo4D = false;
+ if ($images.rank === 3) {
+ reshapedTo4D = true;
+ batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
+ }
+ __read(size, 0);
+ var inputs = { images: batchImages };
+ var attrs = { alignCorners: alignCorners, halfPixelCenters: halfPixelCenters, size: size };
+ // tslint:disable-next-line: no-unnecessary-type-assertion
+ var res = ENGINE.runKernel(ResizeNearestNeighbor, inputs, attrs);
+ if (reshapedTo4D) {
+ return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
+ }
+ return res;
+ }
+ var resizeNearestNeighbor = op({ resizeNearestNeighbor_: resizeNearestNeighbor_ });
+
+ /**
+ * Performs image binarization with corresponding threshold
+ * (depends on the method)value, which creates a binary image from a grayscale.
+ * @param image 3d tensor of shape [imageHeight,imageWidth, depth],
+ * where imageHeight and imageWidth must be positive.The image color
+ * range should be [0, 255].
+ * @param method Optional string from `'binary' | 'otsu'`
+ * which specifies the method for thresholding. Defaults to 'binary'.
+ * @param inverted Optional boolean whichspecifies
+ * if colours should be inverted. Defaults to false.
+ * @param threshValue Optional number which defines threshold value from 0 to 1.
+ * Defaults to 0.5.
+ * @return A 3d tensor of shape [imageHeight,imageWidth, depth], which
+ * contains binarized image.
+ */
+ function threshold_(image, method, inverted, threshValue) {
+ var _a;
+ if (method === void 0) { method = 'binary'; }
+ if (inverted === void 0) { inverted = false; }
+ if (threshValue === void 0) { threshValue = 0.5; }
+ var $image = convertToTensor(image, 'image', 'threshold');
+ /* 0.2989, 0.5870, 0.1140 are represent luma coefficients in CCIR601.
+ Reference for converting between RGB and grayscale: https://en.wikipedia.org/wiki/Luma_%28video%29 */
+ var RED_INTENCITY_COEF = 0.2989;
+ var GREEN_INTENCITY_COEF = 0.5870;
+ var BLUE_INTENCITY_COEF = 0.1140;
+ var totalPixelsInImage = $image.shape[0] * $image.shape[1];
+ var $threshold = mul(tensor1d([threshValue]), 255);
+ var r, g, b, grayscale;
+ assert($image.rank === 3, function () { return 'Error in threshold: image must be rank 3,' +
+ ("but got rank " + $image.rank + "."); });
+ assert($image.shape[2] === 3 || $image.shape[2] === 1, function () { return 'Error in threshold: ' +
+ 'image color channel must be equal to 3 or 1' +
+ ("but got " + $image.shape[2] + "."); });
+ assert($image.dtype === 'int32' || $image.dtype === 'float32', function () { return 'Error in dtype: image dtype must be int32 or float32,' +
+ ("but got dtype " + $image.dtype + "."); });
+ assert(method === 'otsu' || method === 'binary', function () { return "Method must be binary or otsu, but was " + method; });
+ if ($image.shape[2] === 3) {
+ _a = __read(split($image, [1, 1, 1], -1), 3), r = _a[0], g = _a[1], b = _a[2];
+ var $r = mul(r, RED_INTENCITY_COEF);
+ var $g = mul(g, GREEN_INTENCITY_COEF);
+ var $b = mul(b, BLUE_INTENCITY_COEF);
+ grayscale = add(add($r, $g), $b);
+ }
+ else {
+ grayscale = image;
+ }
+ if (method === 'otsu') {
+ var $histogram = bincount(cast(round(grayscale), 'int32'), tensor([]), 256);
+ $threshold = otsu($histogram, totalPixelsInImage);
+ }
+ var invCondition = inverted ?
+ lessEqual(grayscale, $threshold) : greater(grayscale, $threshold);
+ var result = cast(mul(invCondition, 255), 'int32');
+ return result;
+ }
+ function otsu(histogram, total) {
+ var bestThresh = tensor1d([-1]);
+ var bestInBetVar = tensor1d([0]);
+ var cInBetVar = tensor1d([0]);
+ var classFirst, classSecond, meanFirst, meanSec, weightForeground, weightBack;
+ for (var index = 0; index < histogram.size - 1; index++) {
+ classFirst = slice(histogram, 0, index + 1);
+ classSecond = slice(histogram, index + 1);
+ weightForeground = div(sum(classFirst), total);
+ weightBack = div(sum(classSecond), total);
+ var meanFirstDivA = sum(mul(classFirst, range(0, classFirst.size)));
+ meanFirst = div(meanFirstDivA, sum(classFirst));
+ var meanSecFill = fill(classSecond.shape, classFirst.size);
+ var meanSecAdd = add(range(0, classSecond.size), meanSecFill);
+ var meanSecMul = mul(classSecond, (meanSecAdd));
+ meanSec = div(sum(meanSecMul), sum(classSecond));
+ var cInBetVarSubA = sub(meanFirst, meanSec);
+ var cInBetVarSubB = sub(meanFirst, meanSec);
+ var cInBetVarMul = mul(weightForeground, weightBack);
+ cInBetVar = mul(mul(cInBetVarMul, cInBetVarSubA), cInBetVarSubB);
+ var condition = greater(cInBetVar, bestInBetVar);
+ bestInBetVar = where(condition, cInBetVar, bestInBetVar);
+ bestThresh = where(condition, tensor1d([index]), bestThresh);
+ }
+ return bestThresh;
+ }
+ var threshold = op({ threshold_: threshold_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Applies the given transform(s) to the image(s).
+ *
+ * @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
+ * @param transforms Projective transform matrix/matrices. A tensor1d of length
+ * 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0
+ * b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
+ * input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
+ * where k = c0 x + c1 y + 1. The transforms are inverted compared to the
+ * transform mapping input points to output points.
+ * @param interpolation Interpolation mode.
+ * Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
+ * @param fillMode Points outside the boundaries of the input are filled
+ * according to the given mode, one of 'constant', 'reflect', 'wrap',
+ * 'nearest'. Default to 'constant'.
+ * 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
+ * reflecting about the edge of the last pixel.
+ * 'constant': (k k k k | a b c d | k k k k) The input is extended by
+ * filling all values beyond the edge with the same constant value k.
+ * 'wrap': (a b c d | a b c d | a b c d) The input is extended by
+ * wrapping around to the opposite edge.
+ * 'nearest': (a a a a | a b c d | d d d d) The input is extended by
+ * the nearest pixel.
+ * @param fillValue A float represents the value to be filled outside the
+ * boundaries when fillMode is 'constant'.
+ * @param Output dimension after the transform, [height, width]. If undefined,
+ * output is the same size as input image.
+ *
+ * @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
+ */
+ function transform_(image, transforms, interpolation, fillMode, fillValue, outputShape) {
+ if (interpolation === void 0) { interpolation = 'nearest'; }
+ if (fillMode === void 0) { fillMode = 'constant'; }
+ if (fillValue === void 0) { fillValue = 0; }
+ var $image = convertToTensor(image, 'image', 'transform', 'float32');
+ var $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
+ assert($image.rank === 4, function () { return 'Error in transform: image must be rank 4,' +
+ ("but got rank " + $image.rank + "."); });
+ assert($transforms.rank === 2 &&
+ ($transforms.shape[0] === $image.shape[0] ||
+ $transforms.shape[0] === 1) &&
+ $transforms.shape[1] === 8, function () { return "Error in transform: Input transform should be batch x 8 or 1 x 8"; });
+ assert(outputShape == null || outputShape.length === 2, function () { return 'Error in transform: outputShape must be [height, width] or null, ' +
+ ("but got " + outputShape + "."); });
+ var inputs = { image: $image, transforms: $transforms };
+ var attrs = { interpolation: interpolation, fillMode: fillMode, fillValue: fillValue, outputShape: outputShape };
+ return ENGINE.runKernel(Transform, inputs, attrs);
+ }
+ var transform = op({ transform_: transform_ });
+
+ /**
+ * Copy a tensor setting everything outside a central band in each innermost
+ * matrix to zero.
+ *
+ * The band part is computed as follows: Assume input has `k` dimensions
+ * `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
+ * `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
+ * The indicator function
+ * `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
+ * `&& (num_upper < 0 || (n-m) <= num_upper)`
+ *
+ * ```js
+ * const x = tf.tensor2d([[ 0, 1, 2, 3],
+ * [-1, 0, 1, 2],
+ * [-2, -1, 0, 1],
+ * [-3, -2, -1, 0]]);
+ * let y = tf.linalg.bandPart(x, 1, -1);
+ * y.print(); // [[ 0, 1, 2, 3],
+ * // [-1, 0, 1, 2],
+ * // [ 0, -1, 0, 1],
+ * // [ 0, 0 , -1, 0]]
+ * let z = tf.linalg.bandPart(x, 2, 1);
+ * z.print(); // [[ 0, 1, 0, 0],
+ * // [-1, 0, 1, 0],
+ * // [-2, -1, 0, 1],
+ * // [ 0, -2, -1, 0]]
+ * ```
+ *
+ * @param x Rank `k` tensor
+ * @param numLower Number of subdiagonals to keep.
+ * If negative, keep entire lower triangle.
+ * @param numUpper Number of subdiagonals to keep.
+ * If negative, keep entire upper triangle.
+ * @returns Rank `k` tensor of the same shape as input.
+ * The extracted banded tensor.
+ *
+ * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
+ */
+ function bandPart_(a, numLower, numUpper) {
+ assert(numLower % 1 === 0, function () { return "bandPart(): numLower must be an integer, got " + numLower + "."; });
+ assert(numUpper % 1 === 0, function () { return "bandPart(): numUpper must be an integer, got " + numUpper + "."; });
+ var $a = convertToTensor(a, 'a', 'bandPart');
+ assert($a.rank >= 2, function () { return "bandPart(): Rank must be at least 2, got " + $a.rank + "."; });
+ var shape = $a.shape;
+ var _a = __read($a.shape.slice(-2), 2), M = _a[0], N = _a[1];
+ if (!(numLower <= M)) {
+ throw new Error("bandPart(): numLower (" + numLower + ")" +
+ (" must not be greater than the number of rows (" + M + ")."));
+ }
+ if (!(numUpper <= N)) {
+ throw new Error("bandPart(): numUpper (" + numUpper + ")" +
+ (" must not be greater than the number of columns (" + N + ")."));
+ }
+ if (numLower < 0) {
+ numLower = M;
+ }
+ if (numUpper < 0) {
+ numUpper = N;
+ }
+ var i = reshape(range(0, M, 1, 'int32'), [-1, 1]);
+ var j = range(0, N, 1, 'int32');
+ var ij = sub(i, j);
+ var inBand = logicalAnd(lessEqual(ij, scalar(+numLower, 'int32')), greaterEqual(ij, scalar(-numUpper, 'int32')));
+ var zero = zeros([M, N], $a.dtype);
+ return reshape(stack(unstack(reshape($a, [-1, M, N]))
+ .map(function (mat) { return where(inBand, mat, zero); })), shape);
+ }
+ var bandPart = op({ bandPart_: bandPart_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Gram-Schmidt orthogonalization.
+ *
+ * ```js
+ * const x = tf.tensor2d([[1, 2], [3, 4]]);
+ * let y = tf.linalg.gramSchmidt(x);
+ * y.print();
+ * console.log('Othogonalized:');
+ * y.dot(y.transpose()).print(); // should be nearly the identity matrix.
+ * console.log('First row direction maintained:');
+ * const data = await y.array();
+ * console.log(data[0][1] / data[0][0]); // should be nearly 2.
+ * ```
+ *
+ * @param xs The vectors to be orthogonalized, in one of the two following
+ * formats:
+ * - An Array of `tf.Tensor1D`.
+ * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
+ * of `xs`.
+ * In each case, all the vectors must have the same length and the length
+ * must be greater than or equal to the number of vectors.
+ * @returns The orthogonalized and normalized vectors or matrix.
+ * Orthogonalization means that the vectors or the rows of the matrix
+ * are orthogonal (zero inner products). Normalization means that each
+ * vector or each row of the matrix has an L2 norm that equals `1`.
+ *
+ * @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
+ */
+ function gramSchmidt_(xs) {
+ var inputIsTensor2D;
+ if (Array.isArray(xs)) {
+ inputIsTensor2D = false;
+ assert(xs != null && xs.length > 0, function () { return 'Gram-Schmidt process: input must not be null, undefined, or ' +
+ 'empty'; });
+ var dim_1 = xs[0].shape[0];
+ var _loop_1 = function (i) {
+ assert(xs[i].shape[0] === dim_1, function () { return 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
+ ("(" + xs[i].shape[0] + " vs. " + dim_1 + ")"); });
+ };
+ for (var i = 1; i < xs.length; ++i) {
+ _loop_1(i);
+ }
+ }
+ else {
+ inputIsTensor2D = true;
+ xs = split(xs, xs.shape[0], 0).map(function (x) { return squeeze(x, [0]); });
+ }
+ assert(xs.length <= xs[0].shape[0], function () { return "Gram-Schmidt: Number of vectors (" + xs.length + ") exceeds " +
+ ("number of dimensions (" + xs[0].shape[0] + ")."); });
+ var ys = [];
+ var xs1d = xs;
+ var _loop_2 = function (i) {
+ ys.push(ENGINE.tidy(function () {
+ var x = xs1d[i];
+ if (i > 0) {
+ for (var j = 0; j < i; ++j) {
+ var proj = mul(sum(mul(ys[j], x)), ys[j]);
+ x = sub(x, proj);
+ }
+ }
+ return div(x, norm(x, 'euclidean'));
+ }));
+ };
+ for (var i = 0; i < xs.length; ++i) {
+ _loop_2(i);
+ }
+ if (inputIsTensor2D) {
+ return stack(ys, 0);
+ }
+ else {
+ return ys;
+ }
+ }
+ var gramSchmidt = op({ gramSchmidt_: gramSchmidt_ });
+
+ /**
+ * Compute QR decomposition of m-by-n matrix using Householder transformation.
+ *
+ * Implementation based on
+ * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
+ * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
+ *
+ * ```js
+ * const a = tf.tensor2d([[1, 2], [3, 4]]);
+ * let [q, r] = tf.linalg.qr(a);
+ * console.log('Q');
+ * q.print();
+ * console.log('R');
+ * r.print();
+ * console.log('Orthogonalized');
+ * q.dot(q.transpose()).print() // should be nearly the identity matrix.
+ * console.log('Reconstructed');
+ * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
+ * ```
+ *
+ * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
+ * it has the shape `[..., M, N]`.
+ * @param fullMatrices An optional boolean parameter. Defaults to `false`.
+ * If `true`, compute full-sized `Q`. If `false` (the default),
+ * compute only the leading N columns of `Q` and `R`.
+ * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
+ * i.e., its columns all have unit norm and are mutually orthogonal.
+ * If `M >= N`,
+ * If `fullMatrices` is `false` (default),
+ * - `Q` has a shape of `[..., M, N]`,
+ * - `R` has a shape of `[..., N, N]`.
+ * If `fullMatrices` is `true` (default),
+ * - `Q` has a shape of `[..., M, M]`,
+ * - `R` has a shape of `[..., M, N]`.
+ * If `M < N`,
+ * - `Q` has a shape of `[..., M, M]`,
+ * - `R` has a shape of `[..., M, N]`.
+ * @throws If the rank of `x` is less than 2.
+ *
+ * @doc {heading:'Operations',
+ * subheading:'Linear Algebra',
+ * namespace:'linalg'}
+ */
+ function qr_(x, fullMatrices) {
+ if (fullMatrices === void 0) { fullMatrices = false; }
+ assert(x.rank >= 2, function () { return "qr() requires input tensor to have a rank >= 2, but got rank " + x.rank; });
+ if (x.rank === 2) {
+ return qr2d(x, fullMatrices);
+ }
+ else {
+ // Rank > 2.
+ // TODO(cais): Below we split the input into individual 2D tensors,
+ // perform QR decomposition on them and then stack the results back
+ // together. We should explore whether this can be parallelized.
+ var outerDimsProd = x.shape.slice(0, x.shape.length - 2)
+ .reduce(function (value, prev) { return value * prev; });
+ var x2ds = unstack(reshape(x, [
+ outerDimsProd, x.shape[x.shape.length - 2],
+ x.shape[x.shape.length - 1]
+ ]), 0);
+ var q2ds_1 = [];
+ var r2ds_1 = [];
+ x2ds.forEach(function (x2d) {
+ var _a = __read(qr2d(x2d, fullMatrices), 2), q2d = _a[0], r2d = _a[1];
+ q2ds_1.push(q2d);
+ r2ds_1.push(r2d);
+ });
+ var q = reshape(stack(q2ds_1, 0), x.shape);
+ var r = reshape(stack(r2ds_1, 0), x.shape);
+ return [q, r];
+ }
+ }
+ function qr2d(x, fullMatrices) {
+ if (fullMatrices === void 0) { fullMatrices = false; }
+ return ENGINE.tidy(function () {
+ assert(x.shape.length === 2, function () { return "qr2d() requires a 2D Tensor, but got a " + x.shape.length + "D Tensor."; });
+ var m = x.shape[0];
+ var n = x.shape[1];
+ var q = eye(m); // Orthogonal transform so far.
+ var r = clone(x); // Transformed matrix so far.
+ var one2D = tensor2d([[1]], [1, 1]);
+ var w = clone(one2D);
+ var iters = m >= n ? n : m;
+ var _loop_1 = function (j) {
+ var _a;
+ // This tidy within the for-loop ensures we clean up temporary
+ // tensors as soon as they are no longer needed.
+ var rTemp = r;
+ var wTemp = w;
+ var qTemp = q;
+ _a = __read(ENGINE.tidy(function () {
+ // Find H = I - tau * w * w', to put zeros below R(j, j).
+ var rjEnd1 = slice(r, [j, j], [m - j, 1]);
+ var normX = norm(rjEnd1);
+ var rjj = slice(r, [j, j], [1, 1]);
+ // The sign() function returns 0 on 0, which causes division by zero.
+ var s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
+ var u1 = sub(rjj, mul(s, normX));
+ var wPre = div(rjEnd1, u1);
+ if (wPre.shape[0] === 1) {
+ w = clone(one2D);
+ }
+ else {
+ w = concat([
+ one2D,
+ slice(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
+ ], 0);
+ }
+ var tau = neg(div(matMul$1(s, u1), normX));
+ // -- R := HR, Q := QH.
+ var rjEndAll = slice(r, [j, 0], [m - j, n]);
+ var tauTimesW = mul(tau, w);
+ var wT = transpose(w);
+ if (j === 0) {
+ r = sub(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
+ }
+ else {
+ var rTimesTau = sub(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
+ r = concat([slice(r, [0, 0], [j, n]), rTimesTau], 0);
+ }
+ var tawTimesWT = transpose(tauTimesW);
+ var qAllJEnd = slice(q, [0, j], [m, q.shape[1] - j]);
+ if (j === 0) {
+ q = sub(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
+ }
+ else {
+ var qTimesTau = sub(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
+ q = concat([slice(q, [0, 0], [m, j]), qTimesTau], 1);
+ }
+ return [w, r, q];
+ }), 3), w = _a[0], r = _a[1], q = _a[2];
+ dispose([rTemp, wTemp, qTemp]);
+ };
+ for (var j = 0; j < iters; ++j) {
+ _loop_1(j);
+ }
+ if (!fullMatrices && m > n) {
+ q = slice(q, [0, 0], [m, n]);
+ r = slice(r, [0, 0], [n, n]);
+ }
+ return [q, r];
+ });
+ }
+ var qr = op({ qr_: qr_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ exports.Reduction = void 0;
+ (function (Reduction) {
+ Reduction[Reduction["NONE"] = 0] = "NONE";
+ Reduction[Reduction["MEAN"] = 1] = "MEAN";
+ Reduction[Reduction["SUM"] = 2] = "SUM";
+ Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
+ })(exports.Reduction || (exports.Reduction = {}));
+
+ /**
+ * Computes the weighted loss between two tensors.
+ *
+ * @param losses Tensor of shape `[batch_size, d1, ... dN]`.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `losses`, and must be broadcastable to `losses` (i.e., all
+ * dimensions must be either `1`, or the same as the corresponding
+ * `losses` dimension).
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function computeWeightedLoss_(losses, weights, reduction) {
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
+ }
+ var weightedLoss = ($weights == null) ? $losses : mul($losses, $weights);
+ if (reduction === exports.Reduction.NONE) {
+ return weightedLoss;
+ }
+ if (reduction === exports.Reduction.SUM) {
+ return sum(weightedLoss);
+ }
+ if (reduction === exports.Reduction.MEAN) {
+ if ($weights == null) {
+ return mean(weightedLoss);
+ }
+ else {
+ var broadcastFactor = $losses.size / $weights.size;
+ var result = div(sum(weightedLoss), sum($weights));
+ return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) :
+ result;
+ }
+ }
+ if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
+ if ($weights == null) {
+ return div(sum(weightedLoss), scalar($losses.size));
+ }
+ else {
+ var broadcastedWeights = mul($weights, ones($losses.shape));
+ var numNonZeros = cast(sum(notEqual(broadcastedWeights, scalar(0))), 'float32');
+ return div(sum(weightedLoss), numNonZeros);
+ }
+ }
+ throw Error("Unknown reduction: " + reduction);
+ }
+ var computeWeightedLoss = op({ computeWeightedLoss_: computeWeightedLoss_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the absolute difference loss between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function absoluteDifference_(labels, predictions, weights, reduction) {
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
+ var $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'absoluteDifference');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
+ var losses = abs(sub($labels, $predictions));
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var absoluteDifference = op({ absoluteDifference_: absoluteDifference_ });
+
+ /**
+ * Computes the cosine distance loss between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param axis The dimension along which the cosine distance is computed.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function cosineDistance_(labels, predictions, axis, weights, reduction) {
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'cosineDistance');
+ var $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'cosineDistance');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
+ var one = scalar(1);
+ var losses = sub(one, sum(mul($labels, $predictions), axis, true));
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var cosineDistance = op({ cosineDistance_: cosineDistance_ });
+
+ /**
+ * Computes the Hinge loss between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function hingeLoss_(labels, predictions, weights, reduction) {
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'hingeLoss');
+ var $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'hingeLoss');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
+ var one = scalar(1);
+ // Convert binary labels to (-1, 1)
+ $labels = sub(mul(scalar(2), $labels), one);
+ var losses = relu(sub(one, mul($labels, $predictions)));
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var hingeLoss = op({ hingeLoss_: hingeLoss_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the huber loss between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param delta Point where huber loss changes from quadratic to linear.
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`.
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function huberLoss_(labels, predictions, weights, delta, reduction) {
+ if (delta === void 0) { delta = 1.0; }
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'huberLoss');
+ var $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'huberLoss');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
+ var deltaScalar = scalar(delta);
+ var error = abs(sub($predictions, $labels));
+ var quadratic = minimum(error, deltaScalar);
+ var linear = sub(error, quadratic);
+ var losses = add(mul(scalar(0.5), square(quadratic)), mul(deltaScalar, linear));
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var huberLoss = op({ huberLoss_: huberLoss_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the log loss between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param epsilon A small increment to avoid taking log of zero
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function logLoss_(labels, predictions, weights, epsilon, reduction) {
+ if (epsilon === void 0) { epsilon = 1e-7; }
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'logLoss');
+ var $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'logLoss');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
+ var one = scalar(1);
+ var epsilonScalar = scalar(epsilon);
+ var l1 = neg(mul($labels, log(add($predictions, epsilonScalar))));
+ var l2 = mul(sub(one, $labels), log(add(sub(one, $predictions), epsilonScalar)));
+ var losses = sub(l1, l2);
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var logLoss = op({ logLoss_: logLoss_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the mean squared error between two tensors.
+ *
+ * @param labels The ground truth output tensor, same dimensions as
+ * 'predictions'.
+ * @param predictions The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
+ */
+ function meanSquaredError_(labels, predictions, weights, reduction) {
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
+ var $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'meanSquaredError');
+ }
+ assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
+ var losses = squaredDifference($labels, $predictions);
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var meanSquaredError = op({ meanSquaredError_: meanSquaredError_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function sigmoidCrossEntropyWithLogits_(labels, logits) {
+ var $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
+ var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
+ assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
+ /**
+ * Implementation Details:
+ *
+ * For brevity, let `x = logits`, `z = labels`. The logistic loss is
+ * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+ * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+ * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+ * = (1 - z) * x + log(1 + exp(-x))
+ * = x - x * z + log(1 + exp(-x))
+ *
+ * For x < 0, to avoid overflow in exp(-x), we reformulate the above
+ * x - x * z + log(1 + exp(-x))
+ * = log(exp(x)) - x * z + log(1 + exp(-x))
+ * = - x * z + log(1 + exp(x))
+ *
+ * Hence, to ensure stability and avoid overflow, the implementation uses
+ * this equivalent formulation:
+ * max(x, 0) - x * z + log(1 + exp(-abs(x)))
+ */
+ var maxOutput = relu($logits);
+ var outputXTarget = mul($logits, $labels);
+ var sigmoidOutput = log1p(exp(neg(abs($logits))));
+ return add(sub(maxOutput, outputXTarget), sigmoidOutput);
+ }
+ /**
+ * Computes the sigmoid cross entropy loss between two tensors.
+ *
+ * If labelSmoothing is nonzero, smooth the labels towards 1/2:
+ *
+ * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
+ * + 0.5 * labelSmoothing
+ *
+ * @param multiClassLabels The ground truth output tensor of shape
+ * [batch_size, num_classes], same dimensions as 'predictions'.
+ * @param logits The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or the same rank as
+ * `labels`, and must be broadcastable to `labels` (i.e., all dimensions
+ * must be either `1`, or the same as the corresponding `losses`
+ * dimension).
+ * @param labelSmoothing If greater than 0, then smooth the labels.
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
+ */
+ function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing, reduction) {
+ if (labelSmoothing === void 0) { labelSmoothing = 0; }
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
+ var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
+ }
+ assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
+ if (labelSmoothing > 0) {
+ var labelSmoothingScalar = scalar(labelSmoothing);
+ var one = scalar(1);
+ var half = scalar(0.5);
+ $multiClassLabels =
+ add(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
+ }
+ var losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var sigmoidCrossEntropy = op({ sigmoidCrossEntropy_: sigmoidCrossEntropy_ });
+
+ /**
+ * Computes softmax cross entropy between logits and labels.
+ *
+ * Measures the probability error in discrete classification tasks in which
+ * the classes are mutually exclusive (each entry is in exactly one class).
+ * For example, each CIFAR-10 image is labeled with one and only one label: an
+ * image can be a dog or a truck, but not both.
+ *
+ * `NOTE`: While the classes are mutually exclusive, their probabilities need
+ * not be. All that is required is that each row of labels is a valid
+ * probability distribution. If they are not, the computation of the gradient
+ * will be incorrect.
+ *
+ * `WARNING`: This op expects unscaled logits, since it performs a softmax on
+ * logits internally for efficiency. Do not call this op with the output of
+ * softmax, as it will produce incorrect results.
+ *
+ * logits and labels must have the same shape, e.g. [batch_size, num_classes]
+ * and the same dtype.
+ * @param labels The labels array.
+ * @param logits The logits array.
+ * @param dim The dimension softmax would be performed on. Defaults to `-1`
+ * which indicates the last dimension.
+ */
+ function softmaxCrossEntropyWithLogits_(labels, logits, dim) {
+ if (dim === void 0) { dim = -1; }
+ if (dim === -1) {
+ dim = logits.rank - 1;
+ }
+ if (dim !== logits.rank - 1) {
+ throw Error("Softmax cross entropy along a non-last dimension is not yet " +
+ ("supported. Labels / logits was rank " + logits.rank + " ") +
+ ("and dim was " + dim));
+ }
+ // Use a custom gradient for numerical stability.
+ var customOp = customGrad(function (labels, logits, save) {
+ // Reference:
+ // 1. http://cs231n.github.io/linear-classify/#softmax
+ // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
+ var keepDims = true;
+ var lse = logSumExp(logits, [dim], keepDims);
+ var logResult = sub(cast(logits, 'float32'), lse);
+ save([labels, logResult]);
+ var costVector = neg(mul(logResult, labels));
+ var value = sum(costVector, [dim]);
+ var gradFunc = function (dy, saved) {
+ var _a = __read(saved, 2), labels = _a[0], logResult = _a[1];
+ var dyShape = expandShapeToKeepDim(dy.shape, [dim]);
+ return [
+ mul(reshape(dy, dyShape), sub(cast(labels, 'float32'), exp(logResult))),
+ mul(reshape(dy, dyShape), sub(exp(logResult), cast(labels, 'float32'))),
+ ];
+ };
+ return { value: value, gradFunc: gradFunc };
+ });
+ return customOp(labels, logits);
+ }
+ /**
+ * Computes the softmax cross entropy loss between two tensors.
+ *
+ * If labelSmoothing is nonzero, smooth the labels towards 1/2:
+ *
+ * newOnehotLabels = onehotLabels * (1 - labelSmoothing)
+ * + labelSmoothing / numClasses
+ *
+ * @param onehotLabels One hot encoded labels
+ * [batch_size, num_classes], same dimensions as 'predictions'.
+ * @param logits The predicted outputs.
+ * @param weights Tensor whose rank is either 0, or 1, and must be
+ * broadcastable to `loss` of shape [batch_size]
+ * @param labelSmoothing If greater than 0, then smooth the labels.
+ * @param reduction Type of reduction to apply to loss. Should be of type
+ * `Reduction`
+ *
+ * @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
+ */
+ function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing, reduction) {
+ if (labelSmoothing === void 0) { labelSmoothing = 0; }
+ if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; }
+ var $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
+ var $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
+ var $weights = null;
+ if (weights != null) {
+ $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
+ }
+ assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
+ if (labelSmoothing > 0) {
+ var labelSmoothingScalar = scalar(labelSmoothing);
+ var one = scalar(1);
+ var numClasses = scalar($onehotLabels.shape[1]);
+ $onehotLabels =
+ add(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
+ }
+ var losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
+ return computeWeightedLoss(losses, $weights, reduction);
+ }
+ var softmaxCrossEntropy = op({ softmaxCrossEntropy_: softmaxCrossEntropy_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * The input SparseTensor is represented via the map of inputs {`indices`,
+ * `values`, `denseShape`}. The output SparseTensor has the same `denseShape`
+ * but with indices `outputIndices` and values `outputValues`. This op inserts a
+ * single entry for every row that doesn't have any values. The index is created
+ * as `[row, 0, ..., 0]` and the inserted value is `defaultValue`.
+ *
+ * For example, suppose `spInput` has shape [5, 6] and non-empty values:
+ * [0, 1]: a
+ * [0, 3]: b
+ * [2, 0]: c
+ * [3, 1]: d
+ *
+ * Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values:
+ * [0, 1]: a
+ * [0, 3]: b
+ * [1, 0]: `defaultValue`
+ * [2, 0]: c
+ * [3, 1]: d
+ * [4, 0]: `defaultValue`
+ *
+ * The output SparseTensor will be in row-major order and will have the same
+ * shape as the input.
+ *
+ * This op also returns an indicator vector shaped [dense_shape[0]] such that
+ * emptyRowIndicator[i] = True iff row i was an empty row.
+ *
+ * And a reverse index map vector shaped [indices.shape[0]] that is used during
+ * backpropagation, reverseIndexMap[i] = outi s.t. indices[i, j] ==
+ * outputIndices[outi, j] for all j
+ *
+ * ```js
+ * const result = tf.sparse.sparseFillEmptyRows(
+ * [[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]],
+ * [0, 10, 13, 14, 32, 33], [5, 6], -1);
+ * console.log(result);
+ * result['outputIndices'].print(); // [[0, 0], [1, 0], [1, 3], [1, 4],
+ * // [2, 0], [3, 2], [3, 3], [4, 0]]
+ * result['outputValues'].print(); // [0, 10, 13, 14,-1, 32, 33, -1]
+ * result['emptyRowIndicator'].print(); // [false, false, true, false, true]
+ * result['reverseIndexMap'].print(); // [0, 1, 2, 3, 5, 6]
+ * ```
+ * @param indices: 2-D. the indices of the sparse tensor.
+ * @param values: 1-D. the values of the sparse tensor.
+ * @param denseShape: 1-D. the shape of the sparse tensor.
+ * @param defaultValue: 0-D. default value to insert into location [row, 0, ...,
+ * 0] for rows missing from the input sparse tensor.
+ * @return A map with the following properties:
+ * - outputIndices
+ * - outputValues: 1-D. the values of the filled sparse tensor.
+ * - emptyRowIndicator: 1-D. whether the dense row was missing in the input
+ * sparse tensor.
+ * - reverseIndexMap: 1-D. a map from the input indices to the output
+ * indices.
+ * @doc {heading: 'Operations', subheading: 'Sparse'}
+ */
+ function sparseFillEmptyRows_(indices, values, denseShape, defaultValue) {
+ var $indices = convertToTensor(indices, 'indices', 'sparseFillEmptyRows', 'int32');
+ var $values = convertToTensor(values, 'values', 'sparseFillEmptyRows');
+ var $denseShape = convertToTensor(denseShape, 'denseShape', 'sparseFillEmptyRows', 'int32');
+ var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseFillEmptyRows', $values.dtype);
+ if ($indices.rank !== 2) {
+ throw new Error("Indices should be Tensor2D but received shape\n " + $indices.shape);
+ }
+ if ($values.rank !== 1) {
+ throw new Error("Values should be Tensor1D but received shape " + $values.shape);
+ }
+ if ($denseShape.rank !== 1) {
+ throw new Error("Dense shape should be Tensor1D but received shape " + $denseShape.shape);
+ }
+ if ($defaultValue.rank !== 0) {
+ throw new Error("Default value should be a scalar but received shape " + $defaultValue.shape);
+ }
+ var inputs = {
+ indices: $indices,
+ values: $values,
+ denseShape: $denseShape,
+ defaultValue: $defaultValue
+ };
+ var result = ENGINE.runKernel(SparseFillEmptyRows, inputs);
+ return {
+ outputIndices: result[0],
+ outputValues: result[1],
+ emptyRowIndicator: result[2],
+ reverseIndexMap: result[3]
+ };
+ }
+ var sparseFillEmptyRows = op({ sparseFillEmptyRows_: sparseFillEmptyRows_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * This operation has the same semantics as reshape on the represented dense
+ * tensor. The `inputIndices` are recomputed based on the requested `newShape`.
+ * If one component of `newShape` is the special value -1, the size of that
+ * dimension is computed so that the total dense size remains constant. At most
+ * one component of `newShape` can be -1. The number of dense elements implied
+ * by `newShape` must be the same as the number of dense elements originally
+ * implied by `inputShape`. Reshaping does not affect the order of values in the
+ * SparseTensor. If the input tensor has rank R_in and N non-empty values, and
+ * `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
+ * `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
+ * `outputShape` has length R_out.
+ *
+ * ```js
+ * const result = tf.sparse.sparseReshape(
+ * [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
+ * [2, 3, 6], [9, -1]);
+ * console.log(result);
+ * result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
+ * result['outputShape'].print(); // [9, 4]
+ * ```
+ * @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
+ * values in a SparseTensor.
+ * @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
+ * shape.
+ * @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
+ * @return A map with the following properties:
+ * - outputIndices: 2-D. N x R_out matrix with the updated indices of
+ * non-empty values in the output SparseTensor.
+ * - outputShape: 1-D. R_out vector with the full dense shape of the output
+ * SparseTensor. This is the same as newShape but with any -1 dimensions
+ * filled in.
+ * @doc {heading: 'Operations', subheading: 'Sparse'}
+ */
+ function sparseReshape_(inputIndices, inputShape, newShape) {
+ var $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape', 'int32');
+ var $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape', 'int32');
+ var $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape', 'int32');
+ if ($inputIndices.rank !== 2) {
+ throw new Error("Input indices should be Tensor2D but received shape\n " + $inputIndices.shape);
+ }
+ if ($inputShape.rank !== 1) {
+ throw new Error("Input shape should be Tensor1D but received shape " + $inputShape.shape);
+ }
+ if ($newShape.rank !== 1) {
+ throw new Error("New shape should be Tensor1D but received shape " + $newShape.shape);
+ }
+ var inputs = {
+ inputIndices: $inputIndices,
+ inputShape: $inputShape,
+ newShape: $newShape
+ };
+ var result = ENGINE.runKernel(SparseReshape, inputs);
+ return { outputIndices: result[0], outputShape: result[1] };
+ }
+ var sparseReshape = op({ sparseReshape_: sparseReshape_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the mean along sparse segments of a tensor.
+ *
+ * ```js
+ * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
+ * // Select two rows, one segment.
+ * const result1 = tf.sparse.sparseSegmentMean(c,
+ * tf.tensor1d([0, 1], 'int32'),
+ * tf.tensor1d([0, 0], 'int32'));
+ * result1.print(); // [[0, 0, 0, 0]]
+ *
+ * // Select two rows, two segments.
+ * const result2 = tf.sparse.sparseSegmentMean(c,
+ * tf.tensor1d([0, 1], 'int32'),
+ * tf.tensor1d([0, 1], 'int32'));
+ * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
+ *
+ * // Select all rows, two segments.
+ * const result3 = tf.sparse.sparseSegmentMean(c,
+ * tf.tensor1d([0, 1, 2], 'int32'),
+ * tf.tensor1d([0, 1, 1], 'int32'));
+ * result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
+ * ```
+ * @param data: A Tensor of at least one dimension with data that will be
+ * assembled in the output.
+ * @param indices: A 1-D Tensor with indices into data. Has same rank as
+ * segmentIds.
+ * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
+ * should be sorted and can be repeated.
+ * @return Has same shape as data, except for dimension 0 which has equal to
+ * the number of segments.
+ *
+ * @doc {heading: 'Operations', subheading: 'Sparse'}
+ */
+ function sparseSegmentMean_(data, indices, segmentIds) {
+ var $data = convertToTensor(data, 'data', 'sparseSegmentMean');
+ var $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
+ var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
+ if ($data.rank < 1) {
+ throw new Error("Data should be at least 1 dimensional but received scalar");
+ }
+ if ($indices.rank !== 1) {
+ throw new Error("Indices should be Tensor1D but received shape\n " + $indices.shape);
+ }
+ if ($segmentIds.rank !== 1) {
+ throw new Error("Segment ids should be Tensor1D but received shape\n " + $segmentIds.shape);
+ }
+ var inputs = {
+ data: $data,
+ indices: $indices,
+ segmentIds: $segmentIds
+ };
+ return ENGINE.runKernel(SparseSegmentMean, inputs);
+ }
+ var sparseSegmentMean = op({ sparseSegmentMean_: sparseSegmentMean_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Computes the sum along sparse segments of a tensor.
+ *
+ * ```js
+ * const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]);
+ * // Select two rows, one segment.
+ * const result1 = tf.sparse.sparseSegmentSum(c,
+ * tf.tensor1d([0, 1], 'int32'),
+ * tf.tensor1d([0, 0], 'int32'));
+ * result1.print(); // [[0, 0, 0, 0]]
+ *
+ * // Select two rows, two segment.
+ * const result2 = tf.sparse.sparseSegmentSum(c,
+ * tf.tensor1d([0, 1], 'int32'),
+ * tf.tensor1d([0, 1], 'int32'));
+ * result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
+ *
+ * // Select all rows, two segments.
+ * const result3 = tf.sparse.sparseSegmentSum(c,
+ * tf.tensor1d([0, 1, 2], 'int32'),
+ * tf.tensor1d([0, 0, 1], 'int32'));
+ * result3.print(); // [[0, 0, 0, 0], [5, 6, 7, 8]]
+ * ```
+ * @param data: A Tensor of at least one dimension with data that will be
+ * assembled in the output.
+ * @param indices: A 1-D Tensor with indices into data. Has same rank as
+ * segmentIds.
+ * @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
+ * should be sorted and can be repeated.
+ * @return Has same shape as data, except for dimension 0 which has equal to
+ * the number of segments.
+ *
+ * @doc {heading: 'Operations', subheading: 'Sparse'}
+ */
+ function sparseSegmentSum_(data, indices, segmentIds) {
+ var $data = convertToTensor(data, 'data', 'sparseSegmentSum');
+ var $indices = convertToTensor(indices, 'indices', 'sparseSegmentSum', 'int32');
+ var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum', 'int32');
+ if ($data.rank < 1) {
+ throw new Error("Data should be at least 1 dimensional but received scalar");
+ }
+ if ($indices.rank !== 1) {
+ throw new Error("Indices should be Tensor1D but received shape\n " + $indices.shape);
+ }
+ if ($segmentIds.rank !== 1) {
+ throw new Error("Segment ids should be Tensor1D but received shape\n " + $segmentIds.shape);
+ }
+ var inputs = {
+ data: $data,
+ indices: $indices,
+ segmentIds: $segmentIds
+ };
+ return ENGINE.runKernel(SparseSegmentSum, inputs);
+ }
+ var sparseSegmentSum = op({ sparseSegmentSum_: sparseSegmentSum_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Creates ngrams from ragged string data.
+ *
+ * This op accepts a ragged tensor with 1 ragged dimension containing only
+ * strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
+ * of that string, joined along the innermost axis.
+ *
+ * ```js
+ * const result = tf.string.stringNGrams(
+ * ['a', 'b', 'c', 'd'], tf.tensor1d([0, 2, 4], 'int32'),
+ * '|', [1, 2], 'LP', 'RP', -1, false);
+ * result['nGrams'].print(); // ['a', 'b', 'LP|a', 'a|b', 'b|RP',
+ * // 'c', 'd', 'LP|c', 'c|d', 'd|RP']
+ * result['nGramsSplits'].print(); // [0, 5, 10]
+ * ```
+ * @param data: The values tensor of the ragged string tensor to make ngrams out
+ * of. Must be a 1D string tensor.
+ * @param dataSplits: The splits tensor of the ragged string tensor to make
+ * ngrams out of.
+ * @param separator: The string to append between elements of the token. Use ""
+ * for no separator.
+ * @param nGramWidths: The sizes of the ngrams to create.
+ * @param leftPad: The string to use to pad the left side of the ngram sequence.
+ * Only used if pad_width !== 0.
+ * @param rightPad: The string to use to pad the right side of the ngram
+ * sequence. Only used if pad_width !== 0.
+ * @param padWidth: The number of padding elements to add to each side of each
+ * sequence. Note that padding will never be greater than `nGramWidths`-1
+ * regardless of this value. If `padWidth`=-1 , then add max(`nGramWidths)-1
+ * elements.
+ * @param preserveShortSequences: If true, then ensure that at least one ngram
+ * is generated for each input sequence. In particular, if an input sequence
+ * is shorter than min(ngramWidth) + 2*padWidth, then generate a single
+ * ngram containing the entire sequence. If false, then no ngrams are
+ * generated for these short input sequences.
+ * @return A map with the following properties:
+ * - nGrams: The values tensor of the output ngrams ragged tensor.
+ * - nGramsSplits: The splits tensor of the output ngrams ragged tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'String'}
+ */
+ function stringNGrams_(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
+ var $data = convertToTensor(data, 'data', 'stringNGrams', 'string');
+ if ($data.dtype !== 'string') {
+ throw new Error('Data must be of datatype string');
+ }
+ if ($data.shape.length !== 1) {
+ throw new Error("Data must be a vector, saw: " + $data.shape);
+ }
+ var $dataSplits = convertToTensor(dataSplits, 'dataSplits', 'stringNGrams');
+ if ($dataSplits.dtype !== 'int32') {
+ throw new Error('Data splits must be of datatype int32');
+ }
+ var attrs = {
+ separator: separator,
+ nGramWidths: nGramWidths,
+ leftPad: leftPad,
+ rightPad: rightPad,
+ padWidth: padWidth,
+ preserveShortSequences: preserveShortSequences
+ };
+ var inputs = { data: $data, dataSplits: $dataSplits };
+ var result = ENGINE.runKernel(StringNGrams, inputs, attrs);
+ return { nGrams: result[0], nGramsSplits: result[1] };
+ }
+ var stringNGrams = op({ stringNGrams_: stringNGrams_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Split elements of `input` based on `delimiter` into a SparseTensor .
+ *
+ * Let N be the size of source (typically N will be the batch size). Split each
+ * element of `input` based on `delimiter` and return a SparseTensor containing
+ * the splitted tokens. Empty tokens are ignored if `skipEmpty` is set to True.
+ *
+ * `delimiter` can be empty, or a string of split characters. If `delimiter` is
+ * an empty string, each element of `input` is split into individual
+ * character strings. Otherwise every character of `delimiter` is a potential
+ * split point.
+ *
+ * ```js
+ * const result = tf.string.stringSplit(['hello world', 'a b c'], ' ');
+ * result['indices'].print(); // [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
+ * result['values'].print(); // ['hello', 'world', 'a', 'b', 'c']
+ * result['shape'].print(); // [2, 3]
+ * ```
+ * @param input: 1-D. Strings to split.
+ * @param delimiter: 0-D. Delimiter characters, or empty string.
+ * @param skipEmpty: Optional. If true, skip the empty strings from the result.
+ * Defaults to true.
+ * @return A map with the following properties:
+ * - indices: A dense matrix of int32 representing the indices of the sparse
+ * tensor.
+ * - values: A vector of strings corresponding to the splited values.
+ * - shape: a length-2 vector of int32 representing the shape of the sparse
+ * tensor, where the first value is N and the second value is the maximum number
+ * of tokens in a single input entry.
+ *
+ * @doc {heading: 'Operations', subheading: 'String'}
+ */
+ function stringSplit_(input, delimiter, skipEmpty) {
+ if (skipEmpty === void 0) { skipEmpty = true; }
+ var $input = convertToTensor(input, 'input', 'stringSplit', 'string');
+ var $delimiter = convertToTensor(delimiter, 'delimiter', 'stringSplit', 'string');
+ if ($input.rank !== 1) {
+ throw new Error("Input should be Tensor1D but received shape " + $input.shape);
+ }
+ if ($delimiter.rank !== 0) {
+ throw new Error("Delimiter should be a scalar but received shape " + $delimiter.shape);
+ }
+ var attrs = { skipEmpty: skipEmpty };
+ var inputs = { input: $input, delimiter: $delimiter };
+ var result = ENGINE.runKernel(StringSplit, inputs, attrs);
+ return { indices: result[0], values: result[1], shape: result[2] };
+ }
+ var stringSplit = op({ stringSplit_: stringSplit_ });
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Converts each string in the input Tensor to its hash mod by a number of
+ * buckets.
+ *
+ * The hash function is deterministic on the content of the string within the
+ * process and will never change. However, it is not suitable for cryptography.
+ * This function may be used when CPU time is scarce and inputs are trusted or
+ * unimportant. There is a risk of adversaries constructing inputs that all hash
+ * to the same bucket.
+ *
+ * ```js
+ * const result = tf.string.stringToHashBucketFast(
+ * ['Hello', 'TensorFlow', '2.x'], 3);
+ * result.print(); // [0, 2, 2]
+ * ```
+ * @param input: The strings to assign a hash bucket.
+ * @param numBuckets: The number of buckets.
+ * @return A Tensor of the same shape as the input tensor.
+ *
+ * @doc {heading: 'Operations', subheading: 'String'}
+ */
+ function stringToHashBucketFast_(input, numBuckets) {
+ var $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
+ var attrs = { numBuckets: numBuckets };
+ if (numBuckets <= 0) {
+ throw new Error("Number of buckets must be at least 1");
+ }
+ var inputs = { input: $input };
+ return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
+ }
+ var stringToHashBucketFast = op({ stringToHashBucketFast_: stringToHashBucketFast_ });
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var spectral = {
+ fft: fft,
+ ifft: ifft,
+ rfft: rfft,
+ irfft: irfft
+ };
+ var signal = {
+ hammingWindow: hammingWindow,
+ hannWindow: hannWindow,
+ frame: frame,
+ stft: stft,
+ };
+ var image = {
+ flipLeftRight: flipLeftRight,
+ grayscaleToRGB: grayscaleToRGB,
+ resizeNearestNeighbor: resizeNearestNeighbor,
+ resizeBilinear: resizeBilinear,
+ rotateWithOffset: rotateWithOffset,
+ cropAndResize: cropAndResize,
+ nonMaxSuppression: nonMaxSuppression,
+ nonMaxSuppressionAsync: nonMaxSuppressionAsync,
+ nonMaxSuppressionWithScore: nonMaxSuppressionWithScore,
+ nonMaxSuppressionWithScoreAsync: nonMaxSuppressionWithScoreAsync,
+ nonMaxSuppressionPadded: nonMaxSuppressionPadded,
+ nonMaxSuppressionPaddedAsync: nonMaxSuppressionPaddedAsync,
+ threshold: threshold,
+ transform: transform
+ };
+ var linalg = {
+ bandPart: bandPart,
+ gramSchmidt: gramSchmidt,
+ qr: qr
+ };
+ var losses = {
+ absoluteDifference: absoluteDifference,
+ computeWeightedLoss: computeWeightedLoss,
+ cosineDistance: cosineDistance,
+ hingeLoss: hingeLoss,
+ huberLoss: huberLoss,
+ logLoss: logLoss,
+ meanSquaredError: meanSquaredError,
+ sigmoidCrossEntropy: sigmoidCrossEntropy,
+ softmaxCrossEntropy: softmaxCrossEntropy
+ };
+ var sparse = {
+ sparseFillEmptyRows: sparseFillEmptyRows,
+ sparseReshape: sparseReshape,
+ sparseSegmentMean: sparseSegmentMean,
+ sparseSegmentSum: sparseSegmentSum
+ };
+ // tslint:disable-next-line:variable-name
+ var string = {
+ stringNGrams: stringNGrams,
+ stringSplit: stringSplit,
+ stringToHashBucketFast: stringToHashBucketFast
+ };
+
+ /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
+ var Optimizer = /** @class */ (function (_super) {
+ __extends(Optimizer, _super);
+ function Optimizer() {
+ return _super !== null && _super.apply(this, arguments) || this;
+ }
+ /**
+ * Executes `f()` and minimizes the scalar output of `f()` by computing
+ * gradients of y with respect to the list of trainable variables provided by
+ * `varList`. If no list is provided, it defaults to all trainable variables.
+ *
+ * @param f The function to execute and whose output to minimize.
+ * @param returnCost Whether to return the scalar cost value produced by
+ * executing `f()`.
+ * @param varList An optional list of variables to update. If specified, only
+ * the trainable variables in varList will be updated by minimize. Defaults to
+ * all trainable variables.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers'}
+ */
+ Optimizer.prototype.minimize = function (f, returnCost, varList) {
+ if (returnCost === void 0) { returnCost = false; }
+ var _a = this.computeGradients(f, varList), value = _a.value, grads = _a.grads;
+ if (varList != null) {
+ var gradArray = varList.map(function (v) { return ({ name: v.name, tensor: grads[v.name] }); });
+ this.applyGradients(gradArray);
+ }
+ else {
+ this.applyGradients(grads);
+ }
+ // Dispose gradients.
+ dispose(grads);
+ if (returnCost) {
+ return value;
+ }
+ else {
+ value.dispose();
+ return null;
+ }
+ };
+ Object.defineProperty(Optimizer.prototype, "iterations", {
+ /**
+ * The number of iterations that this optimizer instance has been invoked for.
+ */
+ get: function () {
+ if (this.iterations_ == null) {
+ this.iterations_ = 0;
+ }
+ return this.iterations_;
+ },
+ enumerable: true,
+ configurable: true
+ });
+ Optimizer.prototype.incrementIterations = function () {
+ this.iterations_ = this.iterations + 1;
+ };
+ /**
+ * Executes f() and computes the gradient of the scalar output of f() with
+ * respect to the list of trainable variables provided by `varList`. If no
+ * list is provided, it defaults to all trainable variables.
+ *
+ * @param f The function to execute and whose output to use for computing
+ * gradients with respect to variables.
+ * @param varList An optional list of variables to compute gradients with
+ * respect to. If specified, only the trainable variables in varList will have
+ * gradients computed with respect to. Defaults to all trainable variables.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers'}
+ */
+ Optimizer.prototype.computeGradients = function (f, varList) {
+ return variableGrads(f, varList);
+ };
+ /**
+ * Dispose the variables (if any) owned by this optimizer instance.
+ */
+ Optimizer.prototype.dispose = function () {
+ if (this.iterations_ != null) {
+ dispose(this.iterations_);
+ }
+ };
+ Optimizer.prototype.saveIterations = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ if (this.iterations_ == null) {
+ this.iterations_ = 0;
+ }
+ return [2 /*return*/, {
+ name: 'iter',
+ // TODO(cais): Use 'int64' type when available.
+ tensor: scalar(this.iterations_, 'int32')
+ }];
+ });
+ });
+ };
+ Optimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ throw new Error('getWeights() is not implemented for this optimizer yet.');
+ });
+ });
+ };
+ Optimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ throw new Error("setWeights() is not implemented for this optimizer class " +
+ ("" + this.getClassName()));
+ });
+ });
+ };
+ /**
+ * Extract the first element of the weight values and set it
+ * as the iterations counter variable of this instance of optimizer.
+ *
+ * @param weightValues
+ * @returns Weight values with the first element consumed and excluded.
+ */
+ Optimizer.prototype.extractIterations = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var _a;
+ return __generator(this, function (_b) {
+ switch (_b.label) {
+ case 0:
+ _a = this;
+ return [4 /*yield*/, weightValues[0].tensor.data()];
+ case 1:
+ _a.iterations_ = (_b.sent())[0];
+ return [2 /*return*/, weightValues.slice(1)];
+ }
+ });
+ });
+ };
+ return Optimizer;
+ }(Serializable));
+ Object.defineProperty(Optimizer, Symbol.hasInstance, {
+ value: function (instance) {
+ return instance.minimize != null && instance.computeGradients != null &&
+ instance.applyGradients != null;
+ }
+ });
+
+ /** @doclink Optimizer */
+ var AdadeltaOptimizer = /** @class */ (function (_super) {
+ __extends(AdadeltaOptimizer, _super);
+ function AdadeltaOptimizer(learningRate, rho, epsilon) {
+ if (epsilon === void 0) { epsilon = null; }
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.rho = rho;
+ _this.epsilon = epsilon;
+ _this.accumulatedGrads = [];
+ _this.accumulatedUpdates = [];
+ if (epsilon == null) {
+ _this.epsilon = ENGINE.backend.epsilon();
+ }
+ return _this;
+ }
+ AdadeltaOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var variableNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (item) { return item.name; }) :
+ Object.keys(variableGradients);
+ variableNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ var trainable = false;
+ if (_this.accumulatedGrads[i] == null) {
+ _this.accumulatedGrads[i] = {
+ originalName: name + "/accum_grad",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ if (_this.accumulatedUpdates[i] == null) {
+ _this.accumulatedUpdates[i] = {
+ originalName: name + "/accum_var",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var accumulatedGrad = _this.accumulatedGrads[i].variable;
+ var accumulatedUpdate = _this.accumulatedUpdates[i].variable;
+ tidy(function () {
+ var newAccumulatedGrad = add(mul(accumulatedGrad, _this.rho), mul(square(gradient), 1 - _this.rho));
+ var updates = mul(div(sqrt(add(accumulatedUpdate, _this.epsilon)), sqrt(add(accumulatedGrad, _this.epsilon))), gradient);
+ var newAccumulatedUpdate = add(mul(accumulatedUpdate, _this.rho), mul(square(updates), 1 - _this.rho));
+ accumulatedGrad.assign(newAccumulatedGrad);
+ accumulatedUpdate.assign(newAccumulatedUpdate);
+ var newValue = add(mul(updates, -_this.learningRate), value);
+ value.assign(newValue);
+ });
+ });
+ this.incrementIterations();
+ };
+ AdadeltaOptimizer.prototype.dispose = function () {
+ if (this.accumulatedUpdates != null) {
+ dispose(this.accumulatedGrads.map(function (v) { return v.variable; }));
+ dispose(this.accumulatedUpdates.map(function (v) { return v.variable; }));
+ }
+ };
+ AdadeltaOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var variables;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ variables = __spread(this.accumulatedGrads, this.accumulatedUpdates);
+ return [4 /*yield*/, this.saveIterations()];
+ case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))];
+ }
+ });
+ });
+ };
+ AdadeltaOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var variableCount, trainable;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ variableCount = weightValues.length / 2;
+ trainable = false;
+ this.accumulatedGrads =
+ weightValues.slice(0, variableCount).map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ this.accumulatedUpdates =
+ weightValues.slice(variableCount, variableCount * 2)
+ .map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ AdadeltaOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'rho': this.rho,
+ 'epsilon': this.epsilon
+ };
+ };
+ /** @nocollapse */
+ AdadeltaOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['rho'], config['epsilon']);
+ };
+ return AdadeltaOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ AdadeltaOptimizer.className = 'Adadelta'; // Name matters for Python compatibility.
+ registerClass(AdadeltaOptimizer);
+
+ /** @doclink Optimizer */
+ var AdagradOptimizer = /** @class */ (function (_super) {
+ __extends(AdagradOptimizer, _super);
+ function AdagradOptimizer(learningRate, initialAccumulatorValue) {
+ if (initialAccumulatorValue === void 0) { initialAccumulatorValue = 0.1; }
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.initialAccumulatorValue = initialAccumulatorValue;
+ _this.accumulatedGrads = [];
+ return _this;
+ }
+ AdagradOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var variableNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (item) { return item.name; }) :
+ Object.keys(variableGradients);
+ variableNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ if (_this.accumulatedGrads[i] == null) {
+ var trainable_1 = false;
+ _this.accumulatedGrads[i] = {
+ originalName: name + "/accumulator",
+ variable: tidy(function () { return fill(value.shape, _this.initialAccumulatorValue)
+ .variable(trainable_1); })
+ };
+ }
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var accumulatedGrad = _this.accumulatedGrads[i].variable;
+ tidy(function () {
+ var newAccumulatedGrad = add(accumulatedGrad, square(gradient));
+ accumulatedGrad.assign(newAccumulatedGrad);
+ var newValue = add(mul(div(gradient, sqrt(add(newAccumulatedGrad, ENGINE.backend.epsilon()))), -_this.learningRate), value);
+ value.assign(newValue);
+ });
+ });
+ this.incrementIterations();
+ };
+ AdagradOptimizer.prototype.dispose = function () {
+ if (this.accumulatedGrads != null) {
+ dispose(this.accumulatedGrads.map(function (v) { return v.variable; }));
+ }
+ };
+ AdagradOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.saveIterations()];
+ case 1:
+ // Order matters for Python compatibility.
+ return [2 /*return*/, [_a.sent()].concat(this.accumulatedGrads.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))];
+ }
+ });
+ });
+ };
+ AdagradOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var trainable;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ trainable = false;
+ this.accumulatedGrads = weightValues.map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); });
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ AdagradOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'initialAccumulatorValue': this.initialAccumulatorValue,
+ };
+ };
+ /** @nocollapse */
+ AdagradOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['initialAccumulatorValue']);
+ };
+ return AdagradOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ AdagradOptimizer.className = 'Adagrad'; // Note: Name matters for Python compatibility.
+ registerClass(AdagradOptimizer);
+
+ var AdamOptimizer = /** @class */ (function (_super) {
+ __extends(AdamOptimizer, _super);
+ function AdamOptimizer(learningRate, beta1, beta2, epsilon) {
+ if (epsilon === void 0) { epsilon = null; }
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.beta1 = beta1;
+ _this.beta2 = beta2;
+ _this.epsilon = epsilon;
+ _this.accumulatedFirstMoment = [];
+ _this.accumulatedSecondMoment = [];
+ tidy(function () {
+ // accB* will be updated by batch.
+ _this.accBeta1 = scalar(beta1).variable();
+ _this.accBeta2 = scalar(beta2).variable();
+ });
+ if (epsilon == null) {
+ _this.epsilon = ENGINE.backend.epsilon();
+ }
+ return _this;
+ }
+ AdamOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var varNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (v) { return v.name; }) :
+ Object.keys(variableGradients);
+ tidy(function () {
+ var oneMinusAccBeta1 = sub(1, _this.accBeta1);
+ var oneMinusAccBeta2 = sub(1, _this.accBeta2);
+ varNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ var trainable = false;
+ if (_this.accumulatedFirstMoment[i] == null) {
+ _this.accumulatedFirstMoment[i] = {
+ originalName: name + "/m",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ if (_this.accumulatedSecondMoment[i] == null) {
+ _this.accumulatedSecondMoment[i] = {
+ originalName: name + "/v",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var firstMoment = _this.accumulatedFirstMoment[i].variable;
+ var secondMoment = _this.accumulatedSecondMoment[i].variable;
+ var newFirstMoment = add(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1));
+ var newSecondMoment = add(mul(secondMoment, _this.beta2), mul(square(gradient), 1 - _this.beta2));
+ var biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
+ var biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
+ firstMoment.assign(newFirstMoment);
+ secondMoment.assign(newSecondMoment);
+ var newValue = add(mul(div(biasCorrectedFirstMoment, add(sqrt(biasCorrectedSecondMoment), _this.epsilon)), -_this.learningRate), value);
+ value.assign(newValue);
+ });
+ _this.accBeta1.assign(mul(_this.accBeta1, _this.beta1));
+ _this.accBeta2.assign(mul(_this.accBeta2, _this.beta2));
+ });
+ this.incrementIterations();
+ };
+ AdamOptimizer.prototype.dispose = function () {
+ this.accBeta1.dispose();
+ this.accBeta2.dispose();
+ if (this.accumulatedFirstMoment != null) {
+ dispose(this.accumulatedFirstMoment.map(function (v) { return v.variable; }));
+ }
+ if (this.accumulatedSecondMoment != null) {
+ dispose(this.accumulatedSecondMoment.map(function (v) { return v.variable; }));
+ }
+ };
+ AdamOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var variables;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ variables = __spread(this.accumulatedFirstMoment, this.accumulatedSecondMoment);
+ return [4 /*yield*/, this.saveIterations()];
+ case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))];
+ }
+ });
+ });
+ };
+ AdamOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var variableCount, trainable;
+ var _this = this;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ tidy(function () {
+ _this.accBeta1.assign(pow(_this.beta1, _this.iterations_ + 1));
+ _this.accBeta2.assign(pow(_this.beta2, _this.iterations_ + 1));
+ });
+ variableCount = weightValues.length / 2;
+ trainable = false;
+ this.accumulatedFirstMoment =
+ weightValues.slice(0, variableCount).map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ this.accumulatedSecondMoment =
+ weightValues.slice(variableCount, variableCount * 2)
+ .map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ AdamOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'beta1': this.beta1,
+ 'beta2': this.beta2,
+ 'epsilon': this.epsilon,
+ };
+ };
+ /** @nocollapse */
+ AdamOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
+ };
+ return AdamOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ AdamOptimizer.className = 'Adam'; // Note: Name matters for Python compatibility.
+ registerClass(AdamOptimizer);
+
+ var AdamaxOptimizer = /** @class */ (function (_super) {
+ __extends(AdamaxOptimizer, _super);
+ function AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay) {
+ if (epsilon === void 0) { epsilon = null; }
+ if (decay === void 0) { decay = 0.0; }
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.beta1 = beta1;
+ _this.beta2 = beta2;
+ _this.epsilon = epsilon;
+ _this.decay = decay;
+ _this.accumulatedFirstMoment = [];
+ _this.accumulatedWeightedInfNorm = [];
+ tidy(function () {
+ _this.iteration = scalar(0).variable();
+ _this.accBeta1 = scalar(beta1).variable();
+ });
+ if (epsilon == null) {
+ _this.epsilon = ENGINE.backend.epsilon();
+ }
+ return _this;
+ }
+ AdamaxOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var variableNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (item) { return item.name; }) :
+ Object.keys(variableGradients);
+ tidy(function () {
+ var oneMinusAccBeta1 = sub(1, _this.accBeta1);
+ var lr = div(-_this.learningRate, add(mul(_this.iteration, _this.decay), 1));
+ variableNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ var trainable = false;
+ if (_this.accumulatedFirstMoment[i] == null) {
+ _this.accumulatedFirstMoment[i] = {
+ originalName: name + "/m",
+ variable: zerosLike(value).variable(trainable)
+ };
+ }
+ if (_this.accumulatedWeightedInfNorm[i] == null) {
+ _this.accumulatedWeightedInfNorm[i] = {
+ originalName: name + "/v",
+ variable: zerosLike(value).variable(trainable)
+ };
+ }
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var firstMoment = _this.accumulatedFirstMoment[i].variable;
+ var weightedInfNorm = _this.accumulatedWeightedInfNorm[i].variable;
+ var newFirstMoment = add(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1));
+ var ut0 = mul(weightedInfNorm, _this.beta2);
+ var ut1 = abs(gradient);
+ var newWeightedInfNorm = maximum(ut0, ut1);
+ firstMoment.assign(newFirstMoment);
+ weightedInfNorm.assign(newWeightedInfNorm);
+ var newValue = add(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add(newWeightedInfNorm, _this.epsilon))), value);
+ value.assign(newValue);
+ });
+ _this.iteration.assign(add(_this.iteration, 1));
+ _this.accBeta1.assign(mul(_this.accBeta1, _this.beta1));
+ });
+ this.incrementIterations();
+ };
+ AdamaxOptimizer.prototype.dispose = function () {
+ this.accBeta1.dispose();
+ this.iteration.dispose();
+ if (this.accumulatedFirstMoment != null) {
+ dispose(this.accumulatedFirstMoment.map(function (v) { return v.variable; }));
+ }
+ if (this.accumulatedWeightedInfNorm != null) {
+ dispose(this.accumulatedWeightedInfNorm.map(function (v) { return v.variable; }));
+ }
+ };
+ AdamaxOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ throw new Error('getWeights() is not implemented for Adamax yet.');
+ });
+ });
+ };
+ AdamaxOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ throw new Error('setWeights() is not implemented for Adamax yet.');
+ });
+ });
+ };
+ AdamaxOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'beta1': this.beta1,
+ 'beta2': this.beta2,
+ 'epsilon': this.epsilon,
+ 'decay': this.decay
+ };
+ };
+ /** @nocollapse */
+ AdamaxOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
+ };
+ return AdamaxOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ AdamaxOptimizer.className = 'Adamax'; // Note: Name matters for Python compatbility.
+ registerClass(AdamaxOptimizer);
+
+ /** @doclink Optimizer */
+ var SGDOptimizer = /** @class */ (function (_super) {
+ __extends(SGDOptimizer, _super);
+ function SGDOptimizer(learningRate) {
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.setLearningRate(learningRate);
+ return _this;
+ }
+ SGDOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var varNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (v) { return v.name; }) :
+ Object.keys(variableGradients);
+ varNames.forEach(function (name, i) {
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var value = ENGINE.registeredVariables[name];
+ tidy(function () {
+ var newValue = add(mul(_this.c, gradient), value);
+ value.assign(newValue);
+ });
+ });
+ this.incrementIterations();
+ };
+ /**
+ * Sets the learning rate of the optimizer.
+ */
+ SGDOptimizer.prototype.setLearningRate = function (learningRate) {
+ this.learningRate = learningRate;
+ if (this.c != null) {
+ this.c.dispose();
+ }
+ this.c = keep(scalar(-learningRate));
+ };
+ SGDOptimizer.prototype.dispose = function () {
+ this.c.dispose();
+ };
+ SGDOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.saveIterations()];
+ case 1: return [2 /*return*/, [_a.sent()]];
+ }
+ });
+ });
+ };
+ SGDOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ if (weightValues.length !== 0) {
+ throw new Error('SGD optimizer does not have settable weights.');
+ }
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ SGDOptimizer.prototype.getConfig = function () {
+ return { 'learningRate': this.learningRate };
+ };
+ /** @nocollapse */
+ SGDOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate']);
+ };
+ return SGDOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ SGDOptimizer.className = 'SGD'; // Note: Name matters for Python compatibility.
+ registerClass(SGDOptimizer);
+
+ /** @doclink Optimizer */
+ var MomentumOptimizer = /** @class */ (function (_super) {
+ __extends(MomentumOptimizer, _super);
+ function MomentumOptimizer(learningRate, momentum, useNesterov) {
+ if (useNesterov === void 0) { useNesterov = false; }
+ var _this = _super.call(this, learningRate) || this;
+ _this.learningRate = learningRate;
+ _this.momentum = momentum;
+ _this.useNesterov = useNesterov;
+ _this.accumulations = [];
+ _this.m = scalar(_this.momentum);
+ return _this;
+ }
+ MomentumOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var variableNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (item) { return item.name; }) :
+ Object.keys(variableGradients);
+ variableNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ if (_this.accumulations[i] == null) {
+ var trainable_1 = false;
+ _this.accumulations[i] = {
+ originalName: name + "/momentum",
+ variable: tidy(function () { return zerosLike(value).variable(trainable_1); })
+ };
+ }
+ var accumulation = _this.accumulations[i].variable;
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ tidy(function () {
+ var newValue;
+ var newAccumulation = add(mul(_this.m, accumulation), gradient);
+ if (_this.useNesterov) {
+ newValue = add(mul(_this.c, add(gradient, mul(newAccumulation, _this.m))), value);
+ }
+ else {
+ newValue = add(mul(_this.c, newAccumulation), value);
+ }
+ accumulation.assign(newAccumulation);
+ value.assign(newValue);
+ });
+ });
+ this.incrementIterations();
+ };
+ MomentumOptimizer.prototype.dispose = function () {
+ this.m.dispose();
+ if (this.accumulations != null) {
+ dispose(this.accumulations.map(function (v) { return v.variable; }));
+ }
+ };
+ /**
+ * Sets the momentum of the optimizer.
+ *
+ * @param momentum
+ */
+ MomentumOptimizer.prototype.setMomentum = function (momentum) {
+ this.momentum = momentum;
+ };
+ MomentumOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.saveIterations()];
+ case 1:
+ // Order matters for Python compatibility.
+ return [2 /*return*/, [_a.sent()].concat(this.accumulations.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))];
+ }
+ });
+ });
+ };
+ MomentumOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var trainable;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ trainable = false;
+ this.accumulations = weightValues.map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); });
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ MomentumOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'momentum': this.momentum,
+ 'useNesterov': this.useNesterov
+ };
+ };
+ /** @nocollapse */
+ MomentumOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
+ };
+ return MomentumOptimizer;
+ }(SGDOptimizer));
+ /** @nocollapse */
+ MomentumOptimizer.className = 'Momentum'; // Name matters for Python compatibility.
+ registerClass(MomentumOptimizer);
+
+ /** @doclink Optimizer */
+ var RMSPropOptimizer = /** @class */ (function (_super) {
+ __extends(RMSPropOptimizer, _super);
+ function RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered) {
+ if (decay === void 0) { decay = 0.9; }
+ if (momentum === void 0) { momentum = 0.0; }
+ if (epsilon === void 0) { epsilon = null; }
+ if (centered === void 0) { centered = false; }
+ var _this = _super.call(this) || this;
+ _this.learningRate = learningRate;
+ _this.decay = decay;
+ _this.momentum = momentum;
+ _this.epsilon = epsilon;
+ _this.accumulatedMeanSquares = [];
+ _this.accumulatedMoments = [];
+ _this.accumulatedMeanGrads = [];
+ _this.centered = centered;
+ if (epsilon == null) {
+ _this.epsilon = ENGINE.backend.epsilon();
+ }
+ if (learningRate == null) {
+ throw new Error("learningRate for RMSPropOptimizer must be defined.");
+ }
+ return _this;
+ }
+ RMSPropOptimizer.prototype.applyGradients = function (variableGradients) {
+ var _this = this;
+ var variableNames = Array.isArray(variableGradients) ?
+ variableGradients.map(function (item) { return item.name; }) :
+ Object.keys(variableGradients);
+ variableNames.forEach(function (name, i) {
+ var value = ENGINE.registeredVariables[name];
+ var trainable = false;
+ if (_this.accumulatedMeanSquares[i] == null) {
+ _this.accumulatedMeanSquares[i] = {
+ originalName: name + "/rms",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ if (_this.accumulatedMoments[i] == null) {
+ _this.accumulatedMoments[i] = {
+ originalName: name + "/momentum",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ if (_this.accumulatedMeanGrads[i] == null && _this.centered) {
+ _this.accumulatedMeanGrads[i] = {
+ originalName: name + "/mg",
+ variable: tidy(function () { return zerosLike(value).variable(trainable); })
+ };
+ }
+ var gradient = Array.isArray(variableGradients) ?
+ variableGradients[i].tensor :
+ variableGradients[name];
+ if (gradient == null) {
+ return;
+ }
+ var accumulatedMeanSquare = _this.accumulatedMeanSquares[i].variable;
+ var accumulatedMoments = _this.accumulatedMoments[i].variable;
+ tidy(function () {
+ var newAccumulatedMeanSquare = add(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay));
+ if (_this.centered) {
+ var accumulatedMeanGrad = _this.accumulatedMeanGrads[i].variable;
+ // Centered gradient
+ var newAccumulatedMeanGrad = add(mul(accumulatedMeanGrad, _this.decay), mul(gradient, 1 - _this.decay));
+ var gradContribution = div(mul(gradient, _this.learningRate), sqrt(sub(newAccumulatedMeanSquare, add(square(newAccumulatedMeanGrad), _this.epsilon))));
+ var newAccumulatedMoments = add(mul(accumulatedMoments, _this.momentum), gradContribution);
+ accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
+ accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
+ accumulatedMoments.assign(newAccumulatedMoments);
+ var newValue = sub(value, newAccumulatedMoments);
+ value.assign(newValue);
+ }
+ else {
+ // Plain gradient
+ var newAccumulatedMeanSquare_1 = add(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay));
+ var newAccumulatedMoments = add(mul(accumulatedMoments, _this.momentum), div(mul(gradient, _this.learningRate), sqrt(add(newAccumulatedMeanSquare_1, _this.epsilon))));
+ accumulatedMeanSquare.assign(newAccumulatedMeanSquare_1);
+ accumulatedMoments.assign(newAccumulatedMoments);
+ var newValue = sub(value, newAccumulatedMoments);
+ value.assign(newValue);
+ }
+ });
+ });
+ this.incrementIterations();
+ };
+ RMSPropOptimizer.prototype.dispose = function () {
+ if (this.accumulatedMeanSquares != null) {
+ dispose(this.accumulatedMeanSquares.map(function (v) { return v.variable; }));
+ }
+ if (this.accumulatedMeanGrads != null && this.centered) {
+ dispose(this.accumulatedMeanGrads.map(function (v) { return v.variable; }));
+ }
+ if (this.accumulatedMoments != null) {
+ dispose(this.accumulatedMoments.map(function (v) { return v.variable; }));
+ }
+ };
+ RMSPropOptimizer.prototype.getWeights = function () {
+ return __awaiter(this, void 0, void 0, function () {
+ var variables;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0:
+ variables = __spread(this.accumulatedMeanSquares, this.accumulatedMoments);
+ if (this.centered) {
+ variables.push.apply(variables, __spread(this.accumulatedMeanGrads));
+ }
+ return [4 /*yield*/, this.saveIterations()];
+ case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))];
+ }
+ });
+ });
+ };
+ RMSPropOptimizer.prototype.setWeights = function (weightValues) {
+ return __awaiter(this, void 0, void 0, function () {
+ var variableCount, trainable;
+ return __generator(this, function (_a) {
+ switch (_a.label) {
+ case 0: return [4 /*yield*/, this.extractIterations(weightValues)];
+ case 1:
+ weightValues = _a.sent();
+ variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
+ trainable = false;
+ this.accumulatedMeanSquares =
+ weightValues.slice(0, variableCount).map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ this.accumulatedMoments =
+ weightValues.slice(variableCount, variableCount * 2)
+ .map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ if (this.centered) {
+ this.accumulatedMeanGrads =
+ weightValues.slice(variableCount * 2, variableCount * 3)
+ .map(function (v) { return ({
+ originalName: v.name,
+ variable: v.tensor.variable(trainable)
+ }); });
+ }
+ return [2 /*return*/];
+ }
+ });
+ });
+ };
+ RMSPropOptimizer.prototype.getConfig = function () {
+ return {
+ 'learningRate': this.learningRate,
+ 'decay': this.decay,
+ 'momentum': this.momentum,
+ 'epsilon': this.epsilon,
+ 'centered': this.centered
+ };
+ };
+ /** @nocollapse */
+ RMSPropOptimizer.fromConfig = function (cls, config) {
+ return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
+ };
+ return RMSPropOptimizer;
+ }(Optimizer));
+ /** @nocollapse */
+ RMSPropOptimizer.className = 'RMSProp'; // Note: Name matters for Python compatibility.
+ registerClass(RMSPropOptimizer);
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var OptimizerConstructors = /** @class */ (function () {
+ function OptimizerConstructors() {
+ }
+ /**
+ * Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
+ *
+ * ```js
+ * // Fit a quadratic function by learning the coefficients a, b, c.
+ * const xs = tf.tensor1d([0, 1, 2, 3]);
+ * const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
+ *
+ * const a = tf.scalar(Math.random()).variable();
+ * const b = tf.scalar(Math.random()).variable();
+ * const c = tf.scalar(Math.random()).variable();
+ *
+ * // y = a * x^2 + b * x + c.
+ * const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
+ * const loss = (pred, label) => pred.sub(label).square().mean();
+ *
+ * const learningRate = 0.01;
+ * const optimizer = tf.train.sgd(learningRate);
+ *
+ * // Train the model.
+ * for (let i = 0; i < 10; i++) {
+ * optimizer.minimize(() => loss(f(xs), ys));
+ * }
+ *
+ * // Make predictions.
+ * console.log(
+ * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
+ * const preds = f(xs).dataSync();
+ * preds.forEach((pred, i) => {
+ * console.log(`x: ${i}, pred: ${pred}`);
+ * });
+ * ```
+ *
+ * @param learningRate The learning rate to use for the SGD algorithm.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.sgd = function (learningRate) {
+ return new SGDOptimizer(learningRate);
+ };
+ /**
+ * Constructs a `tf.MomentumOptimizer` that uses momentum gradient
+ * descent.
+ *
+ * See
+ * [http://proceedings.mlr.press/v28/sutskever13.pdf](
+ * http://proceedings.mlr.press/v28/sutskever13.pdf)
+ *
+ * @param learningRate The learning rate to use for the Momentum gradient
+ * descent algorithm.
+ * @param momentum The momentum to use for the momentum gradient descent
+ * algorithm.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.momentum = function (learningRate, momentum, useNesterov) {
+ if (useNesterov === void 0) { useNesterov = false; }
+ return new MomentumOptimizer(learningRate, momentum, useNesterov);
+ };
+ /**
+ * Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
+ * descent. This implementation uses plain momentum and is not centered
+ * version of RMSProp.
+ *
+ * See
+ * [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
+ * http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
+ *
+ * @param learningRate The learning rate to use for the RMSProp gradient
+ * descent algorithm.
+ * @param decay The discounting factor for the history/coming gradient.
+ * @param momentum The momentum to use for the RMSProp gradient descent
+ * algorithm.
+ * @param epsilon Small value to avoid zero denominator.
+ * @param centered If true, gradients are normalized by the estimated
+ * variance of the gradient.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.rmsprop = function (learningRate, decay, momentum, epsilon, centered) {
+ if (decay === void 0) { decay = .9; }
+ if (momentum === void 0) { momentum = 0.0; }
+ if (epsilon === void 0) { epsilon = null; }
+ if (centered === void 0) { centered = false; }
+ return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
+ };
+ /**
+ * Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
+ * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
+ *
+ * @param learningRate The learning rate to use for the Adam gradient
+ * descent algorithm.
+ * @param beta1 The exponential decay rate for the 1st moment estimates.
+ * @param beta2 The exponential decay rate for the 2nd moment estimates.
+ * @param epsilon A small constant for numerical stability.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.adam = function (learningRate, beta1, beta2, epsilon) {
+ if (learningRate === void 0) { learningRate = 0.001; }
+ if (beta1 === void 0) { beta1 = 0.9; }
+ if (beta2 === void 0) { beta2 = 0.999; }
+ if (epsilon === void 0) { epsilon = null; }
+ return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
+ };
+ /**
+ * Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
+ * See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
+ *
+ * @param learningRate The learning rate to use for the Adadelta gradient
+ * descent algorithm.
+ * @param rho The learning rate decay over each update.
+ * @param epsilon A constant epsilon used to better condition the grad
+ * update.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.adadelta = function (learningRate, rho, epsilon) {
+ if (learningRate === void 0) { learningRate = .001; }
+ if (rho === void 0) { rho = .95; }
+ if (epsilon === void 0) { epsilon = null; }
+ return new AdadeltaOptimizer(learningRate, rho, epsilon);
+ };
+ /**
+ * Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
+ * See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
+ *
+ * @param learningRate The learning rate to use for the Adamax gradient
+ * descent algorithm.
+ * @param beta1 The exponential decay rate for the 1st moment estimates.
+ * @param beta2 The exponential decay rate for the 2nd moment estimates.
+ * @param epsilon A small constant for numerical stability.
+ * @param decay The learning rate decay over each update.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.adamax = function (learningRate, beta1, beta2, epsilon, decay) {
+ if (learningRate === void 0) { learningRate = 0.002; }
+ if (beta1 === void 0) { beta1 = 0.9; }
+ if (beta2 === void 0) { beta2 = 0.999; }
+ if (epsilon === void 0) { epsilon = null; }
+ if (decay === void 0) { decay = 0.0; }
+ return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
+ };
+ /**
+ * Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
+ * See
+ * [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
+ * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
+ * or
+ * [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
+ * http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
+ *
+ * @param learningRate The learning rate to use for the Adagrad gradient
+ * descent algorithm.
+ * @param initialAccumulatorValue Starting value for the accumulators, must be
+ * positive.
+ *
+ * @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
+ */
+ OptimizerConstructors.adagrad = function (learningRate, initialAccumulatorValue) {
+ if (initialAccumulatorValue === void 0) { initialAccumulatorValue = 0.1; }
+ return new AdagradOptimizer(learningRate, initialAccumulatorValue);
+ };
+ return OptimizerConstructors;
+ }());
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var train = {
+ sgd: OptimizerConstructors.sgd,
+ momentum: OptimizerConstructors.momentum,
+ adadelta: OptimizerConstructors.adadelta,
+ adagrad: OptimizerConstructors.adagrad,
+ rmsprop: OptimizerConstructors.rmsprop,
+ adamax: OptimizerConstructors.adamax,
+ adam: OptimizerConstructors.adam
+ };
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var delayCallback = (function () {
+ if (typeof requestAnimationFrame !== 'undefined') {
+ return requestAnimationFrame;
+ }
+ else if (typeof setImmediate !== 'undefined') {
+ return setImmediate;
+ }
+ return function (f) { return f(); }; // no delays
+ })();
+ /**
+ * Returns a promise that resolve when a requestAnimationFrame has completed.
+ *
+ * On Node.js this uses setImmediate instead of requestAnimationFrame.
+ *
+ * This is simply a sugar method so that users can do the following:
+ * `await tf.nextFrame();`
+ *
+ * @doc {heading: 'Performance', subheading: 'Timing'}
+ */
+ function nextFrame() {
+ return new Promise(function (resolve) { return delayCallback(function () { return resolve(); }); });
+ }
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function assertParamsConsistent(shapes, axis) {
+ var rank = shapes[0].length;
+ shapes.forEach(function (shape, i) {
+ assert(shape.length === rank, function () { return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " +
+ ("as the rank of the rest (" + rank + ")"); });
+ });
+ assert(axis >= 0 && axis < rank, function () { return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + "."; });
+ var firstShape = shapes[0];
+ shapes.forEach(function (shape, i) {
+ for (var r = 0; r < rank; r++) {
+ assert((r === axis) || (shape[r] === firstShape[r]), function () { return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " +
+ ("does not match the shape of the rest (" + firstShape + ") ") +
+ ("along the non-concatenated axis " + i + "."); });
+ }
+ });
+ }
+ function computeOutShape$1(shapes, axis) {
+ var outputShape = shapes[0].slice();
+ for (var i = 1; i < shapes.length; i++) {
+ outputShape[axis] += shapes[i][axis];
+ }
+ return outputShape;
+ }
+
+ /**
+ * @license
+ * Copyright 2017 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var PARALLELIZE_THRESHOLD = 30;
+ function computeOptimalWindowSize(inSize) {
+ if (inSize <= PARALLELIZE_THRESHOLD) {
+ return inSize;
+ }
+ return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
+ }
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ // Returns the image center in pixels.
+ function getImageCenter(center, imageHeight, imageWidth) {
+ var centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
+ var centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
+ return [centerX, centerY];
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Gets the new shape of the input Tensor after it's been reshaped
+ * to:
+ * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
+ * inputShape[1], ..., inputShape[N-1]]
+ *
+ * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
+ */
+ function getReshaped(inputShape, blockShape, prod, batchToSpace) {
+ if (batchToSpace === void 0) { batchToSpace = true; }
+ var reshaped = [];
+ if (batchToSpace) {
+ reshaped = reshaped.concat(blockShape.slice(0));
+ reshaped.push(inputShape[0] / prod);
+ reshaped = reshaped.concat(inputShape.slice(1));
+ }
+ else {
+ reshaped = reshaped.concat(inputShape[0]);
+ var spatialLength = blockShape.length;
+ for (var i = 0; i < spatialLength; ++i) {
+ reshaped =
+ reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
+ }
+ reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
+ }
+ return reshaped;
+ }
+ /**
+ * Gets the permutation that will transpose the dimensions of the
+ * reshaped tensor to shape:
+ *
+ * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
+ * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
+ *
+ * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
+ */
+ function getPermuted(reshapedRank, blockShapeRank, batchToSpace) {
+ if (batchToSpace === void 0) { batchToSpace = true; }
+ var permuted = [];
+ if (batchToSpace) {
+ permuted.push(blockShapeRank);
+ for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
+ if (i <= 2 * blockShapeRank) {
+ permuted.push(i);
+ permuted.push(i - (blockShapeRank + 1));
+ }
+ else {
+ permuted.push(i);
+ }
+ }
+ }
+ else {
+ var permutedBeforeBatch = [];
+ var permutedAfterBatch = [];
+ for (var i = 1; i < reshapedRank; ++i) {
+ if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
+ permutedAfterBatch.push(i);
+ }
+ else {
+ permutedBeforeBatch.push(i);
+ }
+ }
+ permuted.push.apply(permuted, __spread(permutedBeforeBatch));
+ permuted.push(0);
+ permuted.push.apply(permuted, __spread(permutedAfterBatch));
+ }
+ return permuted;
+ }
+ /**
+ * Gets the shape of the reshaped and permuted input Tensor before any cropping
+ * is applied. The new shape will be:
+ *
+ * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
+ * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
+ *
+ * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
+ */
+ function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) {
+ if (batchToSpace === void 0) { batchToSpace = true; }
+ var reshapedPermuted = [];
+ if (batchToSpace) {
+ reshapedPermuted.push(inputShape[0] / prod);
+ }
+ else {
+ reshapedPermuted.push(inputShape[0] * prod);
+ }
+ for (var i = 1; i < inputShape.length; ++i) {
+ if (i <= blockShape.length) {
+ if (batchToSpace) {
+ reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
+ }
+ else {
+ reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
+ }
+ }
+ else {
+ reshapedPermuted.push(inputShape[i]);
+ }
+ }
+ return reshapedPermuted;
+ }
+ /**
+ * Converts the crops argument into the beginning coordinates of a slice
+ * operation.
+ */
+ function getSliceBeginCoords(crops, blockShape) {
+ var sliceBeginCoords = [0];
+ for (var i = 0; i < blockShape; ++i) {
+ sliceBeginCoords.push(crops[i][0]);
+ }
+ return sliceBeginCoords;
+ }
+ /**
+ * Converts the crops argument into the size of a slice operation. When
+ * combined with getSliceBeginCoords this function allows the reshaped and
+ * permuted Tensor to be cropped to its final output shape of:
+ *
+ * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
+ * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
+ * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
+ *
+ * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
+ */
+ function getSliceSize(uncroppedShape, crops, blockShape) {
+ var sliceSize = uncroppedShape.slice(0, 1);
+ for (var i = 0; i < blockShape; ++i) {
+ sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
+ }
+ return sliceSize;
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var SELU_SCALEALPHA = 1.7580993408473768599402175208123;
+ var SELU_SCALE = 1.0507009873554804934193349852946;
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ var ERF_P = 0.3275911;
+ var ERF_A1 = 0.254829592;
+ var ERF_A2 = -0.284496736;
+ var ERF_A3 = 1.421413741;
+ var ERF_A4 = -1.453152027;
+ var ERF_A5 = 1.061405429;
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Merges real and imaginary Float32Arrays into a single complex Float32Array.
+ *
+ * The memory layout is interleaved as follows:
+ * real: [r0, r1, r2]
+ * imag: [i0, i1, i2]
+ * complex: [r0, i0, r1, i1, r2, i2]
+ *
+ * This is the inverse of splitRealAndImagArrays.
+ *
+ * @param real The real values of the complex tensor values.
+ * @param imag The imag values of the complex tensor values.
+ * @returns A complex tensor as a Float32Array with merged values.
+ */
+ function mergeRealAndImagArrays(real, imag) {
+ if (real.length !== imag.length) {
+ throw new Error("Cannot merge real and imag arrays of different lengths. real:" +
+ (real.length + ", imag: " + imag.length + "."));
+ }
+ var result = new Float32Array(real.length * 2);
+ for (var i = 0; i < result.length; i += 2) {
+ result[i] = real[i / 2];
+ result[i + 1] = imag[i / 2];
+ }
+ return result;
+ }
+ /**
+ * Splits a complex Float32Array into real and imag parts.
+ *
+ * The memory layout is interleaved as follows:
+ * complex: [r0, i0, r1, i1, r2, i2]
+ * real: [r0, r1, r2]
+ * imag: [i0, i1, i2]
+ *
+ * This is the inverse of mergeRealAndImagArrays.
+ *
+ * @param complex The complex tensor values.
+ * @returns An object with real and imag Float32Array components of the complex
+ * tensor.
+ */
+ function splitRealAndImagArrays(complex) {
+ var real = new Float32Array(complex.length / 2);
+ var imag = new Float32Array(complex.length / 2);
+ for (var i = 0; i < complex.length; i += 2) {
+ real[i / 2] = complex[i];
+ imag[i / 2] = complex[i + 1];
+ }
+ return { real: real, imag: imag };
+ }
+ /**
+ * Extracts even indexed complex values in the given array.
+ * @param complex The complex tensor values
+ */
+ function complexWithEvenIndex(complex) {
+ var len = Math.ceil(complex.length / 4);
+ var real = new Float32Array(len);
+ var imag = new Float32Array(len);
+ for (var i = 0; i < complex.length; i += 4) {
+ real[Math.floor(i / 4)] = complex[i];
+ imag[Math.floor(i / 4)] = complex[i + 1];
+ }
+ return { real: real, imag: imag };
+ }
+ /**
+ * Extracts odd indexed comple values in the given array.
+ * @param complex The complex tensor values
+ */
+ function complexWithOddIndex(complex) {
+ var len = Math.floor(complex.length / 4);
+ var real = new Float32Array(len);
+ var imag = new Float32Array(len);
+ for (var i = 2; i < complex.length; i += 4) {
+ real[Math.floor(i / 4)] = complex[i];
+ imag[Math.floor(i / 4)] = complex[i + 1];
+ }
+ return { real: real, imag: imag };
+ }
+ /**
+ * Get the map representing a complex value in the given array.
+ * @param complex The complex tensor values.
+ * @param index An index of the target complex value.
+ */
+ function getComplexWithIndex(complex, index) {
+ var real = complex[index * 2];
+ var imag = complex[index * 2 + 1];
+ return { real: real, imag: imag };
+ }
+ /**
+ * Insert a given complex value into the TypedArray.
+ * @param data The array in which the complex value is inserted.
+ * @param c The complex value to be inserted.
+ * @param index An index of the target complex value.
+ */
+ function assignToTypedArray(data, real, imag, index) {
+ data[index * 2] = real;
+ data[index * 2 + 1] = imag;
+ }
+ /**
+ * Make the list of exponent terms used by FFT.
+ */
+ function exponents(n, inverse) {
+ var real = new Float32Array(n / 2);
+ var imag = new Float32Array(n / 2);
+ for (var i = 0; i < Math.ceil(n / 2); i++) {
+ var x = (inverse ? 2 : -2) * Math.PI * (i / n);
+ real[i] = Math.cos(x);
+ imag[i] = Math.sin(x);
+ }
+ return { real: real, imag: imag };
+ }
+ /**
+ * Make the exponent term used by FFT.
+ */
+ function exponent(k, n, inverse) {
+ var x = (inverse ? 2 : -2) * Math.PI * (k / n);
+ var real = Math.cos(x);
+ var imag = Math.sin(x);
+ return { real: real, imag: imag };
+ }
+
+ var ARROW = '->';
+ var ARROW_REGEX = /->/g;
+ var COMMA = ',';
+ var ELLIPSIS = '...';
+ /**
+ * Parse an equation for einsum.
+ *
+ * @param equation The einsum equation (e.g., "ij,jk->ik").
+ * @param numTensors Number of tensors provided along with `equation`. Used to
+ * check matching number of input tensors.
+ * @returns An object consisting of the following fields:
+ * - allDims: all dimension names as strings.
+ * - summedDims: a list of all dimensions being summed over, as indices to
+ * the elements of `allDims`.
+ * - idDims: indices of the dimensions in each input tensor, as indices to
+ * the elements of `allDims.
+ */
+ function decodeEinsumEquation(equation, numTensors) {
+ equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
+ var numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
+ ARROW.length;
+ if (numArrows < 1) {
+ throw new Error('Equations without an arrow are not supported.');
+ }
+ else if (numArrows > 1) {
+ throw new Error("Equation must contain exactly one arrow (\"" + ARROW + "\").");
+ }
+ var _a = __read(equation.split(ARROW), 2), inputString = _a[0], outputString = _a[1];
+ assert(inputString.indexOf(ELLIPSIS) === -1, function () { return "The ellipsis notation (\"" + ELLIPSIS + "\") is not supported yet."; });
+ var inputTerms = inputString.split(COMMA);
+ var numInputs = inputTerms.length;
+ if (numTensors !== numInputs) {
+ throw new Error("Expected " + numInputs + " input tensors, received " + numTensors);
+ }
+ if (numInputs > 2) {
+ throw new Error('Support for more than 2 input tensors is not implemented yet.');
+ }
+ var allDims = [];
+ var _loop_1 = function (i) {
+ var dimName = outputString[i];
+ if (!inputTerms.some(function (inputTerm) { return inputTerm.indexOf(dimName) !== -1; })) {
+ throw new Error("Output subscripts contain the label " + dimName + " " +
+ "not present in the input subscripts.");
+ }
+ if (allDims.indexOf(dimName) === -1) {
+ allDims.push(dimName);
+ }
+ };
+ for (var i = 0; i < outputString.length; ++i) {
+ _loop_1(i);
+ }
+ for (var i = 0; i < inputString.length; ++i) {
+ var dimName = inputString[i];
+ if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
+ allDims.push(dimName);
+ }
+ }
+ var idDims = new Array(inputTerms.length);
+ for (var i = 0; i < numInputs; ++i) {
+ if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
+ throw new Error("Found duplicate axes in input component " + inputTerms[i] + ". " +
+ "Support for duplicate axes in input is not implemented yet.");
+ }
+ idDims[i] = [];
+ for (var j = 0; j < inputTerms[i].length; ++j) {
+ idDims[i].push(allDims.indexOf(inputTerms[i][j]));
+ }
+ }
+ var numDims = allDims.length; // Number of unique dimensions.
+ var numOutDims = outputString.length; // Number of output dimensions.
+ var summedDims = []; // Dimensions being summed over.
+ for (var i = numOutDims; i < numDims; ++i) {
+ summedDims.push(i);
+ }
+ return { allDims: allDims, summedDims: summedDims, idDims: idDims };
+ }
+ /**
+ * Get the permutation for a given input tensor.
+ *
+ * @param nDims Total number of dimension of all tensors involved in the einsum
+ * operation.
+ * @param idDims Dimension indices involve in the tensor in question.
+ * @returns An object consisting of the following fields:
+ * - permutationIndices: Indices to permute the axes of the tensor with.
+ * - expandDims: Indices to the dimension that need to be expanded from the
+ * tensor after permutation.
+ */
+ function getEinsumPermutation(nDims, idDims) {
+ var permutationIndices = new Array(nDims);
+ permutationIndices.fill(-1);
+ for (var i = 0; i < idDims.length; ++i) {
+ permutationIndices[idDims[i]] = i;
+ }
+ var expandDims = [];
+ for (var i = 0; i < nDims; ++i) {
+ if (permutationIndices[i] === -1) {
+ expandDims.push(i);
+ }
+ }
+ permutationIndices = permutationIndices.filter(function (d) { return d !== -1; });
+ return { permutationIndices: permutationIndices, expandDims: expandDims };
+ }
+ /**
+ * Checks that the dimension sizes from different input tensors match the
+ * equation.
+ */
+ function checkEinsumDimSizes(nDims, idDims, tensors) {
+ var dimSizes = new Array(nDims);
+ var _loop_2 = function (i) {
+ var shape = tensors[i].shape;
+ var _loop_3 = function (j) {
+ if (dimSizes[idDims[i][j]] === undefined) {
+ dimSizes[idDims[i][j]] = shape[j];
+ }
+ else {
+ assert(dimSizes[idDims[i][j]] === shape[j], function () { return "Expected dimension " + dimSizes[idDims[i][j]] + " at axis " + j + " " +
+ ("of input shaped " + JSON.stringify(shape) + ", ") +
+ ("but got dimension " + shape[j]); });
+ }
+ };
+ for (var j = 0; j < idDims[i].length; ++j) {
+ _loop_3(j);
+ }
+ };
+ for (var i = 0; i < tensors.length; ++i) {
+ _loop_2(i);
+ }
+ }
+ /**
+ * Gets path of computation for einsum.
+ *
+ * @param summedDims indices to the dimensions being summed over.
+ * @param idDims A look up table for the dimensions present in each input
+ * tensor. Each consituent array contains indices for the dimensions in the
+ * corresponding input tensor.
+ *
+ * @return A map with two fields:
+ * - path: The path of computation, with each element indicating the dimension
+ * being summed over after the element-wise multiplication in that step.
+ * - steps: With the same length as `path`. Each element contains the indices
+ * to the input tensors being used for element-wise multiplication in the
+ * corresponding step.
+ */
+ function getEinsumComputePath(summedDims, idDims) {
+ var e_1, _a;
+ var path = summedDims;
+ var steps = [];
+ var nSteps = 0;
+ if (summedDims.length === 0) {
+ // Einsum that involes no summing: e.g., transpose and outer product.
+ path.push(-1);
+ }
+ nSteps = summedDims.length + 1;
+ for (var i = 0; i < nSteps; ++i) {
+ steps.push([]);
+ }
+ var computedTermIndices = [];
+ for (var i = 0; i < path.length; ++i) {
+ var summedDim = path[i];
+ var termIndices = findTermsWithDim(idDims, summedDim);
+ try {
+ for (var termIndices_1 = (e_1 = void 0, __values(termIndices)), termIndices_1_1 = termIndices_1.next(); !termIndices_1_1.done; termIndices_1_1 = termIndices_1.next()) {
+ var termIndex = termIndices_1_1.value;
+ if (computedTermIndices.indexOf(termIndex) === -1) {
+ steps[i].push(termIndex);
+ computedTermIndices.push(termIndex);
+ }
+ }
+ }
+ catch (e_1_1) { e_1 = { error: e_1_1 }; }
+ finally {
+ try {
+ if (termIndices_1_1 && !termIndices_1_1.done && (_a = termIndices_1.return)) _a.call(termIndices_1);
+ }
+ finally { if (e_1) throw e_1.error; }
+ }
+ }
+ return { path: path, steps: steps };
+ }
+ /** Determines if an axes permutation is the identity permutation. */
+ function isIdentityPermutation(perm) {
+ return perm.every(function (dim, index) { return dim === index; });
+ }
+ function findTermsWithDim(idDims, dim) {
+ var termIndices = [];
+ for (var i = 0; i < idDims.length; ++i) {
+ if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
+ termIndices.push(i);
+ }
+ }
+ return termIndices;
+ }
+
+ /**
+ * Prepare the split size array. When the input is a number, the axis is evenly
+ * divided among the split size. When the input contains the negative value, the
+ * rest of the axis is allocated toward that.
+ */
+ function prepareSplitSize(x, numOrSizeSplits, axis) {
+ if (axis === void 0) { axis = 0; }
+ var splitSizes = [];
+ if (typeof (numOrSizeSplits) === 'number') {
+ assert(x.shape[axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; });
+ splitSizes =
+ new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
+ }
+ else {
+ var numOfNegs = numOrSizeSplits.reduce(function (count, value) {
+ if (value === -1) {
+ count += 1;
+ }
+ return count;
+ }, 0);
+ assert(numOfNegs <= 1, function () { return 'There should be only one negative value in split array.'; });
+ var negIndex = numOrSizeSplits.indexOf(-1);
+ // Allow the number of split array to be -1, which indicates the rest
+ // of dimension is allocated to that split.
+ if (negIndex !== -1) {
+ var total = numOrSizeSplits.reduce(function (a, b) { return b > 0 ? a + b : a; });
+ numOrSizeSplits[negIndex] = x.shape[axis] - total;
+ }
+ assert(x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; });
+ splitSizes = numOrSizeSplits;
+ }
+ return splitSizes;
+ }
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Generates sparse fill empty rows indices, dense shape mismatch error message.
+ *
+ * @param indicesLength The first dimension of indices.
+ */
+ function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
+ return "Received SparseTensor with denseShape[0] = 0 but\n indices.shape[0] = " + indicesLength;
+ }
+ /**
+ * Generates sparse fill empty rows negative index error message.
+ *
+ * @param index The index with a negative value.
+ * @param value The negative value.
+ */
+ function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
+ return "indices(" + index + ", 0) is invalid: " + value + " < 0";
+ }
+ /**
+ * Generates sparse fill empty rows out of range index error message.
+ *
+ * @param index The index with an out of range value.
+ * @param value The out of range value.
+ * @param limit The upper limit for indices.
+ */
+ function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
+ return "indices(" + index + ", 0) is invalid: " + value + " >= " + limit;
+ }
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Generates sparse reshape multiple negative 1 output dimension error message.
+ *
+ * @param dim1 The first dimension with a negative 1 value.
+ * @param dim2 The second dimension with a negative 1 value.
+ */
+ function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
+ return "only one output dimension may be -1, not both " + dim1 + " and " + dim2;
+ }
+ /**
+ * Generates sparse reshape negative output dimension error message.
+ *
+ * @param dim The dimension with a negative value.
+ * @param value The negative value.
+ */
+ function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
+ return "size " + dim + " must be non-negative, not " + value;
+ }
+ /**
+ * Generates sparse reshape empty tensor zero output dimension error message.
+ *
+ */
+ function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
+ return 'reshape cannot infer the missing input size for an empty tensor ' +
+ 'unless all specified input sizes are non-zero';
+ }
+ /**
+ * Generates sparse reshape input output multiple mismatch error message.
+ *
+ * @param inputShape the input shape.
+ * @param outputShape the requested output shape.
+ */
+ function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
+ var inputSize = sizeFromShape(inputShape);
+ var outputSize = sizeFromShape(outputShape);
+ return "Input to reshape is a SparseTensor with " + inputSize + "\n dense values, but the requested shape requires a multiple of " + outputSize + ". inputShape=" + inputShape + " outputShape= " + outputShape;
+ }
+ /**
+ * Generates sparse reshape input output inequality error message.
+ *
+ * @param inputShape the input shape.
+ * @param outputShape the requested output shape.
+ */
+ function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
+ var inputSize = sizeFromShape(inputShape);
+ var outputSize = sizeFromShape(outputShape);
+ return "Input to reshape is a tensor with " + inputSize + " dense values, but the requested shape has " + outputSize + ". inputShape=" + inputShape + " outputShape=" + outputShape;
+ }
+
+ /**
+ * @license
+ * Copyright 2021 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ /**
+ * Generates sparse segment reduction negative segment ids error message.
+ *
+ */
+ function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
+ return "segment ids must be >= 0";
+ }
+ /**
+ * Generates sparse segment reduction non increasing segment ids error message.
+ *
+ */
+ function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
+ return "segment ids are not increasing";
+ }
+ /**
+ * Generates sparse segment reduction segment id out of range error message.
+ *
+ * @param segmentId The segment id index that is out of range.
+ * @param outputRows Upper bound of valid segment id values.
+ */
+ function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
+ return "Segment id " + segmentId + " out of range [0, " + outputRows + "), possibly because segmentIds input is not sorted.";
+ }
+ /**
+ * Generates sparse segment reduction input indice out of range error message.
+ *
+ * @param index The index that holds the out of range value.
+ * @param indexValue The value that is out of range.
+ * @param inputRows Upper bound of valid index values.
+ */
+ function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
+ return "Bad: indices[" + index + "] == " + indexValue + " out of range [0, " + inputRows + ")";
+ }
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function segOpComputeOptimalWindowSize(inSize, numSegments) {
+ var done = false;
+ var res;
+ if (inSize <= PARALLELIZE_THRESHOLD) {
+ res = inSize;
+ done = true;
+ }
+ else {
+ res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
+ }
+ while (!done) {
+ if (res > numSegments || res === inSize) {
+ done = true;
+ }
+ else {
+ res = nearestDivisor(inSize, res + 1);
+ }
+ }
+ return res;
+ }
+ function computeOutShape(aShape, axis, numSegments) {
+ var outShape = [];
+ var rank = aShape.length;
+ for (var dim = 0; dim < rank; dim++) {
+ if (dim !== axis) {
+ outShape.push(aShape[dim]);
+ }
+ else {
+ outShape.push(numSegments);
+ }
+ }
+ return outShape;
+ }
+ function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
+ var indicesRank = indices.shape.length;
+ var xRank = x.shape.length;
+ if (batchDims !== 0) {
+ if (batchDims < -indicesRank || batchDims > indicesRank) {
+ throw new Error("Expect batchDims in the range of [-" + indicesRank + ", " + indicesRank + "], but got " + batchDims);
+ }
+ }
+ if (batchDims < 0) {
+ batchDims += indicesRank;
+ }
+ if (batchDims > xRank) {
+ throw new Error("batchDims (" + batchDims + ") must be less than rank(x) (\n " + xRank + ").");
+ }
+ if (axis < batchDims) {
+ throw new Error("batchDims (" + batchDims + ") must be less than or equal to axis (" + axis + ").");
+ }
+ for (var i = 0; i < batchDims; ++i) {
+ if (x.shape[i] !== indices.shape[i]) {
+ throw new Error("x.shape[" + i + "]: " + x.shape[i] + " should be equal to indices.shape[" + i + "]: " + indices.shape[i] + ".");
+ }
+ }
+ var dimSize = x.shape[axis];
+ var outputShape = [];
+ var batchSize = 1;
+ var outerSize = 1;
+ var sliceSize = 1;
+ for (var i = 0; i < batchDims; ++i) {
+ outputShape.push(x.shape[i]);
+ batchSize *= x.shape[i];
+ }
+ for (var i = batchDims; i < axis; i++) {
+ outputShape.push(x.shape[i]);
+ outerSize *= x.shape[i];
+ }
+ for (var i = batchDims; i < indicesRank; i++) {
+ outputShape.push(indices.shape[i]);
+ }
+ for (var i = axis + 1; i < xRank; i++) {
+ outputShape.push(x.shape[i]);
+ sliceSize *= x.shape[i];
+ }
+ return { batchSize: batchSize, sliceSize: sliceSize, outerSize: outerSize, dimSize: dimSize, outputShape: outputShape };
+ }
+
+ var segment_util = {
+ __proto__: null,
+ segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize,
+ computeOutShape: computeOutShape,
+ collectGatherOpShapeInfo: collectGatherOpShapeInfo
+ };
+
+ /**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+ function fromUint8ToStringArray(vals) {
+ try {
+ // Decode the bytes into string.
+ return vals.map(function (val) { return decodeString(val); });
+ }
+ catch (err) {
+ throw new Error("Failed to decode encoded string bytes into utf-8, error: " + err);
+ }
+ }
+ function fromStringArrayToUint8(strings) {
+ return strings.map(function (s) { return encodeString(s); });
+ }
+
+ var backend_util = {
+ __proto__: null,
+ slice_util: slice_util,
+ segment_util: segment_util,
+ fromUint8ToStringArray: fromUint8ToStringArray,
+ fromStringArrayToUint8: fromStringArrayToUint8,
+ upcastType: upcastType,
+ axesAreInnerMostDims: axesAreInnerMostDims,
+ combineLocations: combineLocations,
+ computeOutAndReduceShapes: computeOutAndReduceShapes,
+ expandShapeToKeepDim: expandShapeToKeepDim,
+ assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
+ getAxesPermutation: getAxesPermutation,
+ getUndoAxesPermutation: getUndoAxesPermutation,
+ getInnerMostAxes: getInnerMostAxes,
+ getBroadcastDims: getBroadcastDims,
+ getReductionAxes: getReductionAxes,
+ assertAndGetBroadcastShape: assertAndGetBroadcastShape,
+ assertParamsConsistent: assertParamsConsistent,
+ computeOutShape: computeOutShape$1,
+ computeDilation2DInfo: computeDilation2DInfo,
+ computePool2DInfo: computePool2DInfo,
+ computePool3DInfo: computePool3DInfo,
+ computeConv2DInfo: computeConv2DInfo,
+ computeConv3DInfo: computeConv3DInfo,
+ computeDefaultPad: computeDefaultPad,
+ tupleValuesAreOne: tupleValuesAreOne,
+ eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
+ convertConv2DDataFormat: convertConv2DDataFormat,
+ checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
+ getFusedDyActivation: getFusedDyActivation,
+ getFusedBiasGradient: getFusedBiasGradient,
+ applyActivation: applyActivation,
+ shouldFuse: shouldFuse,
+ PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
+ computeOptimalWindowSize: computeOptimalWindowSize,
+ getImageCenter: getImageCenter,
+ getReshaped: getReshaped,
+ getPermuted: getPermuted,
+ getReshapedPermuted: getReshapedPermuted,
+ getSliceBeginCoords: getSliceBeginCoords,
+ getSliceSize: getSliceSize,
+ prepareAndValidate: prepareAndValidate,
+ validateUpdateShape: validateUpdateShape,
+ validateInput: validateInput$1,
+ calculateShapes: calculateShapes,
+ SELU_SCALEALPHA: SELU_SCALEALPHA,
+ SELU_SCALE: SELU_SCALE,
+ ERF_P: ERF_P,
+ ERF_A1: ERF_A1,
+ ERF_A2: ERF_A2,
+ ERF_A3: ERF_A3,
+ ERF_A4: ERF_A4,
+ ERF_A5: ERF_A5,
+ warn: warn,
+ log: log$1,
+ mergeRealAndImagArrays: mergeRealAndImagArrays,
+ splitRealAndImagArrays: splitRealAndImagArrays,
+ complexWithEvenIndex: complexWithEvenIndex,
+ complexWithOddIndex: complexWithOddIndex,
+ getComplexWithIndex: getComplexWithIndex,
+ assignToTypedArray: assignToTypedArray,
+ exponents: exponents,
+ exponent: exponent,
+ decodeEinsumEquation: decodeEinsumEquation,
+ getEinsumPermutation: getEinsumPermutation,
+ checkEinsumDimSizes: checkEinsumDimSizes,
+ getEinsumComputePath: getEinsumComputePath,
+ isIdentityPermutation: isIdentityPermutation,
+ prepareSplitSize: prepareSplitSize,
+ getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
+ getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
+ getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
+ getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
+ getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
+ getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
+ getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
+ getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
+ getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
+ getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
+ getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
+ getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage
+ };
+
+ /**
+ * @license
+ * Copyright 2020 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+ var kernel_impls = {
+ __proto__: null,
+ nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl,
+ nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl,
+ nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl,
+ whereImpl: whereImpl
+ };
+
+ exports.Abs = Abs;
+ exports.Acos = Acos;
+ exports.Acosh = Acosh;
+ exports.AdadeltaOptimizer = AdadeltaOptimizer;
+ exports.AdagradOptimizer = AdagradOptimizer;
+ exports.AdamOptimizer = AdamOptimizer;
+ exports.AdamaxOptimizer = AdamaxOptimizer;
+ exports.Add = Add;
+ exports.AddN = AddN;
+ exports.All = All;
+ exports.Any = Any;
+ exports.ArgMax = ArgMax;
+ exports.ArgMin = ArgMin;
+ exports.Asin = Asin;
+ exports.Asinh = Asinh;
+ exports.Atan = Atan;
+ exports.Atan2 = Atan2;
+ exports.Atanh = Atanh;
+ exports.AvgPool = AvgPool;
+ exports.AvgPool3D = AvgPool3D;
+ exports.AvgPool3DGrad = AvgPool3DGrad;
+ exports.AvgPoolGrad = AvgPoolGrad;
+ exports.BatchMatMul = BatchMatMul;
+ exports.BatchToSpaceND = BatchToSpaceND;
+ exports.Bincount = Bincount;
+ exports.BroadcastArgs = BroadcastArgs;
+ exports.BroadcastTo = BroadcastTo;
+ exports.Cast = Cast;
+ exports.Ceil = Ceil;
+ exports.ClipByValue = ClipByValue;
+ exports.Complex = Complex;
+ exports.ComplexAbs = ComplexAbs;
+ exports.Concat = Concat;
+ exports.Conv2D = Conv2D;
+ exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
+ exports.Conv2DBackpropInput = Conv2DBackpropInput;
+ exports.Conv3D = Conv3D;
+ exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
+ exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
+ exports.Cos = Cos;
+ exports.Cosh = Cosh;
+ exports.CropAndResize = CropAndResize;
+ exports.Cumsum = Cumsum;
+ exports.DataStorage = DataStorage;
+ exports.DenseBincount = DenseBincount;
+ exports.DepthToSpace = DepthToSpace;
+ exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
+ exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
+ exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
+ exports.Diag = Diag;
+ exports.Dilation2D = Dilation2D;
+ exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
+ exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
+ exports.Einsum = Einsum;
+ exports.Elu = Elu;
+ exports.EluGrad = EluGrad;
+ exports.Environment = Environment;
+ exports.Equal = Equal;
+ exports.Erf = Erf;
+ exports.Exp = Exp;
+ exports.ExpandDims = ExpandDims;
+ exports.Expm1 = Expm1;
+ exports.FFT = FFT;
+ exports.Fill = Fill;
+ exports.FlipLeftRight = FlipLeftRight;
+ exports.Floor = Floor;
+ exports.FloorDiv = FloorDiv;
+ exports.FromPixels = FromPixels;
+ exports.FusedBatchNorm = FusedBatchNorm;
+ exports.FusedConv2D = FusedConv2D;
+ exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
+ exports.GatherNd = GatherNd;
+ exports.GatherV2 = GatherV2;
+ exports.Greater = Greater;
+ exports.GreaterEqual = GreaterEqual;
+ exports.IFFT = IFFT;
+ exports.Identity = Identity;
+ exports.Imag = Imag;
+ exports.IsFinite = IsFinite;
+ exports.IsInf = IsInf;
+ exports.IsNan = IsNan;
+ exports.KernelBackend = KernelBackend;
+ exports.LRN = LRN;
+ exports.LRNGrad = LRNGrad;
+ exports.LeakyRelu = LeakyRelu;
+ exports.Less = Less;
+ exports.LessEqual = LessEqual;
+ exports.LinSpace = LinSpace;
+ exports.Log = Log;
+ exports.Log1p = Log1p;
+ exports.LogSoftmax = LogSoftmax;
+ exports.LogicalAnd = LogicalAnd;
+ exports.LogicalNot = LogicalNot;
+ exports.LogicalOr = LogicalOr;
+ exports.Max = Max;
+ exports.MaxPool = MaxPool;
+ exports.MaxPool3D = MaxPool3D;
+ exports.MaxPool3DGrad = MaxPool3DGrad;
+ exports.MaxPoolGrad = MaxPoolGrad;
+ exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
+ exports.Maximum = Maximum;
+ exports.Mean = Mean;
+ exports.Min = Min;
+ exports.Minimum = Minimum;
+ exports.MirrorPad = MirrorPad;
+ exports.Mod = Mod;
+ exports.MomentumOptimizer = MomentumOptimizer;
+ exports.Multinomial = Multinomial;
+ exports.Multiply = Multiply;
+ exports.Neg = Neg;
+ exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
+ exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
+ exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
+ exports.NotEqual = NotEqual;
+ exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
+ exports.OneHot = OneHot;
+ exports.OnesLike = OnesLike;
+ exports.Optimizer = Optimizer;
+ exports.OptimizerConstructors = OptimizerConstructors;
+ exports.Pack = Pack;
+ exports.PadV2 = PadV2;
+ exports.Pool = Pool;
+ exports.Pow = Pow;
+ exports.Prelu = Prelu;
+ exports.Prod = Prod;
+ exports.RMSPropOptimizer = RMSPropOptimizer;
+ exports.Range = Range;
+ exports.Real = Real;
+ exports.RealDiv = RealDiv;
+ exports.Reciprocal = Reciprocal;
+ exports.Relu = Relu;
+ exports.Relu6 = Relu6;
+ exports.Reshape = Reshape;
+ exports.ResizeBilinear = ResizeBilinear;
+ exports.ResizeBilinearGrad = ResizeBilinearGrad;
+ exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
+ exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
+ exports.Reverse = Reverse;
+ exports.RotateWithOffset = RotateWithOffset;
+ exports.Round = Round;
+ exports.Rsqrt = Rsqrt;
+ exports.SGDOptimizer = SGDOptimizer;
+ exports.ScatterNd = ScatterNd;
+ exports.Select = Select;
+ exports.Selu = Selu;
+ exports.Sigmoid = Sigmoid;
+ exports.Sign = Sign;
+ exports.Sin = Sin;
+ exports.Sinh = Sinh;
+ exports.Slice = Slice;
+ exports.Softmax = Softmax;
+ exports.Softplus = Softplus;
+ exports.SpaceToBatchND = SpaceToBatchND;
+ exports.SparseFillEmptyRows = SparseFillEmptyRows;
+ exports.SparseReshape = SparseReshape;
+ exports.SparseSegmentMean = SparseSegmentMean;
+ exports.SparseSegmentSum = SparseSegmentSum;
+ exports.SparseToDense = SparseToDense;
+ exports.SplitV = SplitV;
+ exports.Sqrt = Sqrt;
+ exports.Square = Square;
+ exports.SquaredDifference = SquaredDifference;
+ exports.Step = Step;
+ exports.StridedSlice = StridedSlice;
+ exports.StringNGrams = StringNGrams;
+ exports.StringSplit = StringSplit;
+ exports.StringToHashBucketFast = StringToHashBucketFast;
+ exports.Sub = Sub;
+ exports.Sum = Sum;
+ exports.Tan = Tan;
+ exports.Tanh = Tanh;
+ exports.Tensor = Tensor;
+ exports.TensorBuffer = TensorBuffer;
+ exports.Tile = Tile;
+ exports.TopK = TopK;
+ exports.Transform = Transform;
+ exports.Transpose = Transpose;
+ exports.Unique = Unique;
+ exports.Unpack = Unpack;
+ exports.UnsortedSegmentSum = UnsortedSegmentSum;
+ exports.Variable = Variable;
+ exports.ZerosLike = ZerosLike;
+ exports._FusedMatMul = _FusedMatMul;
+ exports.abs = abs;
+ exports.acos = acos;
+ exports.acosh = acosh;
+ exports.add = add;
+ exports.addN = addN;
+ exports.all = all;
+ exports.any = any;
+ exports.argMax = argMax;
+ exports.argMin = argMin;
+ exports.asin = asin;
+ exports.asinh = asinh;
+ exports.atan = atan;
+ exports.atan2 = atan2;
+ exports.atanh = atanh;
+ exports.avgPool = avgPool;
+ exports.avgPool3d = avgPool3d;
+ exports.backend = backend;
+ exports.backend_util = backend_util;
+ exports.basicLSTMCell = basicLSTMCell;
+ exports.batchNorm = batchNorm;
+ exports.batchNorm2d = batchNorm2d;
+ exports.batchNorm3d = batchNorm3d;
+ exports.batchNorm4d = batchNorm4d;
+ exports.batchToSpaceND = batchToSpaceND;
+ exports.bincount = bincount;
+ exports.booleanMaskAsync = booleanMaskAsync;
+ exports.broadcastArgs = broadcastArgs;
+ exports.broadcastTo = broadcastTo;
+ exports.broadcast_util = broadcast_util;
+ exports.browser = browser;
+ exports.buffer = buffer;
+ exports.cast = cast;
+ exports.ceil = ceil;
+ exports.clipByValue = clipByValue;
+ exports.clone = clone;
+ exports.complex = complex;
+ exports.concat = concat;
+ exports.concat1d = concat1d;
+ exports.concat2d = concat2d;
+ exports.concat3d = concat3d;
+ exports.concat4d = concat4d;
+ exports.conv1d = conv1d;
+ exports.conv2d = conv2d$1;
+ exports.conv2dTranspose = conv2dTranspose;
+ exports.conv3d = conv3d;
+ exports.conv3dTranspose = conv3dTranspose;
+ exports.copyRegisteredKernels = copyRegisteredKernels;
+ exports.cos = cos;
+ exports.cosh = cosh;
+ exports.cosineWindow = cosineWindow;
+ exports.cumsum = cumsum;
+ exports.customGrad = customGrad;
+ exports.denseBincount = denseBincount;
+ exports.deprecationWarn = deprecationWarn;
+ exports.depthToSpace = depthToSpace;
+ exports.depthwiseConv2d = depthwiseConv2d$1;
+ exports.device_util = device_util;
+ exports.diag = diag;
+ exports.dilation2d = dilation2d;
+ exports.disableDeprecationWarnings = disableDeprecationWarnings;
+ exports.dispose = dispose;
+ exports.disposeVariables = disposeVariables;
+ exports.div = div;
+ exports.divNoNan = divNoNan;
+ exports.dot = dot;
+ exports.dropout = dropout;
+ exports.einsum = einsum;
+ exports.elu = elu;
+ exports.enableDebugMode = enableDebugMode;
+ exports.enableProdMode = enableProdMode;
+ exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
+ exports.engine = engine;
+ exports.env = env;
+ exports.equal = equal;
+ exports.erf = erf;
+ exports.exp = exp;
+ exports.expandDims = expandDims;
+ exports.expm1 = expm1;
+ exports.eye = eye;
+ exports.fft = fft;
+ exports.fill = fill;
+ exports.findBackend = findBackend;
+ exports.findBackendFactory = findBackendFactory;
+ exports.floor = floor;
+ exports.floorDiv = floorDiv;
+ exports.fused = fused_ops;
+ exports.gather = gather;
+ exports.gatherND = gatherND;
+ exports.gather_util = gather_nd_util;
+ exports.getBackend = getBackend;
+ exports.getGradient = getGradient;
+ exports.getKernel = getKernel;
+ exports.getKernelsForBackend = getKernelsForBackend;
+ exports.grad = grad;
+ exports.grads = grads;
+ exports.greater = greater;
+ exports.greaterEqual = greaterEqual;
+ exports.ifft = ifft;
+ exports.imag = imag;
+ exports.image = image;
+ exports.inTopKAsync = inTopKAsync;
+ exports.io = io;
+ exports.irfft = irfft;
+ exports.isFinite = isFinite$1;
+ exports.isInf = isInf;
+ exports.isNaN = isNaN$1;
+ exports.keep = keep;
+ exports.kernel_impls = kernel_impls;
+ exports.leakyRelu = leakyRelu;
+ exports.less = less;
+ exports.lessEqual = lessEqual;
+ exports.linalg = linalg;
+ exports.linspace = linspace;
+ exports.localResponseNormalization = localResponseNormalization;
+ exports.log = log;
+ exports.log1p = log1p;
+ exports.logSigmoid = logSigmoid;
+ exports.logSoftmax = logSoftmax;
+ exports.logSumExp = logSumExp;
+ exports.logicalAnd = logicalAnd;
+ exports.logicalNot = logicalNot;
+ exports.logicalOr = logicalOr;
+ exports.logicalXor = logicalXor;
+ exports.losses = losses;
+ exports.matMul = matMul$1;
+ exports.math = math;
+ exports.max = max;
+ exports.maxPool = maxPool;
+ exports.maxPool3d = maxPool3d;
+ exports.maxPoolWithArgmax = maxPoolWithArgmax;
+ exports.maximum = maximum;
+ exports.mean = mean;
+ exports.memory = memory;
+ exports.meshgrid = meshgrid;
+ exports.min = min;
+ exports.minimum = minimum;
+ exports.mirrorPad = mirrorPad;
+ exports.mod = mod;
+ exports.moments = moments;
+ exports.movingAverage = movingAverage;
+ exports.mul = mul;
+ exports.multiRNNCell = multiRNNCell;
+ exports.multinomial = multinomial;
+ exports.neg = neg;
+ exports.nextFrame = nextFrame;
+ exports.norm = norm;
+ exports.notEqual = notEqual;
+ exports.oneHot = oneHot;
+ exports.ones = ones;
+ exports.onesLike = onesLike;
+ exports.op = op;
+ exports.outerProduct = outerProduct;
+ exports.pad = pad;
+ exports.pad1d = pad1d;
+ exports.pad2d = pad2d;
+ exports.pad3d = pad3d;
+ exports.pad4d = pad4d;
+ exports.pool = pool;
+ exports.pow = pow;
+ exports.prelu = prelu;
+ exports.print = print;
+ exports.prod = prod;
+ exports.profile = profile;
+ exports.rand = rand;
+ exports.randomGamma = randomGamma;
+ exports.randomNormal = randomNormal;
+ exports.randomUniform = randomUniform;
+ exports.range = range;
+ exports.ready = ready;
+ exports.real = real;
+ exports.reciprocal = reciprocal;
+ exports.registerBackend = registerBackend;
+ exports.registerGradient = registerGradient;
+ exports.registerKernel = registerKernel;
+ exports.relu = relu;
+ exports.relu6 = relu6;
+ exports.removeBackend = removeBackend;
+ exports.reshape = reshape;
+ exports.reverse = reverse;
+ exports.reverse1d = reverse1d;
+ exports.reverse2d = reverse2d;
+ exports.reverse3d = reverse3d;
+ exports.reverse4d = reverse4d;
+ exports.rfft = rfft;
+ exports.round = round;
+ exports.rsqrt = rsqrt;
+ exports.scalar = scalar;
+ exports.scatterND = scatterND;
+ exports.scatter_util = scatter_nd_util;
+ exports.selu = selu;
+ exports.separableConv2d = separableConv2d;
+ exports.serialization = serialization;
+ exports.setBackend = setBackend;
+ exports.setPlatform = setPlatform;
+ exports.setdiff1dAsync = setdiff1dAsync;
+ exports.sigmoid = sigmoid;
+ exports.sign = sign;
+ exports.signal = signal;
+ exports.sin = sin;
+ exports.sinh = sinh;
+ exports.slice = slice;
+ exports.slice1d = slice1d;
+ exports.slice2d = slice2d;
+ exports.slice3d = slice3d;
+ exports.slice4d = slice4d;
+ exports.slice_util = slice_util;
+ exports.softmax = softmax;
+ exports.softplus = softplus;
+ exports.spaceToBatchND = spaceToBatchND;
+ exports.sparse = sparse;
+ exports.sparseToDense = sparseToDense;
+ exports.spectral = spectral;
+ exports.split = split;
+ exports.sqrt = sqrt;
+ exports.square = square;
+ exports.squaredDifference = squaredDifference;
+ exports.squeeze = squeeze;
+ exports.stack = stack;
+ exports.step = step;
+ exports.stridedSlice = stridedSlice;
+ exports.string = string;
+ exports.sub = sub;
+ exports.sum = sum;
+ exports.sumOutType = sumOutType;
+ exports.tan = tan;
+ exports.tanh = tanh;
+ exports.tensor = tensor;
+ exports.tensor1d = tensor1d;
+ exports.tensor2d = tensor2d;
+ exports.tensor3d = tensor3d;
+ exports.tensor4d = tensor4d;
+ exports.tensor5d = tensor5d;
+ exports.tensor6d = tensor6d;
+ exports.tensor_util = tensor_util;
+ exports.test_util = test_util;
+ exports.tidy = tidy;
+ exports.tile = tile;
+ exports.time = time;
+ exports.topk = topk;
+ exports.train = train;
+ exports.transpose = transpose;
+ exports.truncatedNormal = truncatedNormal;
+ exports.unique = unique;
+ exports.unregisterGradient = unregisterGradient;
+ exports.unregisterKernel = unregisterKernel;
+ exports.unsortedSegmentSum = unsortedSegmentSum;
+ exports.unstack = unstack;
+ exports.upcastType = upcastType;
+ exports.util = util;
+ exports.valueAndGrad = valueAndGrad;
+ exports.valueAndGrads = valueAndGrads;
+ exports.variable = variable;
+ exports.variableGrads = variableGrads;
+ exports.version_core = version;
+ exports.where = where;
+ exports.whereAsync = whereAsync;
+ exports.zeros = zeros;
+ exports.zerosLike = zerosLike;
+
+ Object.defineProperty(exports, '__esModule', { value: true });
+
+})));
+//# sourceMappingURL=tf-core.js.map