@@ -4,7 +4,7 @@ import { Cross2Icon } from '@radix-ui/react-icons';
44import type * as ort from 'onnxruntime-web' ;
55
66import './CanvasBoard.css' ;
7- import { argMax , initOnnx , runInference , MNIST_IMAGE_SIDE_SIZE , getNumberColor } from 'utils/mnist' ;
7+ import { initOnnx , runInference , MNIST_IMAGE_SIDE_SIZE , getNumberColor } from 'utils/mnist' ;
88import DonutChart from 'components/charts/DonutChart' ;
99import type { PieSeriesOption } from 'echarts' ;
1010
@@ -44,7 +44,7 @@ function CanvasBoard() {
4444 }
4545
4646 ctx . beginPath ( ) ;
47- ctx . lineWidth = 30 ;
47+ ctx . lineWidth = 36 ;
4848 ctx . lineCap = 'round' ;
4949 ctx . strokeStyle = 'white' ;
5050 ctx . moveTo ( pos . x , pos . y ) ;
@@ -108,13 +108,11 @@ function CanvasBoard() {
108108 } ;
109109
110110 const donutChartData : PieSeriesOption [ 'data' ] = useMemo ( ( ) => {
111- const highestIndex = argMax ( inferenceList ) ;
112111 const data : PieSeriesOption [ 'data' ] = Array . from ( inferenceList )
113112 . map ( ( d , i ) => ( {
114113 name : i . toString ( ) ,
115- value : 10 - Math . abs ( d ) ,
114+ value : Math . exp ( d ) , // pytorch/mnist.py uses log_softmax, so convert it back to linear from log
116115 itemStyle : { color : getNumberColor ( i ) } ,
117- selected : i == highestIndex ,
118116 } ) )
119117 . sort ( ( a , b ) => b . value - a . value ) ;
120118 return data ;
0 commit comments