diff --git a/js/Overlay.ts b/js/Overlay.ts index 0f83978..36d757e 100644 --- a/js/Overlay.ts +++ b/js/Overlay.ts @@ -16,10 +16,13 @@ export class Overlay { annotate?: (renderer: Renderer) => void; legend?: Legend; anchors?: d3.Selection; + #dataset: string; - constructor(public renderer: Renderer) { - this.figure = d3.select(renderer.gl.canvas.parentNode as HTMLElement); + constructor(public renderer: Renderer, dataset: string) { + let canvas = renderer.gl.canvas as HTMLCanvasElement; + this.figure = d3.select(canvas.parentNode as HTMLElement); let self = this; + this.#dataset = dataset; this.epochSlider = this.figure .insert("input", ":first-child") @@ -162,15 +165,17 @@ export class Overlay { } get width() { - return this.renderer.gl.canvas.clientWidth; + let canvas = this.renderer.gl.canvas as HTMLCanvasElement; + return canvas.clientWidth; } get height() { - return this.renderer.gl.canvas.clientHeight; + let canvas = this.renderer.gl.canvas as HTMLCanvasElement; + return canvas.clientHeight; } get dataset() { - return utils.getDataset() as keyof typeof utils.legendTitle; + return this.#dataset; } resize() { @@ -201,8 +206,8 @@ export class Overlay { } } - init() { - this.initLegend(); + init(legendData: [string, d3.RGBColor][]) { + this.initLegend(legendData); this.resize(); this.drawAxes(); if (this.annotate !== undefined) { @@ -295,17 +300,13 @@ export class Overlay { }); } - initLegend() { - let data = utils.zip( - utils.getLabelNames(false, this.dataset), - utils.baseColors, - ); - this.legend = new Legend(data, { + initLegend(legendData: [string, d3.RGBColor][]) { + this.legend = new Legend(legendData, { root: this.svg, - title: utils.legendTitle[this.dataset], + title: utils.legendTitle[this.dataset as keyof typeof utils.legendTitle], margin: { - left: utils.legendLeft[this.dataset], - right: utils.legendRight[this.dataset], + left: utils.legendLeft[this.dataset as keyof typeof utils.legendLeft], + right: utils.legendRight[this.dataset as keyof typeof utils.legendRight], }, }); this.legend.on("select", (classes) => { diff --git a/js/Renderer.ts b/js/Renderer.ts index c3168d6..9dcf6ca 100644 --- a/js/Renderer.ts +++ b/js/Renderer.ts @@ -18,6 +18,7 @@ interface Data { alphas: number[]; points?: number[][]; colors?: ColorRGBA[]; + hexColors: string[]; } interface RendererOptions { @@ -84,15 +85,17 @@ export class Renderer { constructor( public gl: WebGLRenderingContext, public program: WebGLProgram, + dataset: string, opts: Partial = {}, ) { - this.id = gl.canvas.id; + let canvas = gl.canvas as HTMLCanvasElement; + this.id = canvas.id; this.epochs = opts.epochs ?? [0]; this.epochIndex = opts.epochIndex ?? this.epochs[0]; this.shouldAutoNextEpoch = opts.shouldAutoNextEpoch ?? true; this.shouldPlayGrandTour = opts.shouldPlayGrandTour ?? true; this.pointSize0 = this.#pointSize = opts.pointSize ?? 6.0; - this.overlay = new Overlay(this); + this.overlay = new Overlay(this, dataset); this.sx_span = d3.scaleLinear(); this.sy_span = d3.scaleLinear(); @@ -105,53 +108,69 @@ export class Renderer { this.sz = this.sz_center; } - async initData(buffer: ArrayBuffer) { + async initData( + buffer: ArrayBuffer, + axisFields: string[], + labelField: string, + labelColors?: Object, + ) { + // Load the table let table = arrow.tableFromIPC(buffer); - let ndim = 5; - let nepoch = 1; + let ndim = axisFields.length; + + // Categories and colors + let labelColumn = table.getChild(labelField); + let categories: string[] = labelColumn ? Array.from(new Set(labelColumn.toArray())) : []; + categories = categories.sort(); + let cat_to_int = Object.fromEntries(categories.map((name, i) => [name, i])); + + let hexColors; + if (labelColors) { + categories = Object.keys(labelColors); + hexColors = Object.values(labelColors); + } else { + hexColors = d3.schemeCategory10; + } + let rgbColors = hexColors.map((c) => d3.rgb(c)!); + let legendData = utils.zip(categories, rgbColors); + // Point data let labels = []; - let arr = []; - - let fields = d3.range(ndim).map((i) => "E" + i); - let labelMapping = Object.fromEntries( - ["A0", "A1", "B0", "B1", "B2"].map((name, i) => [name, i]), - ); - + let coords = []; for (let row of utils.iterN(table, 10)) { - labels.push(labelMapping[row.name]); - for (let field of fields) arr.push(row[field]); + labels.push(cat_to_int[row[labelField]]); + for (let field of axisFields) coords.push(row[field]); } + let nepoch = 1; + let n_points = labels.length; + let shape: [number, number, number] = [nepoch, n_points, ndim]; + let dataTensor = utils.reshape(new Float32Array(coords), shape); - let npoint = labels.length; - let shape: [number, number, number] = [nepoch, npoint, ndim]; - let dataTensor = utils.reshape(new Float32Array(arr), shape); - - this.shouldRecalculateColorRect = true; - + // Store the data this.dataObj = { labels, dataTensor, dmax: 1.05 * math.max( math.abs(dataTensor[dataTensor.length - 1]), ), - dimLabels: fields, + dimLabels: axisFields, ndim, - npoint, + npoint: n_points, nepoch, - alphas: Array.from({ length: npoint + 5 * npoint }, () => 255), + alphas: Array.from({ length: n_points + 5 * n_points }, () => 255), + hexColors }; + // Initialize the display and overlay + this.shouldRecalculateColorRect = true; this.initGL(this.dataObj); - if (this.isPlaying === undefined) { // renderer.isPlaying===undefined indicates the renderer on init // otherwise it is reloading other dataset this.isPlaying = true; this.play(); - this.overlay.init(); + this.overlay.init(legendData); } - if ( (this.animId == null || this.shouldRender == false) ) { @@ -163,7 +182,7 @@ export class Renderer { setFullScreen(shouldSet: boolean) { this.isFullScreen = shouldSet; - let canvas = this.gl.canvas; + let canvas = this.gl.canvas as HTMLCanvasElement; d3.select(canvas.parentNode as HTMLElement) .classed("fullscreen", shouldSet); @@ -187,7 +206,8 @@ export class Renderer { } initGL(dataObj: Data) { - utils.resizeCanvas(this.gl.canvas); + let canvas = this.gl.canvas as HTMLCanvasElement; + utils.resizeCanvas(canvas); this.gl.viewport(0, 0, this.gl.canvas.width, this.gl.canvas.height); this.gl.clearColor(...this.#clearColor); @@ -225,8 +245,8 @@ export class Renderer { this.program, "canvasHeight", )!; - this.gl.uniform1f(this.canvasWidthLoc, this.gl.canvas.clientWidth); - this.gl.uniform1f(this.canvasHeightLoc, this.gl.canvas.clientHeight); + this.gl.uniform1f(this.canvasWidthLoc, canvas.clientWidth); + this.gl.uniform1f(this.canvasHeightLoc, canvas.clientHeight); this.modeLoc = this.gl.getUniformLocation(this.program, "mode")!; this.gl.uniform1i(this.modeLoc, 0); // "point" mode @@ -342,6 +362,7 @@ export class Renderer { render(dt: number) { if (!this.dataObj || !this.gt) return; + let canvas = this.gl.canvas as HTMLCanvasElement; let dataObj = this.dataObj; let data = this.dataObj.dataTensor[this.epochIndex]; let labels = this.dataObj.labels; @@ -365,23 +386,23 @@ export class Renderer { utils.updateScaleCenter( points, - this.gl.canvas, + canvas, this.sx_center, this.sy_center, this.sz_center, this.scaleFactor, - utils.legendLeft[this.overlay.dataset] + 15, + utils.legendLeft[this.overlay.dataset as keyof typeof utils.legendLeft] + 15, 65, ); utils.updateScaleSpan( points, - this.gl.canvas, + canvas, this.sx_span, this.sy_span, this.sz_span, this.scaleFactor, - utils.legendLeft[this.overlay.dataset] + 15, + utils.legendLeft[this.overlay.dataset as keyof typeof utils.legendLeft] + 15, 65, ); @@ -410,9 +431,10 @@ export class Renderer { dataObj.points = points; - let bgColors = labels.map((d) => utils.bgColors[d]); + let rgbColors = dataObj.hexColors.map((c) => d3.rgb(c)!); + let modifiedColors = utils.modifyColors(rgbColors); let colors: ColorRGBA[] = labels - .map((d) => utils.baseColors[d]) + .map((i) => rgbColors[i]) .concat(utils.createAxisColors(dataObj.ndim)) .map((c, i) => [c.r, c.g, c.b, dataObj.alphas[i]]); @@ -446,7 +468,7 @@ export class Renderer { this.gl.bufferData( this.gl.ARRAY_BUFFER, new Uint8Array( - bgColors.map((c) => [c.r, c.g, c.b, utils.pointAlpha]).flat(), + modifiedColors.map((c: { r: number; g: number; b: number; }) => [c.r, c.g, c.b, utils.pointAlpha]).flat(), ), this.gl.STATIC_DRAW, ); diff --git a/js/index.ts b/js/index.ts deleted file mode 100644 index 9c571a1..0000000 --- a/js/index.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { Renderer } from "./Renderer"; -import * as utils from "./utils"; -import fs from "./shaders/teaser_fragment.glsl"; -import vs from "./shaders/teaser_vertex.glsl"; - -async function main() { - let canvas = document.querySelector("canvas")!; - let { gl, program } = utils.initGL(canvas, fs, vs); - let renderer = new Renderer(gl, program); - - renderer.overlay.fullScreenButton.style("top", "18px"); - renderer.overlay.epochSlider.style("top", "calc(100% - 28px)"); - renderer.overlay.playButton.style("top", "calc(100% - 34px)"); - renderer.overlay.grandtourButton.style("top", "calc(100% - 34px)"); - - { - let clearBanner = utils.createLoadingBanner(renderer.overlay.figure); - let res = await fetch(new URL("../data/eigs.arrow", import.meta.url)) - await renderer.initData(await res.arrayBuffer()); - clearBanner(); - } - - window.addEventListener("resize", () => { - renderer.setFullScreen(renderer.isFullScreen); - }); -} - -main(); diff --git a/js/utils.ts b/js/utils.ts index 7a2b984..b708365 100644 --- a/js/utils.ts +++ b/js/utils.ts @@ -156,22 +156,22 @@ export function embed(matrix: T[][], canvas: T[][]) { return canvas; } -export function getDataset() { - return dataset; -} - -export function getLabelNames(_adversarial = false, dataset?: string) { - if (dataset === undefined) { - dataset = getDataset(); - } - let res; - if (dataset == "mnist") { - res = ["A0", "A1", "B0", "B1", "B2"]; - } else { - throw new Error("Unrecognized dataset " + dataset); - } - return res; -} +// export function getDataset() { +// return dataset; +// } + +// export function getLabelNames(_adversarial = false, dataset?: string) { +// if (dataset === undefined) { +// dataset = getDataset(); +// } +// let res; +// if (dataset == "mnist") { +// res = ["A0", "A1", "B0", "B1", "B2"]; +// } else { +// throw new Error("Unrecognized dataset " + dataset); +// } +// return res; +// } export function initGL(canvas: HTMLCanvasElement, fs: string, vs: string) { let gl = canvas.getContext("webgl", { premultipliedAlpha: false })!; @@ -272,7 +272,19 @@ export const baseColors = d3.schemeCategory10.map((c) => d3.rgb(c)!); export const bgColors = numeric.add( numeric.mul(baseColors.map((c) => [c.r, c.g, c.b]), 0.6), 0.95 * 255 * 0.4, -).map((c) => d3.rgb(...c as [number, number, number])); +).map((c: [number, number, number]) => d3.rgb(...c as [number, number, number])); + +export function modifyColors(colorList: d3.RGBColor[]) { + // multiply each RGB channel by 0.6, + // then add 0.4 * (0.95 * 255) + return numeric.add( + numeric.mul( + colorList.map((c) => [c.r, c.g, c.b]), + 0.6, + ), + 0.95 * 255 * 0.4, + ); +} export function createAxisPoints(ndim: number) { let res = (math.identity(ndim) as math.Matrix).toArray(); @@ -306,7 +318,6 @@ export function orthogonalize( // make row vectors in matrix pairwise orthogonal; function proj(u: M[number], v: M[number]): M[number] { - // @ts-expect-error return numeric.mul(numeric.dot(u, v) / numeric.dot(u, u), u); } @@ -314,7 +325,6 @@ export function orthogonalize( if (numeric.norm2(v) <= 0) { return v; } else { - // @ts-expect-error return numeric.div(v, numeric.norm2(v) / unitlength); } } @@ -328,7 +338,6 @@ export function orthogonalize( matrix[0] = normalize(matrix[0]); for (let i = 1; i < matrix.length; i++) { for (let j = 0; j < i; j++) { - // @ts-expect-error matrix[i] = numeric.sub(matrix[i], proj(matrix[j], matrix[i])); } matrix[i] = normalize(matrix[i]); @@ -436,13 +445,13 @@ export function zip(a: A[], b: B[]): [A, B][] { return out; } -export function loadScript(url: string): Promise { - return new Promise((resolve, reject) => { - let script = document.createElement("script"); - script.type = "text/javascript"; - script.src = url; - script.onload = () => resolve(); - script.onerror = (err) => reject(err); - document.head.appendChild(script); - }); -} +// export function loadScript(url: string): Promise { +// return new Promise((resolve, reject) => { +// let script = document.createElement("script"); +// script.type = "text/javascript"; +// script.src = url; +// script.onload = () => resolve(); +// script.onerror = (err) => reject(err); +// document.head.appendChild(script); +// }); +// } diff --git a/js/widget.ts b/js/widget.ts index f310c13..3045ccc 100644 --- a/js/widget.ts +++ b/js/widget.ts @@ -8,33 +8,45 @@ import vs from "./shaders/teaser_vertex.glsl"; import "./widget.css"; interface Model { - data: DataView + data: DataView; + axis_fields: string[]; + label_field: string; + label_colors: string[]; } -const template = ` +const TEMPLATE = ` ` export default { async render({ model, el }: RenderProps) { - el.innerHTML = template; + el.innerHTML = TEMPLATE; let canvas = el.querySelector("canvas")!; + + // Compile the fragment and vertex shaders console.log("initGL"); let { gl, program } = utils.initGL(canvas, fs, vs); - let renderer = new Renderer(gl, program); + + // Create the renderer + console.log("Create renderer"); + let renderer = new Renderer(gl, program, "mnist"); renderer.overlay.fullScreenButton.style("top", "18px"); renderer.overlay.epochSlider.style("top", "calc(100% - 28px)"); renderer.overlay.playButton.style("top", "calc(100% - 34px)"); renderer.overlay.grandtourButton.style("top", "calc(100% - 34px)"); - console.log("model", model); + // Load the data { - // load the data let clearBanner = utils.createLoadingBanner(renderer.overlay.figure); - console.log("loading data"); - await renderer.initData(model.get("data").buffer); - console.log("data loaded"); + console.log("Loading data..."); + await renderer.initData( + model.get("data").buffer, + model.get("axis_fields"), + model.get("label_field"), + model.get("label_colors"), + ); + console.log("Data loaded"); clearBanner(); } @@ -42,6 +54,8 @@ export default { renderer.setFullScreen(renderer.isFullScreen); } window.addEventListener("resize", onResize); + + // Return a cleanup function return () => { window.removeEventListener("resize", onResize); }