chore: fix distance algo
8d8d37cc
4 file(s) · +143 −15
| 1 | + | import type { Profile, DigraphAggregation, MetricStats } from './src/lib/types'; |
|
| 2 | + | ||
| 3 | + | const METRIC_KEYS = [ |
|
| 4 | + | 'holdTime1', |
|
| 5 | + | 'holdTime2', |
|
| 6 | + | 'pressPress', |
|
| 7 | + | 'releaseRelease', |
|
| 8 | + | 'pressRelease', |
|
| 9 | + | 'releasePress', |
|
| 10 | + | ] as const; |
|
| 11 | + | ||
| 12 | + | const MAX_METRIC_DISTANCE = 3.0; |
|
| 13 | + | const MIN_STD_FALLBACK = 15.0; |
|
| 14 | + | ||
| 15 | + | function filterOutliers(values: number[], k = 1.5): number[] { |
|
| 16 | + | if (values.length < 4) return values; |
|
| 17 | + | const sorted = [...values].sort((a, b) => a - b); |
|
| 18 | + | const q1 = sorted[Math.floor(sorted.length * 0.25)]; |
|
| 19 | + | const q3 = sorted[Math.floor(sorted.length * 0.75)]; |
|
| 20 | + | const iqr = q3 - q1; |
|
| 21 | + | const lower = q1 - k * iqr; |
|
| 22 | + | const upper = q3 + k * iqr; |
|
| 23 | + | return values.filter((v) => v >= lower && v <= upper); |
|
| 24 | + | } |
|
| 25 | + | ||
| 26 | + | function computeMetricStats(values: number[]): MetricStats { |
|
| 27 | + | if (values.length === 0) return { mean: 0, std: 0, min: 0, max: 0, count: 0 }; |
|
| 28 | + | const count = values.length; |
|
| 29 | + | const mean = values.reduce((s, v) => s + v, 0) / count; |
|
| 30 | + | const variance = count < 2 ? 0 : values.reduce((s, v) => s + (v - mean) ** 2, 0) / (count - 1); |
|
| 31 | + | return { |
|
| 32 | + | mean: Math.round(mean * 10) / 10, |
|
| 33 | + | std: Math.round(Math.sqrt(variance) * 10) / 10, |
|
| 34 | + | min: Math.round(Math.min(...values) * 10) / 10, |
|
| 35 | + | max: Math.round(Math.max(...values) * 10) / 10, |
|
| 36 | + | count, |
|
| 37 | + | }; |
|
| 38 | + | } |
|
| 39 | + | ||
| 40 | + | // Re-aggregate a profile's raw aggregations with outlier filtering |
|
| 41 | + | // Since we only have pre-aggregated data (no raw samples), we'll just use them as-is |
|
| 42 | + | // But demonstrate what the new comparison logic does |
|
| 43 | + | ||
| 44 | + | const profile1: Profile = await Bun.file('1.json').json(); |
|
| 45 | + | const profile2: Profile = await Bun.file('2.json').json(); |
|
| 46 | + | ||
| 47 | + | console.log('=== PROFILE OVERVIEW ==='); |
|
| 48 | + | console.log(`Profile 1: "${profile1.name}" — ${profile1.aggregations.length} digraph types`); |
|
| 49 | + | console.log(`Profile 2: "${profile2.name}" — ${profile2.aggregations.length} digraph types`); |
|
| 50 | + | ||
| 51 | + | const map1 = new Map<string, DigraphAggregation>(); |
|
| 52 | + | for (const a of profile1.aggregations) map1.set(a.normalizedKeys, a); |
|
| 53 | + | const map2 = new Map<string, DigraphAggregation>(); |
|
| 54 | + | for (const a of profile2.aggregations) map2.set(a.normalizedKeys, a); |
|
| 55 | + | ||
| 56 | + | const sharedKeys = [...map1.keys()].filter(k => map2.has(k)); |
|
| 57 | + | console.log(`Shared digraph types: ${sharedKeys.length}\n`); |
|
| 58 | + | ||
| 59 | + | // OLD algorithm |
|
| 60 | + | let oldTotal = 0, oldMetrics = 0; |
|
| 61 | + | for (const key of sharedKeys) { |
|
| 62 | + | const a1 = map1.get(key)!, a2 = map2.get(key)!; |
|
| 63 | + | for (const metric of METRIC_KEYS) { |
|
| 64 | + | const s = a1[metric] as MetricStats, p = a2[metric] as MetricStats; |
|
| 65 | + | let dist: number; |
|
| 66 | + | if (p.std < 0.001) { |
|
| 67 | + | dist = Math.abs(s.mean - p.mean) / Math.max(Math.abs(p.mean) * 0.1, 1.0); |
|
| 68 | + | } else { |
|
| 69 | + | dist = Math.abs(s.mean - p.mean) / p.std; |
|
| 70 | + | } |
|
| 71 | + | oldTotal += dist; oldMetrics++; |
|
| 72 | + | } |
|
| 73 | + | } |
|
| 74 | + | const oldOverall = oldMetrics > 0 ? oldTotal / oldMetrics : 0; |
|
| 75 | + | const oldSim = Math.max(0, Math.round(100 * (1 - oldOverall / 3.0))); |
|
| 76 | + | ||
| 77 | + | // NEW algorithm (capped distance, better fallback) |
|
| 78 | + | let newTotal = 0, newMetrics = 0; |
|
| 79 | + | const perDigraph: { keys: string; avgDist: number; details: string[] }[] = []; |
|
| 80 | + | ||
| 81 | + | for (const key of sharedKeys) { |
|
| 82 | + | const a1 = map1.get(key)!, a2 = map2.get(key)!; |
|
| 83 | + | let digDist = 0, digMetrics = 0; |
|
| 84 | + | const details: string[] = []; |
|
| 85 | + | for (const metric of METRIC_KEYS) { |
|
| 86 | + | const s = a1[metric] as MetricStats, p = a2[metric] as MetricStats; |
|
| 87 | + | const divisor = Math.max(p.std, MIN_STD_FALLBACK); |
|
| 88 | + | const dist = Math.min(Math.abs(s.mean - p.mean) / divisor, MAX_METRIC_DISTANCE); |
|
| 89 | + | details.push(` ${metric.padEnd(16)} s=${String(s.mean).padStart(8)} p=${String(p.mean).padStart(8)} div=${divisor.toFixed(1).padStart(6)} → ${dist.toFixed(2)}${dist >= MAX_METRIC_DISTANCE ? ' (capped)' : ''}`); |
|
| 90 | + | digDist += dist; digMetrics++; |
|
| 91 | + | } |
|
| 92 | + | const avg = digMetrics > 0 ? digDist / digMetrics : 0; |
|
| 93 | + | newTotal += digDist; newMetrics++; |
|
| 94 | + | perDigraph.push({ keys: key, avgDist: avg, details }); |
|
| 95 | + | } |
|
| 96 | + | ||
| 97 | + | const newOverall = newMetrics > 0 ? newTotal / (newMetrics * METRIC_KEYS.length) : 0; |
|
| 98 | + | const newSim = Math.max(0, Math.round(100 * (1 - newOverall / 3.0))); |
|
| 99 | + | ||
| 100 | + | console.log('=== COMPARISON ==='); |
|
| 101 | + | console.log(`OLD: distance=${oldOverall.toFixed(2)}, similarity=${oldSim}%`); |
|
| 102 | + | console.log(`NEW: distance=${newOverall.toFixed(2)}, similarity=${newSim}%`); |
|
| 103 | + | ||
| 104 | + | perDigraph.sort((a, b) => b.avgDist - a.avgDist); |
|
| 105 | + | ||
| 106 | + | console.log('\nTop 10 worst (new algorithm):'); |
|
| 107 | + | for (const d of perDigraph.slice(0, 10)) { |
|
| 108 | + | const pct = Math.max(0, Math.round(100 * (1 - d.avgDist / 3.0))); |
|
| 109 | + | console.log(`\n${d.keys} — dist: ${d.avgDist.toFixed(2)}, match: ${pct}%`); |
|
| 110 | + | for (const line of d.details) console.log(line); |
|
| 111 | + | } |
|
| 112 | + | ||
| 113 | + | console.log('\n\nTop 5 best (new algorithm):'); |
|
| 114 | + | for (const d of perDigraph.slice(-5).reverse()) { |
|
| 115 | + | const pct = Math.max(0, Math.round(100 * (1 - d.avgDist / 3.0))); |
|
| 116 | + | console.log(`${d.keys} — dist: ${d.avgDist.toFixed(2)}, match: ${pct}%`); |
|
| 117 | + | } |
| 1 | 1 | import type { RawDigraph, MetricStats, DigraphAggregation } from './types'; |
|
| 2 | - | import { std } from './utils'; |
|
| 2 | + | import { std, filterOutliers } from './utils'; |
|
| 3 | 3 | ||
| 4 | 4 | function normalizeKey(key: string): string { |
|
| 5 | 5 | if (key === ' ') return '␣'; |
|
| 68 | 68 | const agg: DigraphAggregation = { |
|
| 69 | 69 | normalizedKeys, |
|
| 70 | 70 | count: group.length, |
|
| 71 | - | holdTime1: computeMetricStats(group.map((d) => d.holdTime1)), |
|
| 72 | - | holdTime2: computeMetricStats(group.map((d) => d.holdTime2)), |
|
| 73 | - | pressPress: computeMetricStats(group.map((d) => d.pressPress)), |
|
| 74 | - | releaseRelease: computeMetricStats(group.map((d) => d.releaseRelease)), |
|
| 75 | - | pressRelease: computeMetricStats(group.map((d) => d.pressRelease)), |
|
| 76 | - | releasePress: computeMetricStats(group.map((d) => d.releasePress)), |
|
| 71 | + | holdTime1: computeMetricStats(filterOutliers(group.map((d) => d.holdTime1))), |
|
| 72 | + | holdTime2: computeMetricStats(filterOutliers(group.map((d) => d.holdTime2))), |
|
| 73 | + | pressPress: computeMetricStats(filterOutliers(group.map((d) => d.pressPress))), |
|
| 74 | + | releaseRelease: computeMetricStats(filterOutliers(group.map((d) => d.releaseRelease))), |
|
| 75 | + | pressRelease: computeMetricStats(filterOutliers(group.map((d) => d.pressRelease))), |
|
| 76 | + | releasePress: computeMetricStats(filterOutliers(group.map((d) => d.releasePress))), |
|
| 77 | 77 | }; |
|
| 78 | 78 | aggregations.push(agg); |
|
| 79 | 79 | } |
|
| 9 | 9 | 'releasePress', |
|
| 10 | 10 | ] as const; |
|
| 11 | 11 | ||
| 12 | - | const EPSILON = 0.001; |
|
| 12 | + | const MAX_METRIC_DISTANCE = 3.0; |
|
| 13 | + | const MIN_STD_FALLBACK = 15.0; |
|
| 13 | 14 | ||
| 14 | 15 | export function compareSession( |
|
| 15 | 16 | sessionAggs: DigraphAggregation[], |
|
| 36 | 37 | const profileMean = profileAgg[key].mean; |
|
| 37 | 38 | const profileStd = profileAgg[key].std; |
|
| 38 | 39 | ||
| 39 | - | let distance: number; |
|
| 40 | - | if (profileStd < EPSILON) { |
|
| 41 | - | const fallback = Math.max(Math.abs(profileMean) * 0.1, 1.0); |
|
| 42 | - | distance = Math.abs(sessionMean - profileMean) / fallback; |
|
| 43 | - | } else { |
|
| 44 | - | distance = Math.abs(sessionMean - profileMean) / profileStd; |
|
| 45 | - | } |
|
| 40 | + | const divisor = Math.max(profileStd, MIN_STD_FALLBACK); |
|
| 41 | + | const distance = Math.min( |
|
| 42 | + | Math.abs(sessionMean - profileMean) / divisor, |
|
| 43 | + | MAX_METRIC_DISTANCE, |
|
| 44 | + | ); |
|
| 46 | 45 | ||
| 47 | 46 | digraphDistance += distance; |
|
| 48 | 47 | digraphMetrics++; |
|
| 13 | 13 | const variance = values.reduce((s, v) => s + (v - m) ** 2, 0) / (values.length - 1); |
|
| 14 | 14 | return Math.sqrt(variance); |
|
| 15 | 15 | } |
|
| 16 | + | ||
| 17 | + | /** Remove outliers using the IQR method, returns filtered array */ |
|
| 18 | + | export function filterOutliers(values: number[], k = 1.5): number[] { |
|
| 19 | + | if (values.length < 4) return values; |
|
| 20 | + | const sorted = [...values].sort((a, b) => a - b); |
|
| 21 | + | const q1 = sorted[Math.floor(sorted.length * 0.25)]; |
|
| 22 | + | const q3 = sorted[Math.floor(sorted.length * 0.75)]; |
|
| 23 | + | const iqr = q3 - q1; |
|
| 24 | + | const lower = q1 - k * iqr; |
|
| 25 | + | const upper = q3 + k * iqr; |
|
| 26 | + | return values.filter((v) => v >= lower && v <= upper); |
|
| 27 | + | } |