tag:blogger.com,1999:blog-84749263314520266262024-03-18T16:27:40.382-07:00Google AI BlogThe latest news from Google AI.ewoodhttp://www.blogger.com/profile/12341551220176883769noreply@blogger.comBlogger1346125tag:blogger.com,1999:blog-8474926331452026626.post-7579760017465787142024-03-18T11:41:00.000-07:002024-03-18T12:01:42.865-07:00MELON: Reconstructing 3D objects from images with unknown poses<span class="byline-author">Posted by Mark Matthews, Senior Software Engineer, and Dmitry Lagun, Research Scientist, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh8LjCbKjfNXVUyCpGiZysx_pNF5BK8p5VBCJXXPaz_Bb75CW-33weoMh0YaNcn4AdmGN-Pufd_XlsRzo2MWZLQxqgtri7Nip9tXoGX0CritvRKF-63StOWxp_gVaY-MTnOk9IvJdVt_CczVR6Ip_R8Yv32MHTw2-FckCTF4UOFrgMyq3PCPCkZaZ-nyMcE/s320/MELON%20HERO.jpg" style="display: none;" />
<p>
A person's prior experience and understanding of the world generally enables them to easily infer what an object looks like in whole, even if only looking at a few 2D pictures of it. Yet the capacity for a computer to reconstruct the shape of an object in 3D given only a few images has remained a difficult algorithmic problem for years. This fundamental computer vision task has applications ranging from the creation of e-commerce 3D models to autonomous vehicle navigation.
</p>
<a name='more'></a>
<p>
A key part of the problem is how to determine the exact positions from which images were taken, known as <em>pose inference</em>. If camera poses are known, a range of successful techniques — such as <a href="https://www.matthewtancik.com/nerf">neural radiance fields</a> (NeRF) or <a href="https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/">3D Gaussian Splatting</a> — can reconstruct an object in 3D. But if these poses are not available, then we face a difficult “chicken and egg” problem where we could determine the poses if we knew the 3D object, but we can’t reconstruct the 3D object until we know the camera poses. The problem is made harder by pseudo-symmetries — i.e., many objects look similar when viewed from different angles. For example, square objects like a chair tend to look similar every 90° rotation. Pseudo-symmetries of an object can be revealed by rendering it on a turntable from various angles and plotting its photometric <a href="https://en.wikipedia.org/wiki/Self-similarity">self-similarity</a> map.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjt0nP5M8f5UodttSIPoY5t0JRXEuLosGgock3B0lyOzIn4icGF5jwVuxgX0PiRqc0kBbJ36CLiGA3KPrmaQbjKElGeHrsSRmkpDppU9abE84nuYu9MquqE3gULDzz_INDutmL2i1Wv3_tUpTh5U9UwSck9YRUeVyg-md2GByg3EQYYy7Vs_aeTEk5akpSo/s1764/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="923" data-original-width="1764" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjt0nP5M8f5UodttSIPoY5t0JRXEuLosGgock3B0lyOzIn4icGF5jwVuxgX0PiRqc0kBbJ36CLiGA3KPrmaQbjKElGeHrsSRmkpDppU9abE84nuYu9MquqE3gULDzz_INDutmL2i1Wv3_tUpTh5U9UwSck9YRUeVyg-md2GByg3EQYYy7Vs_aeTEk5akpSo/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Self-Similarity map of a toy truck model. <strong>Left:</strong> The model is rendered on a turntable from various <a href="https://en.wikipedia.org/wiki/Azimuth">azimuthal angles</a>, θ. <strong>Right:</strong> The average <a href="https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm">L2</a> RGB similarity of a rendering from θ with that of θ*. The pseudo-similarities are indicated by the dashed red lines.</td></tr></tbody></table>
<p>
The diagram above only visualizes one dimension of rotation. It becomes even more complex (and difficult to visualize) when introducing more degrees of freedom. Pseudo-symmetries make the problem <em>ill-posed</em>, with naïve approaches often converging to local minima. In practice, such an approach might mistake the back view as the front view of an object, because they share a similar silhouette. Previous techniques (such as <a href="https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF/">BARF</a> or <a href="https://arxiv.org/abs/2205.15768">SAMURAI</a>) side-step this problem by relying on an initial pose estimate that starts close to the global minima. But how can we approach this if those aren’t available?
</p>
<p>
Methods, such as <a href="https://openaccess.thecvf.com/content/ICCV2021/papers/Meng_GNeRF_GAN-Based_Neural_Radiance_Field_Without_Posed_Camera_ICCV_2021_paper.pdf">GNeRF</a> and <a href="https://dl.acm.org/doi/10.1145/3503161.3548078">VMRF</a> leverage <a href="https://en.wikipedia.org/wiki/Generative_adversarial_network">generative adversarial networks</a> (GANs) to overcome the problem. These techniques have the ability to artificially “amplify” a limited number of training views, aiding reconstruction. GAN techniques, however, often have complex, sometimes unstable, training processes, making robust and reliable convergence difficult to achieve in practice. A range of other successful methods, such as <a href="https://openaccess.thecvf.com/content/CVPR2023/html/Sinha_SparsePose_Sparse-View_Camera_Pose_Regression_and_Refinement_CVPR_2023_paper.html">SparsePose</a> or <a href="https://rust-paper.github.io/">RUST</a>, can infer poses from a limited number views, but require pre-training on a large dataset of posed images, which aren’t always available, and can suffer from “domain-gap” issues when inferring poses for different types of images.
</p>
<p>
In “<a href="https://arxiv.org/abs/2303.08096">MELON: NeRF with Unposed Images in SO(3)</a>”, spotlighted at <a href="https://3dvconf.github.io/2024/">3DV 2024</a>, we present a technique that can determine object-centric camera poses entirely from scratch while reconstructing the object in 3D. <a href="https://melon-nerf.github.io/">MELON</a> (Modulo Equivalent Latent Optimization of NeRF) is one of the first techniques that can do this without initial pose camera estimates, complex training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. We demonstrate that MELON can reconstruct a NeRF from unposed images with state-of-the-art accuracy while requiring as few as 4–6 images of an object.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>MELON</h2>
<p>
We leverage two key techniques to aid convergence of this ill-posed problem. The first is a very lightweight, dynamically trained <a href="https://en.wikipedia.org/wiki/Convolutional_neural_network">convolutional neural network</a> (CNN) encoder that regresses camera poses from training images. We pass a downscaled training image to a four layer CNN that infers the camera pose. This CNN is initialized from noise and requires no pre-training. Its capacity is so small that it forces similar looking images to similar poses, providing an implicit regularization greatly aiding convergence.
</p>
<p>
The second technique is a <em>modulo loss</em> that simultaneously considers pseudo symmetries of an object. We render the object from a fixed set of viewpoints for each training image, backpropagating the loss only through the view that best fits the training image. This effectively considers the plausibility of multiple views for each image. In practice, we find <em>N</em>=2 views (viewing an object from the other side) is all that’s required in most cases, but sometimes get better results with <em>N</em>=4 for square objects.
</p>
<p>
These two techniques are integrated into standard NeRF training, except that instead of fixed camera poses, poses are inferred by the CNN and duplicated by the modulo loss. Photometric gradients back-propagate through the best-fitting cameras into the CNN. We observe that cameras generally converge quickly to globally optimal poses (see animation below). After training of the neural field, MELON can synthesize novel views using standard NeRF rendering methods.
</p>
<p>
We simplify the problem by using the <a href="https://github.com/bmild/nerf">NeRF-Synthetic</a> dataset, a popular benchmark for NeRF research and common in the pose-inference literature. This synthetic dataset has cameras at precisely fixed distances and a consistent “up” orientation, requiring us to infer only the <a href="https://en.wikipedia.org/wiki/Spherical_coordinate_system">polar coordinates</a> of the camera. This is the same as an object at the center of a globe with a camera always pointing at it, moving along the surface. We then only need the latitude and longitude (2 degrees of freedom) to specify the camera pose.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEjisRopoeGPgbCRa3sQ7hmBUtnfI6TRapBD7Yn96xeDA_LxzTayiw3DMijPHS0ovkLVTcQGpp2_gAyA_P5BCPwXuEcz7lApC8WQbGfMvj_aAxShjgsmcklf_-4ekgbFH6VZ92Ey3Ta4XAhZvEdc00D2o7SzPIOSnFAj8CgrdmdJunijsGaw1Zx46b94wk/s1315/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="395" data-original-width="1315" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEjisRopoeGPgbCRa3sQ7hmBUtnfI6TRapBD7Yn96xeDA_LxzTayiw3DMijPHS0ovkLVTcQGpp2_gAyA_P5BCPwXuEcz7lApC8WQbGfMvj_aAxShjgsmcklf_-4ekgbFH6VZ92Ey3Ta4XAhZvEdc00D2o7SzPIOSnFAj8CgrdmdJunijsGaw1Zx46b94wk/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">MELON uses a dynamically trained lightweight CNN encoder that predicts a pose for each image. Predicted poses are replicated by the <em>modulo loss, </em>which only penalizes the smallest L2 distance from the ground truth color. At evaluation time, the neural field can be used to generate novel views.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Results</h2>
<p>
We compute two key metrics to evaluate MELON’s performance on the NeRF Synthetic dataset. The error in orientation between the ground truth and inferred poses can be quantified as a single angular error that we average across all training images, the pose error. We then test the accuracy of MELON’s rendered objects from novel views by measuring the <a href="https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio">peak signal-to-noise ratio</a> (PSNR) against held out test views. We see that MELON quickly converges to the approximate poses of most cameras within the first 1,000 steps of training, and achieves a competitive PSNR of 27.5 dB after 50k steps.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjU5wdw89PfwRbvZeaIWLM3rNEAo69__A-ovDwB5x8emIkAGZq05FgF-wDMNlkXPS6tOcC_0NJVD4Glq8eX02yb3CDIiqXbadI4lnvcZ_MI9sHUkz8risxP1orPA8ZnTZUq-PcRLPoEc_AmFuARCokXHQlTOv_q35TH1tivuK2PpA54hO7q7kh_M8ZynO-J/s960/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="480" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjU5wdw89PfwRbvZeaIWLM3rNEAo69__A-ovDwB5x8emIkAGZq05FgF-wDMNlkXPS6tOcC_0NJVD4Glq8eX02yb3CDIiqXbadI4lnvcZ_MI9sHUkz8risxP1orPA8ZnTZUq-PcRLPoEc_AmFuARCokXHQlTOv_q35TH1tivuK2PpA54hO7q7kh_M8ZynO-J/s16000/image1.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Convergence of MELON on a toy truck model during optimization. <strong>Left</strong>: Rendering of the NeRF. <strong>Right</strong>: Polar plot of predicted (blue <em>x</em>), and ground truth (red dot) cameras.</td></tr></tbody></table>
<p>
MELON achieves similar results for other scenes in the NeRF Synthetic dataset.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhWEC7CE_iWu1QZ_jgEUHHCEdqaUBMO7cK-1DZuHaZRDq4Y59_CriUlb_aOSJP5psB6Cbs1E41mm81EsfwVM0zAUojRKToWwiDmPfaWFPr2UGqf6F4n3P8ZpgYxiqyWIgst6op3Fhsbu0nlR727zLVV38KqJvNFY_KDeoJbdOjJFpHjLZkEd95Z9TqSg4R_/s1999/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="644" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhWEC7CE_iWu1QZ_jgEUHHCEdqaUBMO7cK-1DZuHaZRDq4Y59_CriUlb_aOSJP5psB6Cbs1E41mm81EsfwVM0zAUojRKToWwiDmPfaWFPr2UGqf6F4n3P8ZpgYxiqyWIgst6op3Fhsbu0nlR727zLVV38KqJvNFY_KDeoJbdOjJFpHjLZkEd95Z9TqSg4R_/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Reconstruction quality comparison between ground-truth (GT) and MELON on NeRF-Synthetic scenes after 100k training steps.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h3>Noisy images</h3>
<p>
MELON also works well when performing <a href="https://en.wikipedia.org/wiki/View_synthesis">novel view synthesis</a> from extremely noisy, unposed images. We add varying amounts, <em>σ</em>, of <a href="https://en.wikipedia.org/wiki/Additive_white_Gaussian_noise">white Gaussian noise</a> to the training images. For example, the object in <em>σ</em>=1.0 below is impossible to make out, yet MELON can determine the pose and generate novel views of the object.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgHKYFcj-CKKc5kUvfsoOD5rBTp2QMnd3CdYiVzXjMClNwJrcgSrvIZngAdLgxUthE-aiXx5NapxcMx66i-Bi9RhC0zTRVkA0R8fj2A7lOnIdFDIE3YkTh_hWO2PhPa0FjYWYHuNUuae_tPhsrmVHJAkCeeI1f0ooJGe44KgpcO7jVNyLcnUvwtMX-KpJdD/s1182/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="568" data-original-width="1182" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgHKYFcj-CKKc5kUvfsoOD5rBTp2QMnd3CdYiVzXjMClNwJrcgSrvIZngAdLgxUthE-aiXx5NapxcMx66i-Bi9RhC0zTRVkA0R8fj2A7lOnIdFDIE3YkTh_hWO2PhPa0FjYWYHuNUuae_tPhsrmVHJAkCeeI1f0ooJGe44KgpcO7jVNyLcnUvwtMX-KpJdD/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Novel view synthesis from noisy unposed 128×128 images. Top: Example of noise level present in training views. Bottom: Reconstructed model from noisy training views and mean angular pose error.</td></tr></tbody></table>
<p>
This perhaps shouldn’t be too surprising, given that techniques like <a href="https://bmild.github.io/rawnerf/">RawNeRF</a> have demonstrated NeRF’s excellent de-noising capabilities with known camera poses. The fact that MELON works for noisy images of unknown camera poses so robustly was unexpected.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
We present MELON, a technique that can determine object-centric camera poses to reconstruct objects in 3D without the need for approximate pose initializations, complex GAN training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. Though we only demonstrated MELON on synthetic images we are adapting our technique to work in real world conditions. See the <a href="https://arxiv.org/abs/2303.08096">paper</a> and <a href="https://melon-nerf.github.io/">MELON site</a> to learn more.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>We would like to thank our paper co-authors Axel Levy, Matan Sela, and Gordon Wetzstein, as well as Florian Schroff and Hartwig Adam for continuous help in building this technology. We also thank Matthew Brown, Ricardo Martin-Brualla and Frederic Poitevin for their helpful feedback on the paper draft. We also acknowledge the use of the computational resources at the SLAC Shared Scientific Data Facility (SDF).</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-9775566485572311902024-03-15T11:22:00.000-07:002024-03-15T11:22:13.760-07:00HEAL: A framework for health equity assessment of machine learning performance<span class="byline-author">Posted by Mike Schaekermann, Research Scientist, Google Research, and Ivor Horn, Chief Health Equity Officer & Director, Google Core</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYi3V0CsXup8WA6SSjPagoMWfkIpbr9oRWEaUM1vIWOX8_TsZs6ikqOn6qIGbqUzAPhOxwhEPNfWSkECIxRz5fJ629cRGScLraFn2CSw53Sr5_li8Fe7A9I1nMShys_15IiUZNhNiPh_ueFVcu_7f34A-A0pMXXVdDaSoSAf2h0jETJ1PemIR5I6o9pIIW/s1600/HEAL-Hero.png" style="display: none;" />
<p>
Health equity is a major societal concern worldwide with disparities having many causes. These sources include limitations in access to healthcare, differences in clinical treatment, and even fundamental differences in the diagnostic technology. In dermatology for example, skin cancer outcomes are worse for populations such as minorities, those with lower socioeconomic status, or individuals with limited healthcare access. While there is great promise in recent advances in machine learning (ML) and artificial intelligence (AI) to help improve healthcare, this transition from research to bedside must be accompanied by a careful understanding of whether and how they impact health equity.
</p>
<a name='more'></a>
<p>
<em>Health equity</em> is defined by public health organizations as fairness of opportunity for everyone to be as healthy as possible. Importantly, equity may be different from <em>equality</em>. For example, people with greater barriers to improving their health may require more or different effort to experience this fair opportunity. Similarly, equity is not <em>fairness</em> as defined in the AI for healthcare literature. Whereas AI fairness often strives for equal performance of the AI technology across different patient populations, this does not center the goal of prioritizing performance with respect to pre-existing health disparities.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi21VRS33NG-Imj1XlKXWtrwUrl4loEEywV0tO8M0JWtUFFksbTLOhilTZtMdJTgOBdXACUPQX-f5TMAFkABFhdv_cEDmFn4d-JirU78covJI32sHus6XQVJ1C1elwM_MExsQfeVCpFYlq9QZeynLNpLqmW8GqM-DKWiGSyi_18n8Xb3-8IeepHSyBZ6_2l/s1999/image2.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1999" data-original-width="1609" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi21VRS33NG-Imj1XlKXWtrwUrl4loEEywV0tO8M0JWtUFFksbTLOhilTZtMdJTgOBdXACUPQX-f5TMAFkABFhdv_cEDmFn4d-JirU78covJI32sHus6XQVJ1C1elwM_MExsQfeVCpFYlq9QZeynLNpLqmW8GqM-DKWiGSyi_18n8Xb3-8IeepHSyBZ6_2l/s16000/image2.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Health equity considerations. An intervention (e.g., an ML-based tool, indicated in dark blue) promotes health equity if it helps reduce existing disparities in health outcomes (indicated in lighter blue).</td></tr></tbody></table>
<p>
In “<a href="https://www.thelancet.com/journals/eclinm/article/PIIS2589-5370(24)00058-0/fulltext">Health Equity Assessment of machine Learning performance (HEAL): a framework and dermatology AI model case study</a>”, published in <a href="https://www.thelancet.com/journals/eclinm/home"><i>The Lancet eClinicalMedicine</i></a>, we propose a methodology to quantitatively assess whether ML-based health technologies perform equitably. In other words, does the ML model perform well for those with the worst health outcomes for the condition(s) the model is meant to address? This goal anchors on the principle that health equity should prioritize and measure model performance with respect to disparate health outcomes, which may be due to a number of factors that include structural inequities (e.g., demographic, social, cultural, political, economic, environmental and geographic).
</p>
<br />
<h2>The health equity framework (HEAL)</h2>
<p>
The HEAL framework proposes a 4-step process to estimate the likelihood that an ML-based health technology performs equitably:
</p>
<ol>
<li>
Identify factors associated with health inequities and define tool performance metrics,
</li>
<li>
Identify and quantify pre-existing health disparities,
</li>
<li>
Measure the performance of the tool for each subpopulation,
</li>
<li>
Measure the likelihood that the tool prioritizes performance with respect to health disparities.
</li>
</ol>
<p>
The final step’s output is termed the HEAL metric, which quantifies how anticorrelated the ML model’s performance is with health disparities. In other words, does the model perform better with populations that have the worse health outcomes?
</p>
<p>
This 4-step process is designed to inform improvements for making ML model performance more equitable, and is meant to be iterative and re-evaluated on a regular basis. For example, the availability of health outcomes data in step (2) can inform the choice of demographic factors and brackets in step (1), and the framework can be applied again with new datasets, models and populations.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjoGLCxn9QWS5QQpW39mJH1A_pw9wniWKIGGapN_gBC5WdxAWo4jHRS29GhNq7XBgNdZ867tMdP7TcszMz2WxUR4sYBFz0-dJ4cQZCODN2YFRjCP14QhNh_kMVGUdklbToOCYwHXV-UofhZdwZzDZudaVedOqvcC-QbW3LtMGb04FwFclbfzKHVUcqHodW_/s1999/image1.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1352" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjoGLCxn9QWS5QQpW39mJH1A_pw9wniWKIGGapN_gBC5WdxAWo4jHRS29GhNq7XBgNdZ867tMdP7TcszMz2WxUR4sYBFz0-dJ4cQZCODN2YFRjCP14QhNh_kMVGUdklbToOCYwHXV-UofhZdwZzDZudaVedOqvcC-QbW3LtMGb04FwFclbfzKHVUcqHodW_/s16000/image1.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Framework for Health Equity Assessment of machine Learning performance (HEAL). Our guiding principle is to avoid exacerbating health inequities, and these steps help us identify disparities and assess for inequitable model performance to move towards better outcomes for all.</td></tr></tbody></table>
<p>
With this work, we take a step towards encouraging explicit assessment of the health equity considerations of AI technologies, and encourage prioritization of efforts during model development to reduce health inequities for subpopulations exposed to structural inequities that can precipitate disparate outcomes. We should note that the present framework does not model causal relationships and, therefore, cannot quantify the actual impact a new technology will have on reducing health outcome disparities. However, the HEAL metric may help identify opportunities for improvement, where the current performance is not prioritized with respect to pre-existing health disparities.
</p>
<br />
<h2>Case study on a dermatology model</h2>
<p>
As an illustrative case study, we applied the framework to a dermatology model, which utilizes a convolutional neural network similar to that described in <a href="https://blog.research.google/2019/09/using-deep-learning-to-inform.html">prior work</a>. This example dermatology model was trained to classify 288 skin conditions using a development dataset of 29k cases. The input to the model consists of three photos of a skin concern along with demographic information and a brief structured medical history. The output consists of a ranked list of possible matching skin conditions.
</p>
<p>
Using the HEAL framework, we evaluated this model by assessing whether it prioritized performance with respect to pre-existing health outcomes. The model was designed to predict possible dermatologic conditions (from a list of hundreds) based on photos of a skin concern and patient metadata. Evaluation of the model is done using a top-3 agreement metric, which quantifies how often the top 3 output conditions match the most likely condition as suggested by a dermatologist panel. The HEAL metric is computed via the anticorrelation of this top-3 agreement with health outcome rankings.
</p>
<p>
We used a dataset of 5,420 teledermatology cases, enriched for diversity in age, sex and race/ethnicity, to retrospectively evaluate the model’s HEAL metric. The dataset consisted of “store-and-forward” cases from patients of 20 years or older from primary care providers in the USA and skin cancer clinics in Australia. Based on a review of the literature, we decided to explore race/ethnicity, sex and age as potential factors of inequity, and used sampling techniques to ensure that our evaluation dataset had sufficient representation of all race/ethnicity, sex and age groups. To quantify pre-existing health outcomes for each subgroup we relied on measurements from <a href="https://www.who.int/data/gho/data/themes/mortality-and-global-health-estimates/global-health-estimates-leading-causes-of-dalys">public</a> <a href="https://www.thelancet.com/journals/lancet/article/PIIS0140-6736(20)30925-9/fulltext">databases</a> endorsed by the World Health Organization, such as <a href="https://www.who.int/data/gho/indicator-metadata-registry/imr-details/4427">Years of Life Lost</a> (YLLs) and <a href="https://www.who.int/data/gho/indicator-metadata-registry/imr-details/158">Disability-Adjusted Life Years</a> (DALYs; years of life lost plus years lived with disability).
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiSS4J8AzS5iaHYvB7RyUVEDkx1ykrC7zOEAbUvjb8ZybZRZ0C71fRlJjPYBzGYVu9D3Ok0zRdz4MUdHMX6rOqnYKoHv91QNPw0TiqHJ6MKjtgn_UIqW-xoZeihO-A-ZrPgWT8bs-t9bSZWmMQ9AJaQh85BZWHH-T0KPWMx2unNO9HpTzYXiD_24gwNYWot/s1511/Table1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="602" data-original-width="1511" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiSS4J8AzS5iaHYvB7RyUVEDkx1ykrC7zOEAbUvjb8ZybZRZ0C71fRlJjPYBzGYVu9D3Ok0zRdz4MUdHMX6rOqnYKoHv91QNPw0TiqHJ6MKjtgn_UIqW-xoZeihO-A-ZrPgWT8bs-t9bSZWmMQ9AJaQh85BZWHH-T0KPWMx2unNO9HpTzYXiD_24gwNYWot/s16000/Table1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">HEAL metric for all dermatologic conditions across race/ethnicity subpopulations, including health outcomes (YLLs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance.<br />(* Higher is better; measures the likelihood the model performs equitably with respect to the axes in this table.)</td></tr></tbody></table>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMAQjyuGMXvzq4FxZg5Vhlgozwwnzza-QS-mjr3i0oOnDFIeqUGTrPxX2c7ssbpCZtLUoT2lpr8bXg_nJ3ToaaVe6Grge-HcWQl8SFy1gaBCoT-6ZHtFmQV4_S2sA6eOsdMFryegLjZFwOcPiqZDfFFItxqS96ysTZZn1OXVcbQSOG5WazZGjxSkNt9JQK/s1518/Table2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="316" data-original-width="1518" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMAQjyuGMXvzq4FxZg5Vhlgozwwnzza-QS-mjr3i0oOnDFIeqUGTrPxX2c7ssbpCZtLUoT2lpr8bXg_nJ3ToaaVe6Grge-HcWQl8SFy1gaBCoT-6ZHtFmQV4_S2sA6eOsdMFryegLjZFwOcPiqZDfFFItxqS96ysTZZn1OXVcbQSOG5WazZGjxSkNt9JQK/s16000/Table2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">HEAL metric for all dermatologic conditions across sexes, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)</td></tr></tbody></table
<p>
Our analysis estimated that the model was 80.5% likely to perform equitably across race/ethnicity subgroups and 92.1% likely to perform equitably across sexes.
</p>
<p>
However, while the model was likely to perform equitably across age groups for cancer conditions specifically, we discovered that it had room for improvement across age groups for non-cancer conditions. For example, those 70+ have the poorest health outcomes related to non-cancer skin conditions, yet the model didn't prioritize performance for this subgroup.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4s5yfNQCksLIqP3kYuDXahUlOcJSCEtt-JkSTsecDft21uJ8JR0imnsPVGYHVQnc7OPo1WOkcwx2Yevu6su-rbqc1Fl6_NfzCKl0_vOvZA3PPnLkVWKFk7jHPJCm-x69MupVih_zct1YOXJVvSNUIsvn4rICk-_RWbOeuKj4HdRphBOakRXsiJ4lETJ_M/s1508/Table3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="644" data-original-width="1508" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4s5yfNQCksLIqP3kYuDXahUlOcJSCEtt-JkSTsecDft21uJ8JR0imnsPVGYHVQnc7OPo1WOkcwx2Yevu6su-rbqc1Fl6_NfzCKl0_vOvZA3PPnLkVWKFk7jHPJCm-x69MupVih_zct1YOXJVvSNUIsvn4rICk-_RWbOeuKj4HdRphBOakRXsiJ4lETJ_M/s16000/Table3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">HEAL metrics for all cancer and non-cancer dermatologic conditions across age groups, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)</td></tr></tbody></table>
<br />
<h2>Putting things in context</h2>
<p>
For holistic evaluation, the HEAL metric cannot be employed in isolation. Instead this metric should be contextualized alongside many other factors ranging from computational efficiency and data privacy to ethical values, and aspects that may influence the results (e.g., selection bias or differences in representativeness of the evaluation data across demographic groups).
</p>
<p>
As an adversarial example, the HEAL metric can be artificially improved by deliberately reducing model performance for the most advantaged subpopulation until performance for that subpopulation is worse than all others. For illustrative purposes, given subpopulations A and B where A has worse health outcomes than B, consider the choice between two models: Model 1 (M1) performs 5% better for subpopulation A than for subpopulation B. Model 2 (M2) performs 5% worse on subpopulation A than B. The HEAL metric would be higher for M1 because it prioritizes performance on a subpopulation with worse outcomes. However, M1 may have absolute performances of just 75% and 70% for subpopulations A and B respectively, while M2 has absolute performances of 75% and 80% for subpopulations A and B respectively. Choosing M1 over M2 would lead to worse overall performance for all subpopulations because some subpopulations are worse-off while no subpopulation is better-off.
</p>
<p>
Accordingly, the HEAL metric should be used alongside a <a href="https://en.wikipedia.org/wiki/Pareto_efficiency">Pareto condition</a> (discussed further in the paper), which restricts model changes so that outcomes for each subpopulation are either unchanged or improved compared to the status quo, and performance does not worsen for any subpopulation.
</p>
<p>
The HEAL framework, in its current form, assesses the likelihood that an ML-based model prioritizes performance for subpopulations with respect to pre-existing health disparities for specific subpopulations. This differs from the goal of understanding whether ML will reduce disparities in outcomes across subpopulations in reality. Specifically, modeling improvements in outcomes requires a causal understanding of steps in the care journey that happen both before and after use of any given model. Future research is needed to address this gap.
</p>
<br />
<h2>Conclusion</h2>
<p>
The HEAL framework enables a quantitative assessment of the likelihood that health AI technologies prioritize performance with respect to health disparities. The case study demonstrates how to apply the framework in the dermatological domain, indicating a high likelihood that model performance is prioritized with respect to health disparities across sex and race/ethnicity, but also revealing the potential for improvements for non-cancer conditions across age. The case study also illustrates limitations in the ability to apply all recommended aspects of the framework (e.g., mapping societal context, availability of data), thus highlighting the complexity of health equity considerations of ML-based tools.
</p>
<p>
This work is a proposed approach to address a grand challenge for AI and health equity, and may provide a useful evaluation framework not only during model development, but during pre-implementation and real-world monitoring stages, e.g., in the form of health equity dashboards. We hold that the strength of the HEAL framework is in its future application to various AI tools and use cases and its refinement in the process. Finally, we acknowledge that a successful approach towards understanding the impact of AI technologies on health equity needs to be more than a set of metrics. It will require a set of goals agreed upon by a community that represents those who will be most impacted by a model.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>The research described here is joint work across many teams at Google. We are grateful to all our co-authors: Terry Spitz, Malcolm Pyles, Heather Cole-Lewis, Ellery Wulczyn, Stephen R. Pfohl, Donald Martin, Jr., Ronnachai Jaroensri, Geoff Keeling, Yuan Liu, Stephanie Farquhar, Qinghan Xue, Jenna Lester, Cían Hughes, Patricia Strachan, Fraser Tan, Peggy Bui, Craig H. Mermel, Lily H. Peng, Yossi Matias, Greg S. Corrado, Dale R. Webster, Sunny Virmani, Christopher Semturs, Yun Liu, and Po-Hsuan Cameron Chen. We also thank Lauren Winer, Sami Lachgar, Ting-An Lin, Aaron Loh, Morgan Du, Jenny Rizk, Renee Wong, Ashley Carrick, Preeti Singh, Annisah Um'rani, Jessica Schrouff, Alexander Brown, and Anna Iurchenko for their support of this project.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-78680327998563331192024-03-14T12:38:00.000-07:002024-03-14T12:38:11.597-07:00Cappy: Outperforming and boosting large multi-task language models with a small scorer<span class="byline-author">Posted by Yun Zhu and Lijuan Liu, Software Engineers, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFNlqVAnwoYdZ97LvC4-ipR6FeOc4o9udsTUtNBBWl5Y4XHclcrz3kTCibizteSBc_xsVLh-pyRiCCNfIzTDHEs7VsJcUMCk0EjUxzvKITKCncdx1y7u9JXGkXM6TyoZY5RhUt2l_up-Us0yIV-0-EUvHsjOlFNSSNgNHlpwK1PAliqcj4gSoLsYXhIi18/s320/Cappy%20hero.jpg" style="display: none;" />
<p>
Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as <a href="https://arxiv.org/abs/2110.08207">T0</a>, <a href="https://arxiv.org/abs/2210.11416">FLAN</a>, and <a href="https://arxiv.org/abs/2212.12017">OPT-IML</a>. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., <em>"</em>Put the concepts together to form a sentence: ski, mountain, skier<em>”</em>) paired with a corresponding response (e.g., <em>"</em>Skier skis down the mountain<em>"</em>). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions.
</p>
<a name='more'></a>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhMcacnhPA68XiEskvhExF4SGFh4997UZzwvhYfXt-ReGXtzfGTamLB3LZoYSh8WWuf1dmlBnNAUecAMhrBTOMVF6vxsw3BqY8Ld5xPgSdZY_cywScxxxQ5e6uwhawA5VYDEj6VtSyOTNGZtjdLXieeFV5OLiDk3bnB-xaz4MIbvUO-7RPadk8iQDv3206V/s640/Cappy%20instruction-following.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="177" data-original-width="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhMcacnhPA68XiEskvhExF4SGFh4997UZzwvhYfXt-ReGXtzfGTamLB3LZoYSh8WWuf1dmlBnNAUecAMhrBTOMVF6vxsw3BqY8Ld5xPgSdZY_cywScxxxQ5e6uwhawA5VYDEj6VtSyOTNGZtjdLXieeFV5OLiDk3bnB-xaz4MIbvUO-7RPadk8iQDv3206V/s16000/Cappy%20instruction-following.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks.</td></tr></tbody></table>
<p>
Due to the complexity of understanding and solving various tasks solely using instructions, the size of multi-task LLMs typically spans from several billion parameters to hundreds of billions (e.g., <a href="https://arxiv.org/abs/2210.11416">FLAN-11B</a>, <a href="https://arxiv.org/abs/2110.08207">T0-11B</a> and <a href="https://arxiv.org/abs/2212.12017">OPT-IML-175B</a>). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues.
</p>
<p>
Certain <em>parameter-efficient tuning</em> strategies, including <a href="https://aclanthology.org/2021.acl-long.353.pdf">prompt tuning</a> and <a href="https://openreview.net/pdf?id=nZeVKeeFYf9">adapters</a>, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some <em><a href="https://arxiv.org/pdf/2301.00234.pdf">in-context learning</a></em> techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model's maximum input length, which permits only a few samples to guide task resolution.
</p>
<p>
In “<a href="https://arxiv.org/abs/2311.06720">Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer</a>”, presented at <a href="https://nips.cc/virtual/2023/index.html">NeurIPS 2023</a>, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of <a href="https://arxiv.org/abs/1907.11692">RoBERTa</a> with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgKxfEf2em6vxULs9wHGOB2jU7AiGMhUiEsJENdGWjB-8AMW6T2uRUrp3k3776491wzNsQCEk2T26AmiPNaKi-mfiIRNHe7JKZuR4ETQbHrM5h1knDNDBZ-qPw6sPGhtA4v0dz9YtKbHyoXPWEgYkY6r-tv8brepN8_Qq7MjCIwGUaYw5LmJMY4KLxu28ku/s1999/Cappy%20overview.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="975" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgKxfEf2em6vxULs9wHGOB2jU7AiGMhUiEsJENdGWjB-8AMW6T2uRUrp3k3776491wzNsQCEk2T26AmiPNaKi-mfiIRNHe7JKZuR4ETQbHrM5h1knDNDBZ-qPw6sPGhtA4v0dz9YtKbHyoXPWEgYkY6r-tv8brepN8_Qq7MjCIwGUaYw5LmJMY4KLxu28ku/s16000/Cappy%20overview.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Pre-training</h2>
<p>
We begin with the same dataset collection, which includes 39 diverse datasets from <a href="https://arxiv.org/abs/2202.01279">PromptSource</a> that were used to train <a href="https://arxiv.org/abs/2110.08207">T0</a>. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response.
</p>
<p>
Cappy's regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ <a href="https://aclanthology.org/W04-1013/">Rouge-L</a>, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision.
</p>
<p>
As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the <a href="https://arxiv.org/abs/1907.11692">RoBERTa</a> model. The pre-training of Cappy is conducted on Google's <a href="https://arxiv.org/abs/2304.01433">TPU-v4</a>, with <a href="https://arxiv.org/pdf/2310.16355.pdf">RedCoast</a>, a lightweight toolkit for automating distributed training.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEguKQnabejBzwCo7XEYZBJaaHi9_Z0Z03aofMxhmno2dKMbh2d6qVhmu7kKLN7FVExLXwYZYu1UEa1brRSC7bX3ASLyZymVyougwQqhCoE7Iio6DvIzdIK_dYT-1IGk41jZ6qdYcDynxezST6FY8u73opddwlGcGTf-3fXY4KfPo5hhfIinUl7iXRN7V6Sr/s1999/Cappy%20data%20augmentation.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="438" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEguKQnabejBzwCo7XEYZBJaaHi9_Z0Z03aofMxhmno2dKMbh2d6qVhmu7kKLN7FVExLXwYZYu1UEa1brRSC7bX3ASLyZymVyougwQqhCoE7Iio6DvIzdIK_dYT-1IGk41jZ6qdYcDynxezST6FY8u73opddwlGcGTf-3fXY4KfPo5hhfIinUl7iXRN7V6Sr/s16000/Cappy%20data%20augmentation.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Applying Cappy</h2>
<p>
Cappy solves practical tasks within a candidate-selection mechanism. More specifically, given an instruction and a set of candidate responses, Cappy produces a score for each candidate response. This is achieved by inputting the instruction alongside each individual response, and then assigning the response with the highest score as its prediction. In classification tasks, all candidate responses are inherently predefined. For example, for an instruction of a sentiment classification task (e.g., “Based on this review, would the user recommend this product?: ‘Stunning even for the non-gamer.’”), the candidate responses are “Yes” or “No”. In such scenarios, Cappy functions independently. On the other hand, in generation tasks, candidate responses are not pre-defined, requiring an existing multi-task LLM to yield the candidate responses. In this case, Cappy serves as an auxiliary component of the multi-task LLM, enhancing its decoding.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Adapting multi-task LLMs with Cappy </h3>
<p>
When there is available downstream training data, Cappy enables effective and efficient adaptation of multi-task LLMs on downstream tasks. Specifically, we fine-tune Cappy to integrate downstream task information into LLM predictions. This process involves creating a separate regression dataset specific to the downstream training data with the same data annotation process used to construct the pre-training data. As a result, the fine-tuned Cappy collaborates with a multi-task LLM, boosting the LLM's performance on the downstream task.
</p>
<p>
In contrast to other LLM tuning strategies, adapting LLMs with Cappy significantly reduces the high demand for device memory as it avoids the need for back-propagation through LLM parameters for downstream tasks. Moreover, Cappy adaptation does not rely on the access to LLM parameters, making it compatible with closed-source multi-task LLMs, such as the ones only accessible via WebAPIs. Compared with in-context learning approaches, which circumvent model tuning by attaching training examples to the instruction prefix, Cappy is not restricted by the LLM's maximum input length. Thus, Cappy can incorporate an unlimited number of downstream training examples. Cappy can also be applied with other adaptation methods, such as fine-tuning and in-context learning, further boosting their overall performance.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhf1zSOPmuCuPWHPOXYRczk86xESIKANYJN7jUqjkoSQabuQrDyLEfyLCXG0eAEHG1xiYL6jrZ8iMC14a2FhQs7XNwyncRdCyfIRa3KlLx3786yfSXfP9pEwtUEJ6ax7l5J8MchxjH9cV_hKqQFanTh3kNCs_JHYw0vsMOFi09-69-anFrqJShRgYFcKvfe/s1999/Cappy%20downstream%20adaptation.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1040" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhf1zSOPmuCuPWHPOXYRczk86xESIKANYJN7jUqjkoSQabuQrDyLEfyLCXG0eAEHG1xiYL6jrZ8iMC14a2FhQs7XNwyncRdCyfIRa3KlLx3786yfSXfP9pEwtUEJ6ax7l5J8MchxjH9cV_hKqQFanTh3kNCs_JHYw0vsMOFi09-69-anFrqJShRgYFcKvfe/s16000/Cappy%20downstream%20adaptation.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Results</h2>
<p>
We assess Cappy’s performance across eleven held-out language understanding classification tasks from <a href="https://arxiv.org/abs/2202.01279">PromptSource</a>. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on <a href="https://en.wikipedia.org/wiki/Teacher_forcing">teacher-forcing training</a> that utilizes only the ground truth responses.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjyehdahD05Plit772klEfeGTN1GteCZcwwsyWbGgOTtgH4VD3hzPkF8PSDdYZe2EOE0nwL9xdNLZYLzvBJrm9ECTSGIWWUJ-Xo-1uVQUmN8uu0_5dLAERYPvOFfahf1ZZ2bId0tna1ch8BBXV9xKWpPKNIoAlihdNxZvlegShjI6Fjd5Twd8kv6w-axtUW/s1999/Cappy%20accuracy.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1125" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjyehdahD05Plit772klEfeGTN1GteCZcwwsyWbGgOTtgH4VD3hzPkF8PSDdYZe2EOE0nwL9xdNLZYLzvBJrm9ECTSGIWWUJ-Xo-1uVQUmN8uu0_5dLAERYPvOFfahf1ZZ2bId0tna1ch8BBXV9xKWpPKNIoAlihdNxZvlegShjI6Fjd5Twd8kv6w-axtUW/s16000/Cappy%20accuracy.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a <a href="https://huggingface.co/OpenAssistant/reward-model-deberta-v3-large-v2">pre-trained RLHF reward model</a>. Cappy matches the best ones among existing multi-task LLMs.</td></tr></tbody></table>
<p>
We also examine the adaptation of multi-task LLMs with Cappy on complex tasks from <a href="https://arxiv.org/abs/2206.04615">BIG-Bench</a>, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhUmqWX5mq_zgs3nK6TSR3sNEunAburBnwxpIaFNxTuXbhLKeuI-c71IBxZw3tEnnnOHeE7heImqnZyluCAV92_2fhhXEfus_4R0MC78e_WOOXcSNvfyiVLNqNGhYK88YfiT__Ijss-OPpCo4XDz4vLFjtJKM-Mko_n2IgMabNI5J1a3LAVlIvBvRpiZ8GZ/s1999/Cappy%20averaged%20Rouge-L%20score.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1625" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhUmqWX5mq_zgs3nK6TSR3sNEunAburBnwxpIaFNxTuXbhLKeuI-c71IBxZw3tEnnnOHeE7heImqnZyluCAV92_2fhhXEfus_4R0MC78e_WOOXcSNvfyiVLNqNGhYK88YfiT__Ijss-OPpCo4XDz4vLFjtJKM-Mko_n2IgMabNI5J1a3LAVlIvBvRpiZ8GZ/s16000/Cappy%20averaged%20Rouge-L%20score.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin.</td></tr></tbody></table>
<br>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
We introduce Cappy, a novel approach that enhances the performance and efficiency of multi-task LLMs. In our experiments, we adapt a single LLM to several domains with Cappy. In the future, Cappy as a pre-trained model can potentially be used in other creative ways beyond on single LLMs.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgments</h2>
<p>
<em>Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions. </em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-60904818726944897152024-03-12T14:15:00.000-07:002024-03-13T09:21:17.600-07:00Talk like a graph: Encoding graphs for large language models<span class="byline-author">Posted by Bahare Fatemi and Bryan Perozzi, Research Scientists, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg8L7r_SCzFsKsWegtrn8_EOoO2imefs-V_GVHzbM0Xw7GmAxoXIIX0RtpJ2JvloeenxcKCNmhCH_VXRMpu8b5dJP39UkhMJS0wP86TUftZtUi-hfj6tZdVEn30MZAeQEx762q1vN-q4DWP2EdOBIHy_CgNFMcliaJYnzxZHjnuifbVWy52zlls20m4BkyJ/s1600/Screenshot%202024-03-12%20at%202.18.27%E2%80%AFPM.png" style="display: none;" />
<p>
Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term <em><a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)">graph</a> </em>is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.
</p><a name='more'></a>
<p>
Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses.
</p>
<p>
Since graphs are everywhere and LLM technology is on the rise, in “<a href="https://openreview.net/forum?id=IuXR1CCrSi">Talk like a Graph: Encoding Graphs for Large Language Models</a>”, presented at <a href="https://iclr.cc/">ICLR 2024</a>, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called <em><a href="https://github.com/google-research/google-research/tree/master/graphqa">GraphQA</a></em> to study different approaches on different graph reasoning problems and show how to <em>phrase</em> a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjAnieWluvqGQtyh_L3a_Y7XfYUR2dBRGpQf58DpzJfIrkyM2JnwxiCOvTzDidvP-GtbtRe4NsJUEFlzpW8nQbf8WGQD6P_C2jjsRZeLiyDSO8QF8IiGCRYnSa4MxruywJt60gU8KrH6w87ZoBXsGbPmyWDx01j1nqSCaEtfFeNTmAWSLcVVcND8XuzoaHb/s1600/image7.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="400" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjAnieWluvqGQtyh_L3a_Y7XfYUR2dBRGpQf58DpzJfIrkyM2JnwxiCOvTzDidvP-GtbtRe4NsJUEFlzpW8nQbf8WGQD6P_C2jjsRZeLiyDSO8QF8IiGCRYnSa4MxruywJt60gU8KrH6w87ZoBXsGbPmyWDx01j1nqSCaEtfFeNTmAWSLcVVcND8XuzoaHb/s16000/image7.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Graphs as text</h2>
<p>
To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called <em><a href="https://github.com/google-research/google-research/tree/master/graphqa">GraphQA</a></em>. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlHJKqbwgQJFMK4siQJH_Ggag9B8lStCQ4CcXk8iPnNPgxGPLYl_LTrIfjxuP7vKKtzJITlltZ5pcq7RElYNVQJ8PKi9Sr3ctigYfLs6SBlMAEhDHP2nV2PJ-uLhJxUkZ3MdAGV7R8rjw0u6Y8QTCwrMTyqz7tuxzb3TnIFabf4ZZbsSQ95MSboOA42i4w/s1368/image6.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="291" data-original-width="1368" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlHJKqbwgQJFMK4siQJH_Ggag9B8lStCQ4CcXk8iPnNPgxGPLYl_LTrIfjxuP7vKKtzJITlltZ5pcq7RElYNVQJ8PKi9Sr3ctigYfLs6SBlMAEhDHP2nV2PJ-uLhJxUkZ3MdAGV7R8rjw0u6Y8QTCwrMTyqz7tuxzb3TnIFabf4ZZbsSQ95MSboOA42i4w/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Overview of our framework for reasoning with graphs using LLMs.</td></tr></tbody></table>
<p>
GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like <a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model">Erdős-Rényi</a>, <a href="https://en.wikipedia.org/wiki/Scale-free_network">scale-free networks</a>, <a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model">Barabasi-Albert model</a>, and <a href="https://en.wikipedia.org/wiki/Stochastic_block_model">stochastic block model</a>, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.
</p>
<p>
When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. <em>Prompting heuristics</em> are different strategies for doing this. Let's break down the common ones:
</p>
<ul>
<li><em>Zero-shot</em>: simply describe the task ("Is there a cycle in this graph?") and tell the LLM to go for it. No examples provided.
</li><li><em>Few-shot</em>: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.
</li><li><em>Chain-of-Thought</em>: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own "thought process" when faced with new graphs.
</li><li><em>Zero-CoT</em>: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like "Let's think step-by-step," to trigger its own problem-solving breakdown.
</li><li><em>BAG (build a graph)</em>: This is specifically for graph tasks. We add the phrase "Let's build a graph..." to the description, helping the LLM focus on the graph structure.
</li>
</ul>
<p>
We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:
</p>
<ul>
<li><em>Node encoding</em>: How do we represent individual nodes? Options tested include simple <a href="https://en.wikipedia.org/wiki/Integer">integers</a>, common names (people, characters), and letters.
</li><li><em>Edge encoding</em>: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like "are friends", and symbolic representations like arrows.
</li>
</ul>
<p>
Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhHqSBznT1daPxVFMf0ZOR0uiZpjYrTG46t71FWy4tq5IMh-Ijhbzp_toJVmvp72FGrtoQXFkhCaaDVkhCzQXzcfRUPvW7151j22mmVxejpNJdO6VcvdHOkmEye_1zEBtfvAVgSw6RPFOiCpdo9LnetLvgrS-OL7IZPRLpBaCWGny_mzk6wpZcHDY-oS1ts/s855/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="404" data-original-width="855" height="302" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhHqSBznT1daPxVFMf0ZOR0uiZpjYrTG46t71FWy4tq5IMh-Ijhbzp_toJVmvp72FGrtoQXFkhCaaDVkhCzQXzcfRUPvW7151j22mmVxejpNJdO6VcvdHOkmEye_1zEBtfvAVgSw6RPFOiCpdo9LnetLvgrS-OL7IZPRLpBaCWGny_mzk6wpZcHDY-oS1ts/w640-h302/image1.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Examples of graph encoding functions used to encode graphs via text.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Analysis and results</h2>
<p>
We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>How LLMs handle graph tasks </h3>
<p>
In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:
</p>
<ul>
<li><em>LLMs struggle:</em> On most of these basic tasks, LLMs did not do much better than a random guess.
</li><li><em>Encoding matters significantly</em>: How we represent the graph as text has a great effect on LLM performance. The "incident" encoding excelled for most of the tasks in general.
</li>
</ul>
<p>
Our results are summarized in the following chart.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYJLoJxI1twg6uV55JaKbVVnhO-dhgcaSN_B-FK9MTT8kKI1k_xnbGCvaEpmr82U4OGQxJ-oGNYOa0izo3jD1Ssvz8BVaKgw5ObjwN6_zS54BOALM_aO6TbLf-7SfcokAqRRC9fUbdErDeuadKBuRq7ihEootiLodZoYLKtZVDAgTI1ZrxviY7SI1PcFm5/s1864/image8.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1152" data-original-width="1864" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYJLoJxI1twg6uV55JaKbVVnhO-dhgcaSN_B-FK9MTT8kKI1k_xnbGCvaEpmr82U4OGQxJ-oGNYOa0izo3jD1Ssvz8BVaKgw5ObjwN6_zS54BOALM_aO6TbLf-7SfcokAqRRC9fUbdErDeuadKBuRq7ihEootiLodZoYLKtZVDAgTI1ZrxviY7SI1PcFm5/s16000/image8.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Bigger is (usually) better </h3>
<p>
In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of <a href="https://ai.google/static/documents/palm2techreport.pdf">PaLM 2</a>. Here is a summary of our findings:
</p>
<ul>
<li>In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.
</li><li>Oddly, size didn't matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).
</li><li>Even the biggest LLM couldn't consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
</li>
</ul>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiG-zu3s3K3iCIV5k2gakpMwQ_38a08_NrYeO3yITJc64EYiK36sksPulORuZR_BrGdmxZmCWEgIX2sWc42M4f3jpo8v17AddfoORPliE-SefptA4h4gye_g_PBKnufZ9kzTkI0f9MCKwSvuEqfcdgxNiycB2bGUQyUtXx8F7XU4qpXKZGEINZudJxlu-6L/s1227/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="959" data-original-width="1227" height="500" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiG-zu3s3K3iCIV5k2gakpMwQ_38a08_NrYeO3yITJc64EYiK36sksPulORuZR_BrGdmxZmCWEgIX2sWc42M4f3jpo8v17AddfoORPliE-SefptA4h4gye_g_PBKnufZ9kzTkI0f9MCKwSvuEqfcdgxNiycB2bGUQyUtXx8F7XU4qpXKZGEINZudJxlu-6L/w640-h500/image3.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Do different graph shapes confuse LLMs </h3>
<p>
We wondered if the "shape" of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQ9tU8x8LvDYvwwN9XL4j64tXEq-7fGwnzYvS5zpNcEjk9yjxLH2yYmOAfKwr7_w9dHTUD1xtnI6IMAswp0pyManGDEO1ej1WeH9yByu-5ivtlfU5N-7OWJDtnR1uMeG7oWs1eqyiZFOyUpUa5GddPtECkd4ZvNPSx9rtS8fh83ahArgXtpKtVy7tQES9N/s1400/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="195" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgQ9tU8x8LvDYvwwN9XL4j64tXEq-7fGwnzYvS5zpNcEjk9yjxLH2yYmOAfKwr7_w9dHTUD1xtnI6IMAswp0pyManGDEO1ej1WeH9yByu-5ivtlfU5N-7OWJDtnR1uMeG7oWs1eqyiZFOyUpUa5GddPtECkd4ZvNPSx9rtS8fh83ahArgXtpKtVy7tQES9N/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to <a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model">Erdős–Rényi</a>, <a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model">Barabási–Albert</a>, <a href="https://en.wikipedia.org/wiki/Stochastic_block_model">Stochastic Block Model</a>, and <a href="https://en.wikipedia.org/wiki/Scale-free_network">Scale-Free Network</a> respectively.</td></tr></tbody></table>
<p>
We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgqf5piX3TuSfL0DqpUG7ZkBgMKEqqeCIh0feFG4ddMiaHTFgLY3iPkI4UD3gZpAKeTHgfhItKeXo8P3M4sGSQRZJJsXMAVFutTDuWziSwt1CBvt7kV1VSOSHqGTu0yk7lAym4XYJERrS3FETWbj17agumgHaln1EevI_LyzqAbNFZjYNPZGKjw1fgKBydk/s1864/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1152" data-original-width="1864" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgqf5piX3TuSfL0DqpUG7ZkBgMKEqqeCIh0feFG4ddMiaHTFgLY3iPkI4UD3gZpAKeTHgfhItKeXo8P3M4sGSQRZJJsXMAVFutTDuWziSwt1CBvt7kV1VSOSHqGTu0yk7lAym4XYJERrS3FETWbj17agumgHaln1EevI_LyzqAbNFZjYNPZGKjw1fgKBydk/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to <a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model">Erdős–Rényi</a>, <a href="https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model">Barabási–Albert</a>, <a href="https://en.wikipedia.org/wiki/Stochastic_block_model">Stochastic Block Model</a>, and <a href="https://en.wikipedia.org/wiki/Scale-free_network">Scale-Free Network</a> respectively.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:
</p>
<ul>
<li><em>How to translate the graph to text</em>: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..
</li><li><em>Task type</em>: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.
</li><li><em>Graph structure</em>: Surprisingly, the "shape" of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.
</li>
</ul>
<p>
This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM's accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-4708403489832809122024-03-11T12:08:00.000-07:002024-03-11T12:13:03.824-07:00Chain-of-table: Evolving tables in the reasoning chain for table understanding<span class="byline-author">Posted by Zilong Wang, Student Researcher, and Chen-Yu Lee, Research Scientist, Cloud AI Team</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg1smBN07qkS32Aop4if0AeINQQea0Grv8dw7GiRFBNoHBlgkkftynVBNjO6BckpF4vq8d0VqC1v0LoeFAVFqOLrBGlqvMNiCMUtIhHxVvsBjbPxvZLcNcD_Sa1sI_bDlqDLWn_C39MbPNm8VUjr2vhTBuaL4qCc1LUB1VH5iM0UVsswIWWq_uQg88YRWmb/s832/Chain-of-Table.png" style="display: none;" />
<p>
People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in <a href="https://en.wikipedia.org/wiki/Natural_language_processing">natural language processing</a> (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize.
</p>
<a name='more'></a>
<p>
Recently, <a href="https://en.wikipedia.org/wiki/Large_language_model">large language models</a> (LLMs) have achieved outstanding performance across diverse <a href="https://en.wikipedia.org/wiki/Natural-language_understanding">natural language understanding</a> (NLU) tasks by generating reliable reasoning chains, as shown in works like <a href="https://arxiv.org/abs/2201.11903">Chain-of-Thought</a> and <a href="https://arxiv.org/abs/2205.10625">Least-to-Most</a>. However, the most suitable way for LLMs to reason over tabular data remains an open question.
</p>
<p>
In “<a href="https://arxiv.org/abs/2401.04398">Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding</a>”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the <a href="https://arxiv.org/abs/1508.00305">WikiTQ</a>, <a href="https://arxiv.org/abs/1909.02164">TabFact</a>, and <a href="https://arxiv.org/abs/2104.00369">FeTaQA</a> benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKT_df1rC8nK-ULOLPjtJ8gFaDHzRi7DX92Ix7OboQhOUNvqh_Melp9SVRWEsgL1Vu6IX9RuMgX7_UIuyeuHr7H0YwJdo6om2M2rX5d9wqOWsXWVAa9o0S75bIt7qG2DiGlhYypk0KKBMSxz2Z8vgmQqxTvy3bVrmH4nSC4Nzv8fZm6mOoA5yEXN_CgC4h/s1478/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="964" data-original-width="1478" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKT_df1rC8nK-ULOLPjtJ8gFaDHzRi7DX92Ix7OboQhOUNvqh_Melp9SVRWEsgL1Vu6IX9RuMgX7_UIuyeuHr7H0YwJdo6om2M2rX5d9wqOWsXWVAa9o0S75bIt7qG2DiGlhYypk0KKBMSxz2Z8vgmQqxTvy3bVrmH4nSC4Nzv8fZm6mOoA5yEXN_CgC4h/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question.</td></tr></tbody></table>
<br />
<h2>Chain-of-Table</h2>
<p>
In Chain-of-Table, we guide LLMs using <a href="https://arxiv.org/abs/2005.14165">in-context learning</a> to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM.
</p>
<p>
For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question.
</p>
<p>
These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj8JwNNHW6SR1PSRTj79oQKqE1K48onxbcM9uwIlacEGnUqtua0jgkXQ-CfyUukJ0qiBhqsKl1_YfeJmcqkMEe5TR08eo9ZEqymWYszwNyKfZjcx0T-wYwEnHqCvdlf9lJAG8UTBN6RZQngH7sv0hQ9szR1wgjyiFSaOIqVHC08bJv6HeaXvWJMHH41wI4_/s1999/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="983" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj8JwNNHW6SR1PSRTj79oQKqE1K48onxbcM9uwIlacEGnUqtua0jgkXQ-CfyUukJ0qiBhqsKl1_YfeJmcqkMEe5TR08eo9ZEqymWYszwNyKfZjcx0T-wYwEnHqCvdlf9lJAG8UTBN6RZQngH7sv0hQ9szR1wgjyiFSaOIqVHC08bJv6HeaXvWJMHH41wI4_/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably.</td></tr></tbody></table>
<br />
<p>
Chain-of-Table consists of three main stages. In the first stage, it instructs the LLM to dynamically plan the next operation by in-context learning. Specifically, the prompt involves three components as shown in the following figure:
</p>
<ol>
<li> The question <em>Q</em>: “Which country had the most cyclists finish in the top 3?”
</li><li> The operation history <em>chain</em>: <code>f_add_col(Country)</code> and <code>f_select_row(1, 2, 3)</code>.
</li><li> The latest intermediate table <em>T</em>: the transformed intermediate table.
</li>
</ol>
<p>
By providing the triplet <em>(T, Q, chain)</em> in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIBKUfxjF1_KB5gtaj8DRWoqLWQKe_DJXLV6-1sClG1oKutdKujDHyzYgvGlAhQDK235cBoKwNkj7cuA4kLzCt_sltdiyuZSMmEKdEoDS7_XkOFTujyekDI8gJfSLRZkT5yIdGPCVvEVQPoueDgK7dXgyAs04fK3AuwSMurECyNc3ywvzDLAyoNjobg0zk/s1958/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1233" data-original-width="1958" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIBKUfxjF1_KB5gtaj8DRWoqLWQKe_DJXLV6-1sClG1oKutdKujDHyzYgvGlAhQDK235cBoKwNkj7cuA4kLzCt_sltdiyuZSMmEKdEoDS7_XkOFTujyekDI8gJfSLRZkT5yIdGPCVvEVQPoueDgK7dXgyAs04fK3AuwSMurECyNc3ywvzDLAyoNjobg0zk/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments.</td></tr></tbody></table>
<br />
<p>
After the next operation <em>f</em> is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table.
</p>
<p>
For instance, when the operation <code>f_group_by</code> is selected, it requires a header name as its argument.
</p>
<p>
The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning.
</p>
<p>
Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer.
</p>
<br />
<h2>Experimental setup</h2>
<p>
We use <a href="https://ai.google/discover/palm2/">PaLM 2-S</a> and <a href="https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates">GPT 3.5</a> as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: <a href="https://arxiv.org/abs/1508.00305">WikiTQ</a>, <a href="https://arxiv.org/abs/1909.02164">TabFact</a>, and <a href="https://arxiv.org/abs/2104.00369">FeTaQA</a>. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and <a href="https://arxiv.org/abs/2201.11903">Chain-of-Thought</a>) and the program-aided methods (e.g., <a href="https://arxiv.org/abs/2204.00498">Text-to-SQL</a>, <a href="https://arxiv.org/abs/2210.02875">Binder</a>, and <a href="https://arxiv.org/abs/2301.13808">Dater</a>).
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>More accurate answers</h3>
<p>
Compared to the generic reasoning methods and program-aided reasoning methods, Chain-of-Table achieves better performance across <a href="https://ai.google/discover/palm2/">PaLM 2</a> and <a href="https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates">GPT 3.5</a>. This is attributed to the dynamically sampled operations and the informative intermediate tables.
</p><table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglv7DlLdRCDhXh2D8EE8DaOlnyYOBET9usjjD4jQkBMDH_sdWzf72QL6qo8F6wXP6ThhxggSjh-F-z0aah7Qr36ghB3muAAn2k0cjfKV9hBSRaIooRI30qkAbn9nft00DNKG0WjCfVxyNYGD3AciTo282wQDItTceKuDKo03KGTOWvm76HXK2PGgQM8h5o/s1018/ChainOfTableUnderstanding.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="755" data-original-width="1018" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglv7DlLdRCDhXh2D8EE8DaOlnyYOBET9usjjD4jQkBMDH_sdWzf72QL6qo8F6wXP6ThhxggSjh-F-z0aah7Qr36ghB3muAAn2k0cjfKV9hBSRaIooRI30qkAbn9nft00DNKG0WjCfVxyNYGD3AciTo282wQDItTceKuDKo03KGTOWvm76HXK2PGgQM8h5o/s16000/ChainOfTableUnderstanding.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><span style="text-align: left;">Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models.</span></td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Better robustness on harder questions</h3>
<p>
In Chain-of-Table, longer operation chains indicate the higher difficulty and complexity of the questions and their corresponding tables. We categorize the test samples according to their operation lengths in Chain-of-Table. We compare Chain-of-Table with Chain-of-Thought and Dater, as representative generic and program-aided reasoning methods. We illustrate this using results from <a href="https://ai.google/discover/palm2/">PaLM 2</a> on <a href="https://arxiv.org/abs/1508.00305">WikiTQ</a>.
</p><table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhONWxPX_gDzAJe0m3HLMjtdzFZ_EF_uCEvpxlMdex5KpSeo2iUzAzyETzzPEl8wbbawjtmw5JbVYXWSEjkwq-198INrSZEzXlLIly40_nr65KOcgQA96rC8Pz744FQaWdTfeIFbeBO6uhPD4NmOeU1dYUzXeoPUlNk2vZ4zd4JVB6TNIaEsHJohvlrSna7/s1548/CoTOpChainLength.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="624" data-original-width="1548" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhONWxPX_gDzAJe0m3HLMjtdzFZ_EF_uCEvpxlMdex5KpSeo2iUzAzyETzzPEl8wbbawjtmw5JbVYXWSEjkwq-198INrSZEzXlLIly40_nr65KOcgQA96rC8Pz744FQaWdTfeIFbeBO6uhPD4NmOeU1dYUzXeoPUlNk2vZ4zd4JVB6TNIaEsHJohvlrSna7/s16000/CoTOpChainLength.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts.</td></tr></tbody></table>
<br />
<p>
Notably, Chain-of-Table consistently surpasses both baseline methods across all operation chain lengths, with a significant margin up to 11.6% compared with <a href="https://arxiv.org/abs/2201.11903">Chain-of-Thought</a>, and up to 7.9% compared with <a href="https://arxiv.org/abs/2301.13808">Dater</a>. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Better robustness with larger tables</h3>
<p>
We categorize the tables from <a href="https://arxiv.org/abs/1508.00305">WikiTQ</a> into three groups based on token number: small (<2000 tokens), medium (2000 to 4000 tokens) and large (>4000 tokens). We then compare Chain-of-Table with <a href="https://arxiv.org/abs/2301.13808">Dater</a> and <a href="https://arxiv.org/abs/2210.02875">Binder</a>, the two latest and strongest baselines.
</p><table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg_SgXNNWocZbCKKXAju3cpc4r-cABNL8zsrRmXJYPTiS68R8GM3lkTdxJPXoT3niFVX1bvmL9_QHrozVdl4_vYCamVsaixakttU_-ha88xZhHSbg6M_I4VgG86iynnNwv9ywdcbh5vFtqTKAs2kMmFGZNx85WBM5-RBxI63vvMfau7WbLSkqA7yrOIguY_/s1999/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1008" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg_SgXNNWocZbCKKXAju3cpc4r-cABNL8zsrRmXJYPTiS68R8GM3lkTdxJPXoT3niFVX1bvmL9_QHrozVdl4_vYCamVsaixakttU_-ha88xZhHSbg6M_I4VgG86iynnNwv9ywdcbh5vFtqTKAs2kMmFGZNx85WBM5-RBxI63vvMfau7WbLSkqA7yrOIguY_/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><span style="text-align: left;">Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)</span></td></tr></tbody></table><br />
<p>
Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)
</p>
<p>
As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs.
</p>
<br />
<h2>Conclusion</h2>
<p>
Our proposed Chain-of-Table method enhances the reasoning capability of LLMs by leveraging the tabular structure to express intermediate steps for table-based reasoning. It instructs LLMs to dynamically plan an operation chain according to the input table and its associated question. This evolving table design sheds new light on the understanding of prompting LLMs for table understanding.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-11066243616495723762024-03-08T11:33:00.000-08:002024-03-13T09:18:01.747-07:00Health-specific embedding tools for dermatology and pathology<span class="byline-author">Posted by Dave Steiner, Clinical Research Scientist, Google Health, and Rory Pilgrim, Product Manager, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi9zSpggPrlQvV-c0Lc2Sd79B58CwY0kDPJjgQfh-2SR8kiZuXO9A7LWZQ80zCqDNkYHm_IyNSQXF9xUOS-vPg8eJxkPR6HHuFr2VxoaAiAeG4J4ca6Pl8s9Jx1VX3tjQR0oA3I-oS2WujNwYJ2esmlfcyu1PZp7vh5MawdQc8Iu9aLM4fkAhycOXmumoKp/s16000/Path%20+%20Derm%20hero.jpg" style="display: none;" />
<p>
There’s a worldwide shortage of access to medical imaging expert interpretation across specialties including <a href="https://www.rsna.org/news/2022/may/Global-Radiologist-Shortage">radiology</a>, <a href="https://www.aad.org/dw/monthly/2021/december/feature-running-dry">dermatology</a> and <a href="https://proscia.com/infographic-the-state-of-the-pathology-workforce-2022/">pathology</a>. Machine learning (ML) technology can help ease this burden by powering tools that enable doctors to interpret these images more accurately and efficiently. However, the development and implementation of such ML tools are often limited by the availability of high-quality data, ML expertise, and computational resources.
</p>
<a name='more'></a>
<p>
One way to catalyze the use of ML for medical imaging is via domain-specific models that utilize deep learning (DL) to capture the information in medical images as compressed numerical vectors (called embeddings). These embeddings represent a type of pre-learned understanding of the important features in an image. Identifying patterns in the embeddings reduces the amount of data, expertise, and compute needed to train performant models as compared to <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">working with high-dimensional data</a>, such as images, directly. Indeed, these embeddings can be used to perform a variety of downstream tasks within the specialized domain (see animated graphic below). This framework of leveraging pre-learned understanding to solve related tasks is similar to that of a seasoned guitar player quickly learning a new song by ear. Because the guitar player has already built up a foundation of skill and understanding, they can quickly pick up the patterns and groove of a new song.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglGgLglQSBBcJqiT_SsxQf9AKGyrenZw28xTiqVP9qljNyD8mhpv-m4kl27u4NLm0FGJShNOuK456JIzdQ269xBx3fBi1u2ke10iE4THphEkD9MCCGrHjhrddtAHJ27g3pyznABW3i_CxTNkONPsH-BOcoFgS4A8tscJsJ42eD5XAHJ3FVzkfmltMzUKkq/s1600/Path%20+%20Derm%20train%20LP.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="500" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEglGgLglQSBBcJqiT_SsxQf9AKGyrenZw28xTiqVP9qljNyD8mhpv-m4kl27u4NLm0FGJShNOuK456JIzdQ269xBx3fBi1u2ke10iE4THphEkD9MCCGrHjhrddtAHJ27g3pyznABW3i_CxTNkONPsH-BOcoFgS4A8tscJsJ42eD5XAHJ3FVzkfmltMzUKkq/s16000/Path%20+%20Derm%20train%20LP.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Path Foundation is used to convert a small dataset of (image, label) pairs into (embedding, label) pairs. These pairs can then be used to train a task-specific classifier using a linear probe, (i.e., a lightweight linear classifier) as represented in this graphic, or other types of models using the embeddings as input.</td></tr></tbody></table>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZeZ25Ea3ZXz8hd6YWMkECnI0jCWnsorTZ0Ob97G-94OZfE3vVtq27pAAmZufyRHfRjUVag-ViN2bIchtZ0eCl5mUIHldWQ8e0lEJAQhYy_Ae3JTCh9Sjc2izTny5I1fo5QxxZTzwvvIKzXNNugSpyYVnUplnm54zRNRKf38EhDU4hEcHYuqqbHdlxQyyz/s1600/Path%20+%20Derm%20-%20evaluate%20LP.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="500" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZeZ25Ea3ZXz8hd6YWMkECnI0jCWnsorTZ0Ob97G-94OZfE3vVtq27pAAmZufyRHfRjUVag-ViN2bIchtZ0eCl5mUIHldWQ8e0lEJAQhYy_Ae3JTCh9Sjc2izTny5I1fo5QxxZTzwvvIKzXNNugSpyYVnUplnm54zRNRKf38EhDU4hEcHYuqqbHdlxQyyz/s16000/Path%20+%20Derm%20-%20evaluate%20LP.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Once the linear probe is trained, it can be used to make predictions on embeddings from new images. These predictions can be compared to ground truth information in order to evaluate the linear probe's performance.</td></tr></tbody></table>
<p>
In order to make this type of embedding model available and drive further development of ML tools in medical imaging, we are excited to release two domain-specific tools for research use: <a href="https://github.com/Google-Health/imaging-research/tree/master/derm-foundation">Derm Foundation</a> and <a href="https://github.com/Google-Health/imaging-research/tree/master/path-foundation">Path Foundation</a>. This follows on the strong response we’ve already received from researchers using the <a href="https://blog.research.google/2022/07/simplified-transfer-learning-for-chest.html">CXR Foundation</a> embedding tool for chest radiographs and represents a portion of our expanding research offerings across multiple medical-specialized modalities. These embedding tools take an image as input and produce a numerical vector (the embedding) that is specialized to the domains of dermatology and digital pathology images, respectively. By running a dataset of chest X-ray, dermatology, or pathology images through the respective embedding tool, researchers can obtain embeddings for their own images, and use these embeddings to quickly develop new models for their applications.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Path Foundation</h2>
<p>
In “<a href="https://arxiv.org/abs/2310.13259">Domain-specific optimization and diverse evaluation of self-supervised models for histopathology</a>”, we showed that self-supervised learning (SSL) models for pathology images outperform traditional pre-training approaches and enable efficient training of classifiers for downstream tasks. This effort focused on <a href="https://en.wikipedia.org/wiki/H%26E_stain">hematoxylin and eosin</a> (H&E) stained slides, the principal tissue stain in diagnostic pathology that enables pathologists to visualize cellular features under a microscope. The performance of linear classifiers trained using the output of the SSL models matched that of prior DL models trained on orders of magnitude more labeled data.
</p>
<p>
Due to substantial differences between digital pathology images and “natural image” photos, this work involved several pathology-specific optimizations during model training. One key element is that <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7522141/">whole-slide images</a> (WSIs) in pathology can be 100,000 pixels across (thousands of times larger than typical smartphone photos) and are analyzed by experts at multiple magnifications (zoom levels). As such, the WSIs are typically broken down into smaller tiles or patches for computer vision and DL applications. The resulting images are information dense with cells or tissue structures distributed throughout the frame instead of having distinct semantic objects or foreground vs. background variations, thus creating unique challenges for robust SSL and feature extraction. Additionally, physical (e.g., <a href="https://en.wikipedia.org/wiki/Microtome">cutting</a>) and chemical (e.g., <a href="https://en.wikipedia.org/wiki/Fixation_(histology)">fixing</a> and <a href="https://en.wikipedia.org/wiki/Staining">staining</a>) processes used to prepare the samples can influence image appearance dramatically.
</p>
<p>
Taking these important aspects into consideration, pathology-specific SSL optimizations included helping the model learn <a href="https://arxiv.org/abs/2206.12694">stain-agnostic features</a>, generalizing the model to patches from multiple magnifications, <a href="https://blog.research.google/2020/02/generating-diverse-synthetic-medical.html">augmenting</a> the data to mimic scanning and image post processing, and custom data balancing to improve input heterogeneity for SSL training. These approaches were extensively evaluated using a broad set of benchmark tasks involving 17 different tissue types over 12 different tasks.
</p>
<p>
Utilizing the vision transformer (<a href="https://github.com/google-research/vision_transformer">ViT-S/16</a>) architecture, Path Foundation was selected as the best performing model from the optimization and evaluation process described above (and illustrated in the figure below). This model thus provides an important balance between performance and model size to enable valuable and scalable use in generating embeddings over the many individual image patches of large pathology WSIs.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhG4jlO0GRCgYA3fe6CteF9PYvm3joBGIBPXakdWWaQ7ztTTBK36dmrtRpK1xoNVub8MTMvmCzkW0wfCCkYUH3fnvKk8hJb79o4vETQq0MhqS1JDBxWgYUwFkjtpnkgx5jBiDOxwovsfgqvpNzVGpz6CY6nTJzJgSgtuE2qDRzIb9O7fbHrhdNU1-IWPSXp/s1999/Path%20+%20Derm%20SSL.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1097" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhG4jlO0GRCgYA3fe6CteF9PYvm3joBGIBPXakdWWaQ7ztTTBK36dmrtRpK1xoNVub8MTMvmCzkW0wfCCkYUH3fnvKk8hJb79o4vETQq0MhqS1JDBxWgYUwFkjtpnkgx5jBiDOxwovsfgqvpNzVGpz6CY6nTJzJgSgtuE2qDRzIb9O7fbHrhdNU1-IWPSXp/s16000/Path%20+%20Derm%20SSL.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">SSL training with pathology-specific optimizations for Path Foundation.</td></tr></tbody></table>
<p>
The value of domain-specific image representations can also be seen in the figure below, which shows the linear probing performance improvement of Path Foundation (as measured by <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic">AUROC</a>) compared to traditional pre-training on natural images (<a href="https://arxiv.org/abs/2104.10972">ImageNet-21k</a>). This includes evaluation for tasks such as <a href="https://jamanetwork.com/journals/jama/fullarticle/2665774">metastatic breast cancer detection in lymph nodes</a>, <a href="https://jamanetwork.com/journals/jamaoncology/fullarticle/2768225">prostate cancer grading</a>, and <a href="https://www.nature.com/articles/s41523-022-00478-y">breast cancer grading</a>, among others.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtMvTwce8mL0GYA3YTZP0Xc7ub_BYOHIvd9k4FAfnbd-XhpVFU3T9wAl7adebAGVYSWv0RraeV_NHj-0ZiVKQ94wUM9D6GzLSg-FU9ad_L5wN4lksjbWMhN_53FhuY0yGcFvYBU8AgTY7UJKm8z9vz-rH7wkr_m5TOY8gFjWh3YkxHcPMr1wLAkS4hnGkJ/s1999/Path%20+%20Derm%20embeddings.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="890" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtMvTwce8mL0GYA3YTZP0Xc7ub_BYOHIvd9k4FAfnbd-XhpVFU3T9wAl7adebAGVYSWv0RraeV_NHj-0ZiVKQ94wUM9D6GzLSg-FU9ad_L5wN4lksjbWMhN_53FhuY0yGcFvYBU8AgTY7UJKm8z9vz-rH7wkr_m5TOY8gFjWh3YkxHcPMr1wLAkS4hnGkJ/s16000/Path%20+%20Derm%20embeddings.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Path Foundation embeddings significantly outperform traditional ImageNet embeddings as evaluated by linear probing across multiple evaluation tasks in histopathology.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Derm Foundation</h2>
<p>
<a href="https://github.com/Google-Health/imaging-research/tree/master/derm-foundation">Derm Foundation</a> is an embedding tool derived from our research in applying DL to <a href="https://blog.research.google/2019/09/using-deep-learning-to-inform.html">interpret images of dermatology conditions</a> and includes our recent work that adds <a href="https://arxiv.org/abs/2402.15566">improvements to generalize better to new datasets</a>. Due to its dermatology-specific pre-training it has a latent understanding of features present in images of skin conditions and can be used to quickly develop models to classify skin conditions. The model underlying the API is a <a href="https://github.com/google-research/big_transfer">BiT ResNet-101x3</a> trained in two stages. The first pre-training stage uses contrastive learning, similar to <a href="https://arxiv.org/abs/2010.00747">ConVIRT</a>, to train on a large number of image-text pairs <a href="https://blog.research.google/2017/07/revisiting-unreasonable-effectiveness.html">from the internet</a>. In the second stage, the image component of this pre-trained model is then fine-tuned for condition classification using clinical datasets, such as those from teledermatology services.
</p>
<p>
Unlike histopathology images, dermatology images more closely resemble the real-world images used to train many of today's computer vision models. However, for specialized dermatology tasks, creating a high-quality model may still require a large dataset. With Derm Foundation, researchers can use their own smaller dataset to retrieve domain-specific embeddings, and use those to build smaller models (e.g., linear classifiers or other small non-linear models) that enable them to validate their research or product ideas. To evaluate this approach, we trained models on a downstream task using teledermatology data. Model training involved varying dataset sizes (12.5%, 25%, 50%, 100%) to compare embedding-based linear classifiers against fine-tuning.
</p>
<p>
The modeling variants considered were:
</p>
<ul>
<li>A linear classifier on frozen embeddings from <a href="https://github.com/google-research/big_transfer">BiT-M</a> (a standard pre-trained image model)
</li><li>Fine-tuned version of BiT-M with an extra dense layer for the downstream task
</li><li>A linear classifier on frozen embeddings from the Derm Foundation API
</li><li>Fine-tuned version of the model underlying the Derm Foundation API with an extra layer for the downstream task
</li>
</ul>
<p>
We found that models built on top of the Derm Foundation embeddings for dermatology-related tasks achieved significantly higher quality than those built solely on embeddings or fine tuned from BiT-M. This advantage was found to be most pronounced for smaller training dataset sizes.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3cFSDBqVdsZm4MaFhMXli6kEJazYEB4xEYPB6ebOPv24HPd57Puw1zfu85raJ0gqfpnwsLW99Wh6aShuoCKZNYLw1PiG7eIqUEm8nMvwTy2qQTNL8ptn7cqBll127x_iEIsDMjznY5pWRIYF89cvBP3uPiVfMTgJS8aQpXiOC3oCO1Xl8CxTc4LXrLnjY/s1240/Path%20+%20Derm%20task%20accuracy.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="842" data-original-width="1240" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3cFSDBqVdsZm4MaFhMXli6kEJazYEB4xEYPB6ebOPv24HPd57Puw1zfu85raJ0gqfpnwsLW99Wh6aShuoCKZNYLw1PiG7eIqUEm8nMvwTy2qQTNL8ptn7cqBll127x_iEIsDMjznY5pWRIYF89cvBP3uPiVfMTgJS8aQpXiOC3oCO1Xl8CxTc4LXrLnjY/s16000/Path%20+%20Derm%20task%20accuracy.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">These results demonstrate that the Derm Foundation tooI can serve as a useful starting point to accelerate skin-related modeling tasks. We aim to enable other researchers to build on the underlying features and representations of dermatology that the model has learned. </td></tr></tbody></table>
<p>
However, there are limitations with this analysis. We're still exploring how well these embeddings generalize across task types, patient populations, and image settings. Downstream models built using Derm Foundation still require careful evaluation to understand their expected performance in the intended setting.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Access Path and Derm Foundation</h2>
<p>
We envision that the Derm Foundation and Path Foundation embedding tools will enable a range of use cases, including efficient development of models for diagnostic tasks, quality assurance and pre-analytical workflow improvements, image indexing and curation, and biomarker discovery and validation. We are releasing both tools to the research community so they can explore the utility of the embeddings for their own dermatology and pathology data.
</p>
<p>
To get access, please sign up to each tool's terms of service using the following Google Forms.
</p>
<ul>
<li><a href="https://docs.google.com/forms/d/e/1FAIpQLSe5icNBzU_lO2CwjLLIOwbqIcWnJC-m4Sl7MgvI9Lng3QT6Zg/viewform?resourcekey=0-dahJtiVe2CqYkNEdWPcXgw">Derm Foundation Access Form</a>
</li><li><a href="https://docs.google.com/forms/d/1auyo2VkzlzuiAXavZy1AWUyQHAqO7T3BLK-7ofKUvug/edit?resourcekey=0-Z9pRxjDI-kaDEUIiNfMAWQ#question=1168037695&field=173852432">Path Foundation Access Form</a>
</li>
</ul>
<p>
After gaining access to each tool, you can use the API to retrieve embeddings from dermatology images or digital pathology images stored in Google Cloud. Approved users who are just curious to see the model and embeddings in action can use the provided example Colab notebooks to train models using public data for classifying <a href="https://github.com/Google-Health/imaging-research/blob/master/derm-foundation/derm_foundation_demo.ipynb">six common skin conditions</a> or identifying tumors in <a href="https://github.com/Google-Health/imaging-research/blob/master/path-foundation/linear-classifier-demo.ipynb">histopathology patches</a>. We look forward to seeing the range of use-cases these tools can unlock.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>We would like to thank the many collaborators who helped make this work possible including Yun Liu, Can Kirmizi, Fereshteh Mahvar, Bram Sterling, Arman Tajback, Kenneth Philbrik, Arnav Agharwal, Aurora Cheung, Andrew Sellergren, Boris Babenko, Basil Mustafa, Jan Freyberg, Terry Spitz, Yuan Liu, Pinal Bavishi, Ayush Jain, Amit Talreja, Rajeev Rikhye, Abbi Ward, Jeremy Lai, Faruk Ahmed, Supriya Vijay,Tiam Jaroensri, Jessica Loo, Saurabh Vyawahare, Saloni Agarwal, Ellery Wulczyn, Jonathan Krause, Fayaz Jamil, Tom Small, Annisah Um'rani, Lauren Winer, Sami Lachgar, Yossi Matias, Greg Corrado, and Dale Webster.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-17653597190684327392024-03-07T10:15:00.000-08:002024-03-07T10:19:31.177-08:00Social learning: Collaborative learning with large language models<span class="byline-author">Posted by Amirkeivan Mohtashami, Research Intern, and Florian Hartmann, Software Engineer, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s1999/image2.png" style="display: none;" />
<p>
Large language models (LLMs) have significantly improved the state of the art for solving tasks specified using natural language, often reaching performance close to that of people. As these models increasingly enable assistive agents, it could be beneficial for them to learn effectively from each other, much like people do in social settings, which would allow LLM-based agents to improve each other’s performance.
</p>
<a name='more'></a>
<p>
To discuss the learning processes of humans, Bandura and Walters <a href="https://books.google.ch/books/about/Social_Learning_Theory.html?id=IXvuAAAAMAAJ&redir_esc=y">described</a> the concept of <em>social learning</em> in 1977, outlining different models of observational learning used by people. One common method of learning from others is through a <em>verbal instruction</em> (e.g., from a teacher) that describes how to engage in a particular behavior. Alternatively, learning can happen through a <em>live model</em> by mimicking a live example of the behavior.
</p>
<p>
Given the success of LLMs mimicking human communication, in our paper “<a href="https://arxiv.org/abs/2312.11441">Social Learning: Towards Collaborative Learning with Large Language Models</a>”, we investigate whether LLMs are able to learn from each other using social learning. To this end, we outline a framework for social learning in which LLMs share knowledge with each other in a privacy-aware manner using natural language. We evaluate the effectiveness of our framework on various datasets, and propose quantitative methods that measure privacy in this setting. In contrast to previous approaches to collaborative learning, such as common <a href="https://blog.research.google/2017/04/federated-learning-collaborative.html">federated learning</a> approaches that often rely on gradients, in our framework, agents teach each other purely using natural language.
</p>
<br />
<h2>Social learning for LLMs</h2>
<p>
To extend social learning to language models, we consider the scenario where a student LLM should learn to solve a task from multiple teacher entities that already know that task. In our paper, we evaluate the student’s performance on a variety of tasks, such as <a href="https://dl.acm.org/doi/10.1145/2034691.2034742">spam detection</a> in short text messages (SMS), solving <a href="https://arxiv.org/abs/2110.14168">grade school math problems</a>, and <a href="https://arxiv.org/abs/1905.10044">answering questions</a> based on a given text.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgAndq_MjAVBs4j3lmxEX71nMrCLpAasklndZyE8F7yj3slyafRsNauzW4yRxI_Ncg7Sp5jllAXpItsjA-BOmdB2O1jP3Awu09-DVRHBE_Urf58yzm5tDBBpM-aibZxmgA9O6CySCCRdSMMqG7vj-OU07jHa0OU0YixCxRB0Q3APMQbn8Vz5rEBp70ZNogH/s900/image3.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="381" data-original-width="900" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgAndq_MjAVBs4j3lmxEX71nMrCLpAasklndZyE8F7yj3slyafRsNauzW4yRxI_Ncg7Sp5jllAXpItsjA-BOmdB2O1jP3Awu09-DVRHBE_Urf58yzm5tDBBpM-aibZxmgA9O6CySCCRdSMMqG7vj-OU07jHa0OU0YixCxRB0Q3APMQbn8Vz5rEBp70ZNogH/s16000/image3.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A visualization of the social learning process: A teacher model provides instructions or few-shot examples to a student model without sharing its private data.</td></tr></tbody></table>
<p>
Language models have shown a remarkable capacity to perform tasks given only a handful of examples–a process called <a href="https://arxiv.org/abs/2005.14165">few-shot learning</a>. With this in mind, we provide human-labeled examples of a task that enables the teacher model to teach it to a student. One of the main use cases of social learning arises when these examples cannot be directly shared with the student due, for example, to privacy concerns.
</p>
<p>
To illustrate this, let’s look at a hypothetical example for a spam detection task. A teacher model is located on device where some users volunteer to mark incoming messages they receive as either “spam” or “not spam”. This is useful data that could help train a student model to differentiate between spam and not spam, but sharing personal messages with other users is a breach of privacy and should be avoided. To prevent this, a social learning process can transfer the knowledge from the teacher model to the student so it learns what spam messages look like without needing to share the user’s personal text messages.
</p>
<p>
We investigate the effectiveness of this social learning approach by analogy with the established human social learning theory that we discussed above. In these experiments, we use <a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/">PaLM 2-S</a> models for both the teacher and the student.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s1999/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1117" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEicN2GYOp9oUj5x0F20i550WuF5KmpD8iRqrdHmJFU_HmkdFY3RBF4mfn_99q8jtEPVm56a4NfjMGFJ79y3rygqjX46h23tlzSDde7iEbp8ytHsPa5-IsNKFFituSoPmtGk666gjyypTvVhhuin8FahZfhWPyDWqF5yWBIQ-Cf_DxQ7vrmTWIkA_tAJtm4v/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A systems view of social learning: At training time, multiple teachers teach the student. At inference time, the student is using what it learned from the teachers.</td></tr></tbody></table>
<br />
<h3>Synthetic examples</h3>
<p>
As a counterpart to the live teaching model described for traditional social learning, we propose a learning method where the teachers generate new synthetic examples for the task and share them with the student. This is motivated by the idea that one can create a new example that is sufficiently different from the original one, but is just as educational. Indeed, we observe that our generated examples are sufficiently different from the real ones to preserve privacy while still enabling performance comparable to that achieved using the original examples.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBGMoLyGVpCFO2DkG61pJJwjfje3CZO9V_5YfK3FJlQrbqD8P1RnBt70-G1p0ifTVZ8hnN0upKFdnbZNkPeKpICUiYU0uoqftlq-1bvLXfwlzPFhsCf4uyD5Z4z_ML44YWVf-pjyWEbgsgKGEp_P5F7QzFH3P5TokVfw1QQhD2dSON4dDp3jXqZTHXYZSd/s1456/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="880" data-original-width="1456" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBGMoLyGVpCFO2DkG61pJJwjfje3CZO9V_5YfK3FJlQrbqD8P1RnBt70-G1p0ifTVZ8hnN0upKFdnbZNkPeKpICUiYU0uoqftlq-1bvLXfwlzPFhsCf4uyD5Z4z_ML44YWVf-pjyWEbgsgKGEp_P5F7QzFH3P5TokVfw1QQhD2dSON4dDp3jXqZTHXYZSd/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The 8 generated examples perform as well as the original data for several tasks (see our <a href="https://arxiv.org/abs/2312.11441">paper</a>).</td></tr></tbody></table>
<p>
We evaluate the efficacy of learning through synthetic examples on our task suite. Especially when the number of examples is high enough, e.g., n = 16, we observe no statistically significant difference between sharing original data and teaching with synthesized data via social learning for the majority of tasks, indicating that the privacy improvement does not have to come at the cost of model quality.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQPNMTVzgQW7O3o7Uz0a42vnT7kBhAjqRg5ZL1UrQVs7H5b5-FGdxJFcBmCGHr8sU3WkHsPKVlsQmVnzW-YAop1plz6oxYvTQyxEirorXE2WyGVfFvdOzAw5ydoMh7WUNykMJqasBqCr3C2n_pwBlAFZLO-WBiS-yXm9ExW_NTTIW8zYvfu17cMU8Y3_tp/s1456/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="880" data-original-width="1456" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQPNMTVzgQW7O3o7Uz0a42vnT7kBhAjqRg5ZL1UrQVs7H5b5-FGdxJFcBmCGHr8sU3WkHsPKVlsQmVnzW-YAop1plz6oxYvTQyxEirorXE2WyGVfFvdOzAw5ydoMh7WUNykMJqasBqCr3C2n_pwBlAFZLO-WBiS-yXm9ExW_NTTIW8zYvfu17cMU8Y3_tp/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Generating 16 instead of just 8 examples further reduces the performance gap relative to the original examples.</td></tr></tbody></table>
<br />
<p>
The one exception is spam detection, for which teaching with synthesized data yields lower accuracy. This may be because the training procedure of current models makes them biased to only generate non-spam examples. In the <a href="https://arxiv.org/abs/2312.11441">paper</a>, we additionally look into aggregation methods for selecting good subsets of examples to use.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Synthetic instruction</h3>
<p>
Given the success of language models in following instructions, the verbal instruction model can also be naturally adapted to language models by having the teachers generate an instruction for the task. Our experiments show that providing such a generated instruction effectively improves performance over zero-shot prompting, reaching accuracies comparable to few-shot prompting with original examples. However, we did find that the teacher model may fail on certain tasks to provide a good instruction, for example due to a complicated formatting requirement of the output.
</p>
<p>
For <a href="https://arxiv.org/abs/1606.06031">Lambada</a>, <a href="https://arxiv.org/abs/2110.14168">GSM8k</a>, and <a href="https://arxiv.org/abs/2005.14165">Random Insertion</a>, providing synthetic examples performs better than providing generated instructions, whereas in the other tasks generated instruction obtains a higher accuracy. This observation suggests that the choice of the teaching model depends on the task at hand, similar to how the most effective method for teaching people varies by task.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlmIYiQiqu5BGxrgWq6kklbYjnf3cEIE8lYcoIDQBYY54-ZQCTO2bm7IwpElQCD9ZX0Kt9_egKLhFjlmQFh-oJejJuLHHFDC-d_FVS9DzxGQNzEHy8nFL6BTs5D0evWbiDFjhy1p2OZ9u-QixTWFfP73SEWa2L5iax9OGFvwfuGvi5bsr2EzCSEUYONJ5r/s1451/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="880" data-original-width="1451" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlmIYiQiqu5BGxrgWq6kklbYjnf3cEIE8lYcoIDQBYY54-ZQCTO2bm7IwpElQCD9ZX0Kt9_egKLhFjlmQFh-oJejJuLHHFDC-d_FVS9DzxGQNzEHy8nFL6BTs5D0evWbiDFjhy1p2OZ9u-QixTWFfP73SEWa2L5iax9OGFvwfuGvi5bsr2EzCSEUYONJ5r/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Depending on the task, generating instructions can work better than generating new examples.</td></tr></tbody></table>
<br />
<h2>Memorization of the private examples</h2>
<p>
We want teachers in social learning to teach the student without revealing specifics from the original data. To quantify how prone this process is to leaking information, we used <a href="https://research.google/pubs/the-secret-sharer-evaluating-and-testing-unintended-memorization-in-neural-networks/">Secret Sharer</a>, a popular method for quantifying to what extent a model memorizes its training data, and adapted it to the social learning setting. We picked this method since it had previously been <a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html">used</a> for evaluating memorization in federated learning.
</p>
<p>
To apply the Secret Sharer method to social learning, we design “canary” data points such that we can concretely measure how much the training process memorized them. These data points are included in the datasets used by teachers to generate new examples. After the social learning process completes, we can then measure how much more confident the student is in the secret data points the teacher used, compared to similar ones that were not shared even with the teachers.
</p>
<p>
In our analysis, discussed in detail in the <a href="https://arxiv.org/abs/2312.11441">paper</a>, we use canary examples that include names and codes. Our results show that the student is only slightly more confident in the canaries the teacher used. In contrast, when the original data points are directly shared with the student, the confidence in the included canaries is much higher than in the held-out set. This supports the conclusion that the teacher does indeed use its data to teach without simply copying it over.
</p>
<br />
<h2>Conclusion and next steps</h2>
<p>
We introduced a framework for social learning that allows language models with access to private data to transfer knowledge through textual communication while maintaining the privacy of that data. In this framework, we identified sharing examples and sharing instructions as basic models and evaluated them on multiple tasks. Furthermore, we adapted the Secret Sharer metric to our framework, proposing a metric for measuring data leakage.
</p>
<p>
As next steps, we are looking for ways of improving the teaching process, for example by adding feedback loops and iteration. Furthermore, we want to investigate using social learning for modalities other than text.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>We would like to acknowledge and thank Matt Sharifi, Sian Gooding, Lukas Zilka, and Blaise Aguera y Arcas, who are all co-authors on the paper. Furthermore, we would like to thank Victor Cărbune, Zachary Garrett, Tautvydas Misiunas, Sofia Neata and John Platt for their feedback, which greatly improved the paper. We’d also like to thank Tom Small for creating the animated figure.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-83932932080187572842024-03-06T10:26:00.000-08:002024-03-06T14:44:03.387-08:00Croissant: a metadata format for ML-ready datasets<span class="byline-author">Posted by Omar Benjelloun, Software Engineer, Google Research, and Peter Mattson, Software Engineer, Google Core ML and President, MLCommons Association</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj09uSTHgWmPgOkD9W1nZZj5i8uW_-pgxm-T1O5PSacF-EKvHIeIwhMr7Rgft7O3A2Rk94GWe8WboO3dUlxrqt1xz9x4I2aMKJxCUtUkR2eukbsIa8xVyAAN_LJJyMABxRqJuktFkyfhoWPDMQK3O-XgbQNJXzAILlWl3su0fd-Q_uZ-8r5r_uAU2P4srnP/s1600/CroissantHero.png" style="display: none;" />
<p>
Machine learning (ML) practitioners looking to reuse existing datasets to train an ML model often spend a lot of time understanding the data, making sense of its organization, or figuring out what subset to use as features. So much time, in fact, that progress in the field of ML is hampered by a fundamental obstacle: the wide variety of data representations.
</p>
<a name='more'></a>
<p>
ML datasets cover a broad range of content types, from text and structured data to images, audio, and video. Even within datasets that cover the same types of content, every dataset has a unique <em>ad hoc</em> arrangement of files and data formats. This challenge reduces productivity throughout the entire ML development process, from finding the data to training the model. It also impedes development of badly needed tooling for working with datasets.
</p>
<p>
There are general purpose metadata formats for datasets such as <a href="http://schema.org/Dataset">schema.org</a> and <a href="https://www.w3.org/TR/vocab-dcat-3/">DCAT</a>. However, these formats were designed for data discovery rather than for the specific needs of ML data, such as the ability to extract and combine data from structured and unstructured sources, to include metadata that would enable <a href="https://ai.google/responsibility/responsible-ai-practices/">responsible use</a> of the data, or to describe ML usage characteristics such as defining training, test and validation sets.
</p>
<p>
Today, we're introducing <a href="https://mlcommons.org/croissant">Croissant</a>, a new metadata format for ML-ready datasets. Croissant was developed collaboratively by a community from industry and academia, as part of the <a href="https://mlcommons.org/">MLCommons</a> effort. The Croissant format doesn't change how the actual data is represented (e.g., image or text file formats) — it provides a standard way to describe and organize it. Croissant builds upon <a href="https://schema.org/">schema.org</a>, the de facto standard for publishing structured data on the Web, which is already used by over 40M datasets. Croissant augments it with comprehensive layers for ML relevant metadata, data resources, data organization, and default ML semantics.
</p>
<p>
In addition, we are announcing support from major tools and repositories: Today, three widely used collections of ML datasets — <a href="http://www.kaggle.com/datasets">Kaggle</a>, <a href="https://huggingface.co/datasets?other=croissant&sort=trending">Hugging Face</a>, and <a href="https://openml.org/search?type=data">OpenML</a> — will begin supporting the Croissant format for the datasets they host; the <a href="http://g.co/datasetsearch">Dataset Search</a> tool lets users search for Croissant datasets across the Web; and popular ML frameworks, including <a href="https://www.tensorflow.org/">TensorFlow</a>, <a href="https://pytorch.org/">PyTorch</a>, and <a href="https://github.com/google/jax">JAX</a>, can load Croissant datasets easily using the <a href="https://www.tensorflow.org/datasets">TensorFlow Datasets</a> (TFDS) package.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>Croissant</h2>
<p>
This 1.0 release of Croissant includes a complete <a href="https://mlcommons.org/croissant/1.0">specification</a> of the format, a set of <a href="https://github.com/mlcommons/croissant/tree/main/datasets">example datasets</a>, an open source <a href="https://github.com/mlcommons/croissant/tree/main/python/mlcroissant">Python library</a> to validate, consume and generate Croissant metadata, and an open source <a href="https://github.com/mlcommons/croissant/tree/main/editor">visual editor</a> to load, inspect and create Croissant dataset descriptions in an intuitive way.
</p>
<p>
Supporting Responsible AI (RAI) was a key goal of the Croissant effort from the start. We are also releasing the first version of the <a href="https://mlcommons.org/croissant/RAI/1.0">Croissant RAI vocabulary</a> extension, which augments Croissant with key properties needed to describe important RAI use cases such as data life cycle management, data labeling, participatory data, ML safety and fairness evaluation, explainability, and compliance.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>Why a shared format for ML data?</h2>
<p>
The majority of ML work is actually data work. The training data is the “code” that determines the behavior of a model. Datasets can vary from a collection of text used to train a large language model (LLM) to a collection of driving scenarios (annotated videos) used to train a car’s collision avoidance system. However, the steps to develop an ML model typically follow the same iterative data-centric process: (1) find or collect data, (2) clean and refine the data, (3) train the model on the data, (4) test the model on more data, (5) discover the model does not work, (6) analyze the data to find out why, (7) repeat until a workable model is achieved. Many steps are made harder by the lack of a common format. This “data development burden” is especially heavy for resource-limited research and early-stage entrepreneurial efforts.
</p>
<p>
The goal of a format like Croissant is to make this entire process easier. For instance, the metadata can be leveraged by search engines and dataset repositories to make it easier to find the right dataset. The data resources and organization information make it easier to develop tools for cleaning, refining, and analyzing data. This information and the default ML semantics make it possible for ML frameworks to use the data to train and test models with a minimum of code. Together, these improvements substantially reduce the data development burden.
</p>
<p>
Additionally, dataset authors care about the discoverability and ease of use of their datasets. Adopting Croissant improves the value of their datasets, while only requiring a minimal effort, thanks to the available creation tools and support from ML data platforms.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>What can Croissant do today?</h2>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgN40ZSjgTFRIVwAwN2OXIn4vQhmshC8VhcKx-ijY-sCQBH9qDkV3nrFz_YapZ0iAD-Svkyxblt6lpJFFHa4JfDqfY6RIL0RnVhtgBlLyh-1DnH8DUz7-TUSdSUIg5V2piqjmQ5Dw9MISeeSBvnMsie8jRrXOeHXfcTGQi0AHIeOYFuHYwDFSyRmBT8BHum/s908/image1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="540" data-original-width="908" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgN40ZSjgTFRIVwAwN2OXIn4vQhmshC8VhcKx-ijY-sCQBH9qDkV3nrFz_YapZ0iAD-Svkyxblt6lpJFFHa4JfDqfY6RIL0RnVhtgBlLyh-1DnH8DUz7-TUSdSUIg5V2piqjmQ5Dw9MISeeSBvnMsie8jRrXOeHXfcTGQi0AHIeOYFuHYwDFSyRmBT8BHum/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The Croissant ecosystem: Users can Search for Croissant datasets, download them from major repositories, and easily load them into their favorite ML frameworks. They can create, inspect and modify Croissant metadata using the Croissant editor.</td></tr></tbody></table>
<p>
Today, users can find Croissant datasets at:
</p>
<ul>
<li>Google <a href="https://datasetsearch.research.google.com/">Dataset Search</a>, which offers a Croissant filter.
</li><li><a href="https://huggingface.co/datasets?other=croissant&sort=trending">HuggingFace</a>
</li><li><a href="http://kaggle.com/datasets">Kaggle</a>
</li><li><a href="https://openml.org/search?type=data">OpenML</a>
</li>
</ul>
<p>
With a Croissant dataset, it is possible to:
</p>
<ul>
<li>Ingest data easily via <a href="https://www.tensorflow.org/datasets">TensorFlow Datasets</a> for use in popular ML frameworks like <a href="https://www.tensorflow.org/">TensorFlow</a>, <a href="https://pytorch.org/">PyTorch</a>, and <a href="https://github.com/google/jax">JAX</a>.
</li><li>Inspect and modify the metadata using the <a href="https://huggingface.co/spaces/MLCommons/croissant-editor">Croissant editor UI</a> (<a href="https://github.com/mlcommons/croissant/tree/main/editor">github</a>).
</li>
</ul>
<p>
To publish a Croissant dataset, users can:
</p>
<ul>
<li>Use the <a href="https://huggingface.co/spaces/MLCommons/croissant-editor">Croissant editor UI</a> (<a href="https://github.com/mlcommons/croissant/tree/main/editor">github</a>) to generate a large portion of Croissant metadata automatically by analyzing the data the user provides, and to fill important metadata fields such as RAI properties.
</li><li>Publish the Croissant information as part of their dataset Web page to make it discoverable and reusable.
</li><li>Publish their data in one of the repositories that support Croissant, such as Kaggle, HuggingFace and OpenML, and automatically generate Croissant metadata.
</li>
</ul>
<div style="line-height:40%;">
<br>
</div>
<h2>Future direction</h2>
<p>
We are excited about Croissant's potential to help ML practitioners, but making this format truly useful requires the support of the community. We encourage dataset creators to consider providing Croissant metadata. We encourage platforms hosting datasets to provide Croissant files for download and embed Croissant metadata in dataset Web pages so that they can be made discoverable by dataset search engines. Tools that help users work with ML datasets, such as labeling or data analysis tools should also consider supporting Croissant datasets. Together, we can reduce the data development burden and enable a richer ecosystem of ML research and development.
</p>
<p>
We encourage the community to <a href="http://mlcommons.org/croissant">join us</a> in contributing to the effort.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>Acknowledgements</h2>
<p>
<em>Croissant was developed by the <a href="https://datasetsearch.research.google.com/">Dataset Search</a>, <a href="https://www.kaggle.com/">Kaggle</a> and <a href="https://www.tensorflow.org/datasets">TensorFlow Datasets</a> teams from Google, as part of an <a href="http://mlcommons.org">MLCommons</a> community working group, which also includes contributors from these organizations: Bayer, cTuning Foundation, DANS-KNAW, Dotphoton, Harvard, Hugging Face, Kings College London, LIST, Meta, NASA, North Carolina State University, Open Data Institute, Open University of Catalonia, Sage Bionetworks, and TU Eindhoven.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-27545267824972474972024-03-04T07:06:00.000-08:002024-03-05T08:40:45.490-08:00Google at APS 2024<span class="byline-author">Posted by Kate Weber and Shannon Leon, Google Research, Quantum AI Team</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjy22Hfq3RN4qRUJcSMUpIau4ueOIcQ219mDvfu4FNJ9kf5PBMUI0x4Uf9BhoIHtnFUhtvE72GCVYixldOZRSeePJfef0P87Pc_djQeGIZOhyxv9nKsQCc57357tr3npWdS5fyWxiGjex4NxMpOIB2JE1Z2qXdLnzLkFM075WstFJD77xVNS2T9hckWZyLf/s1600/lockup_GoogleResearch_FullColor_Hero.jpg" style="display: none;" />
<p>
Today the <a href="https://www.aps.org/meetings/meeting.cfm?name=MAR24">2024 March Meeting</a> of the <a href="https://www.aps.org/">American Physical Society</a> (APS) kicks off in Minneapolis, MN. A premier conference on topics ranging across physics and related fields, APS 2024 brings together researchers, students, and industry professionals to share their discoveries and build partnerships with the goal of realizing fundamental advances in physics-related sciences and technology.
</p>
<a name='more'></a>
<p>
This year, Google has a strong presence at APS with a booth hosted by the Google <a href="https://quantumai.google/">Quantum AI</a> team, 50+ talks throughout the conference, and participation in conference organizing activities, special sessions and events. Attending APS 2024 in person? Come visit Google’s Quantum AI booth to learn more about the exciting work we’re doing to solve some of the field’s most interesting challenges. <!--Visit the <a href="https://twitter.com/GoogleAI">@GoogleAI</a> X (Twitter) account to find out about Google booth activities (e.g., demos and Q&A sessions).-->
</p>
<p>
You can learn more about the latest cutting edge work we are presenting at the conference along with our schedule of booth events below (Googlers listed in <strong>bold</strong>).
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Organizing Committee</h2>
<div style="margin-left: 20px;">
<p>
Session Chairs include: <strong>Aaron Szasz</strong>
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h2>Booth Activities</h2>
<div style="margin-left: 20px;">
<p>
<em>This schedule is subject to change. Please visit the Google Quantum AI booth for more information.</em>
</p>
<p>
Crumble: A prototype interactive tool for visualizing QEC circuits
<br />
Presenter: <strong>Matt McEwen</strong>
<br />
Tue, Mar 5 | 11:00 AM CST
</p>
<p>
Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms
<br />
Presenter: <strong>Tanuj Khattar</strong>
<br />
Tue, Mar 5 | 2:30 PM CST
</p>
<p>
Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms
<br />
Presenter: <strong>Tanuj Khattar</strong>
<br />
Thu, Mar 7 | 11:00 AM CST
</p>
<p>
$5M XPRIZE / Google Quantum AI competition to accelerate quantum applications Q&A
<br />
Presenter: <strong>Ryan Babbush</strong>
<br />
Thu, Mar 7 | 11:00 AM CST
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h2>Talks</h2>
<h3>Monday</h3>
<div style="margin-left: 20px;">
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/A45.1">Certifying highly-entangled states from few single-qubit measurements</a>
<br />
Presenter: <strong>Hsin-Yuan Huang</strong>
<br />
Author: <strong>Hsin-Yuan Huang</strong>
<br />
<em>Session A45: New Frontiers in Machine Learning Quantum Physics</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/A51.2">Toward high-fidelity analog quantum simulation with superconducting qubits</a>
<br />
Presenter: <strong>Trond Andersen</strong>
<br />
Authors: <strong>Trond I Andersen</strong>, <strong>Xiao Mi</strong>, <strong>Amir H Karamlou</strong>, <strong>Nikita Astrakhantsev</strong>, <strong>Andrey Klots</strong>, <strong>Julia Berndtsson</strong>, <strong>Andre Petukhov</strong>, <strong>Dmitry Abanin</strong>, <strong>Lev B Ioffe</strong>, <strong>Yu Chen</strong>, <strong>Vadim Smelyanskiy</strong>, <strong>Pedram Roushan</strong>
<br />
<em>Session A51: Applications on Noisy Quantum Hardware I</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B50.6">Measuring circuit errors in context for surface code circuits</a>
<br />
Presenter: <strong>Dripto M Debroy</strong>
<br />
Authors: <strong>Dripto M Debroy</strong>, <strong>Jonathan A Gross</strong>, <strong>Élie Genois</strong>, <strong>Zhang Jiang</strong>
<br />
<em>Session B50: Characterizing Noise with QCVV Techniques</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B51.6">Quantum computation of stopping power for inertial fusion target design I: Physics overview and the limits of classical algorithms</a>
<br />
Presenter: Andrew D. Baczewski
<br />
Authors: <strong>Nicholas C. Rubin</strong>, Dominic W. Berry, Alina Kononov, <strong>Fionn D. Malone</strong>, <strong>Tanuj Khattar</strong>, Alec White, <strong>Joonho Lee</strong>, <strong>Hartmut Neven</strong>, <strong>Ryan Babbush</strong>, Andrew D. Baczewski
<br />
<em>Session B51: Heterogeneous Design for Quantum Applications</em>
<br />
<a href="https://arxiv.org/pdf/2308.12352.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B51.7">Quantum computation of stopping power for inertial fusion target design II: Physics overview and the limits of classical algorithms</a>
<br />
Presenter: <strong>Nicholas C. Rubin</strong>
<br />
Authors: <strong>Nicholas C. Rubin</strong>, Dominic W. Berry, Alina Kononov, <strong>Fionn D. Malone</strong>, <strong>Tanuj Khattar</strong>, Alec White, <strong>Joonho Lee</strong>, <strong>Hartmut Neven</strong>, <strong>Ryan Babbush</strong>, Andrew D. Baczewski
<br />
<em>Session B51: Heterogeneous Design for Quantum Applications</em>
<br />
<a href="https://arxiv.org/pdf/2308.12352.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B56.4">Calibrating Superconducting Qubits: From NISQ to Fault Tolerance</a>
<br />
Presenter: <strong>Sabrina S Hong</strong>
<br />
Author: <strong>Sabrina S Hong</strong>
<br />
<em>Session B56: From NISQ to Fault Tolerance</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B31.9">Measurement and feedforward induced entanglement negativity transition</a>
<br />
Presenter: <strong>Ramis Movassagh</strong>
<br />
Authors: Alireza Seif, Yu-Xin Wang,<strong> Ramis Movassagh</strong>, Aashish A. Clerk
<br />
<em>Session B31: Measurement Induced Criticality in Many-Body Systems</em>
<br />
<a href="https://arxiv.org/pdf/2310.18305.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/B52.9">Effective quantum volume, fidelity and computational cost of noisy quantum processing experiments</a>
<br />
Presenter: <strong>Salvatore Mandra</strong>
<br />
Authors: <strong>Kostyantyn Kechedzhi</strong>, <strong>Sergei V Isakov</strong>, <strong>Salvatore Mandra</strong>, <strong>Benjamin Villalonga</strong>, <strong>X. Mi</strong>, <strong>Sergio Boixo</strong>, <strong>Vadim Smelyanskiy</strong>
<br />
<em>Session B52: Quantum Algorithms and Complexity</em>
<br />
<a href="https://arxiv.org/pdf/2306.15970.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/D60.4">Accurate thermodynamic tables for solids using Machine Learning Interaction Potentials and Covariance of Atomic Positions</a>
<br />
Presenter: Mgcini K Phuthi
<br />
Authors: Mgcini K Phuthi, Yang Huang, Michael Widom, <strong>Ekin D Cubuk</strong>, Venkat Viswanathan
<br />
<em>Session D60: Machine Learning of Molecules and Materials: Chemical Space and Dynamics</em>
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h3>Tuesday</h3>
<div style="margin-left: 20px;">
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/F50.4">IN-Situ Pulse Envelope Characterization Technique (INSPECT)</a>
<br />
Presenter: <strong>Zhang Jiang</strong>
<br />
Authors: <strong>Zhang Jiang</strong>, <strong>Jonathan A Gross</strong>, <strong>Élie Genois</strong>
<br />
<em>Session F50: Advanced Randomized Benchmarking and Gate Calibration</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/F50.11">Characterizing two-qubit gates with dynamical decoupling</a>
<br />
Presenter: <strong>Jonathan A Gross</strong>
<br />
Authors: <strong>Jonathan A Gross</strong>, <strong>Zhang Jiang</strong>, <strong>Élie Genois, Dripto M Debroy</strong>, Ze-Pei Cian*, <strong>Wojciech Mruczkiewicz</strong>
<br />
<em>Session F50: Advanced Randomized Benchmarking and Gate Calibration</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/EE01.2">Statistical physics of regression with quadratic models</a>
<br />
Presenter: Blake Bordelon
<br />
Authors: Blake Bordelon, Cengiz Pehlevan, <strong>Yasaman Bahri</strong>
<br />
<em>Session EE01: V: Statistical and Nonlinear Physics II</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/G51.2">Improved state preparation for first-quantized simulation of electronic structure</a>
<br />
Presenter: <strong>William J Huggins</strong>
<br />
Authors: <strong>William J Huggins</strong>, <strong>Oskar Leimkuhler</strong>, <strong>Torin F Stetina</strong>, <strong>Birgitta Whaley</strong>
<br />
<em>Session G51: Hamiltonian Simulation</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/G30.2">Controlling large superconducting quantum processors</a>
<br />
Presenter: <strong>Paul V. Klimov</strong>
<br />
Authors: <strong>Paul V. Klimov</strong>, <strong>Andreas Bengtsson</strong>, <strong>Chris Quintana</strong>, <strong>Alexandre Bourassa</strong>, <strong>Sabrina Hong</strong>, <strong>Andrew Dunsworth</strong>, <strong>Kevin J. Satzinger</strong>, <strong>William P. Livingston</strong>, <strong>Volodymyr Sivak</strong>, <strong>Murphy Y. Niu</strong>, <strong>Trond I. Andersen</strong>, <strong>Yaxing Zhang</strong>, <strong>Desmond Chik</strong>, <strong>Zijun Chen</strong>, <strong>Charles Neill</strong>, <strong>Catherine Erickson</strong>, <strong>Alejandro Grajales Dau</strong>, <strong>Anthony Megrant</strong>, <strong>Pedram Roushan</strong>, <strong>Alexander N. Korotkov</strong>, <strong>Julian Kelly</strong>, <strong>Vadim Smelyanskiy</strong>, <strong>Yu Chen</strong>, <strong>Hartmut Neven</strong>
<br />
<em>Session G30: Commercial Applications of Quantum Computing</em><br />
<a href="https://arxiv.org/pdf/2308.02321.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/G50.5">Gaussian boson sampling: Determining quantum advantage</a>
<br />
Presenter: Peter D Drummond
<br />
Authors: Peter D Drummond, Alex Dellios, Ned Goodman, Margaret D Reid, <strong>Ben Villalonga</strong>
<br />
<em>Session G50: Quantum Characterization, Verification, and Validation II</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/G50.8">Attention to complexity III: learning the complexity of random quantum circuit states</a>
<br />
Presenter: Hyejin Kim
<br />
Authors: Hyejin Kim, Yiqing Zhou, Yichen Xu, Chao Wan, Jin Zhou, <strong>Yuri D Lensky</strong>, Jesse Hoke, <strong>Pedram Roushan</strong>, Kilian Q Weinberger, Eun-Ah Kim
<br />
<em>Session G50: Quantum Characterization, Verification, and Validation II</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/K48.10">Balanced coupling in superconducting circuits</a>
<br />
Presenter: <strong>Daniel T Sank</strong>
<br />
Authors: <strong>Daniel T Sank</strong>, <strong>Sergei V Isakov</strong>, <strong>Mostafa Khezri</strong>, <strong>Juan Atalaya</strong>
<br />
<em>Session K48: Strongly Driven Superconducting Systems</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/K49.12">Resource estimation of Fault Tolerant algorithms using Qᴜᴀʟᴛʀᴀɴ</a>
<br />
Presenter: <strong>Tanuj Khattar</strong>
<br />
Author: <strong>Tanuj Khattar</strong>, <b>Matthew Harrigan</b>, <b>Fionn D. Malone</b>, <b>Nour Yosri</b>, <b>Nicholas C. Rubin</b><br />
<em>Session K49: Algorithms and Implementations on Near-Term Quantum Computers</em>
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h3>Wednesday</h3>
<div style="margin-left: 20px;">
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/M24.1">Discovering novel quantum dynamics with superconducting qubits</a>
<br />
Presenter: <strong>Pedram Roushan</strong>
<br />
Author: <strong>Pedram Roushan</strong>
<br />
<em>Session M24: Analog Quantum Simulations Across Platforms</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/M27.7">Deciphering Tumor Heterogeneity in Triple-Negative Breast Cancer: The Crucial Role of Dynamic Cell-Cell and Cell-Matrix Interactions</a>
<br />
Presenter: Susan Leggett
<br />
Authors: Susan Leggett, Ian Wong, Celeste Nelson, Molly Brennan, <strong>Mohak Patel</strong>, Christian Franck, Sophia Martinez, Joe Tien, Lena Gamboa, Thomas Valentin, Amanda Khoo, Evelyn K Williams
<br />
<em>Session M27: Mechanics of Cells and Tissues II</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/N48.2">Toward implementation of protected charge-parity qubits</a>
<br />
Presenter: Abigail Shearrow
<br />
Authors: Abigail Shearrow, Matthew Snyder, Bradley G Cole, Kenneth R Dodge, Yebin Liu, Andrey Klots, <strong>Lev B Ioffe</strong>, Britton L Plourde, Robert McDermott
<br />
<em>Session N48: Unconventional Superconducting Qubits</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/N48.3">Electronic capacitance in tunnel junctions for protected charge-parity qubits</a>
<br />
Presenter: Bradley G Cole
<br />
Authors: Bradley G Cole, Kenneth R Dodge, Yebin Liu, Abigail Shearrow, Matthew Snyder, <strong>Andrey Klots</strong>, <strong>Lev B Ioffe</strong>, Robert McDermott, B.L.T. Plourde
<br />
<em>Session N48: Unconventional Superconducting Qubits</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.7">Overcoming leakage in quantum error correction</a>
<br />
Presenter: <strong>Kevin C. Miao</strong>
<br />
Authors: <strong>Kevin C. Miao</strong>, <strong>Matt McEwen</strong>, <strong>Juan Atalaya</strong>, <strong>Dvir Kafri</strong>, <strong>Leonid P. Pryadko</strong>, <strong>Andreas Bengtsson</strong>, <strong>Alex Opremcak</strong>, <strong>Kevin J. Satzinger</strong>, <strong>Zijun Chen</strong>, <strong>Paul V. Klimov</strong>, <strong>Chris Quintana</strong>, <strong>Rajeev Acharya</strong>, <strong>Kyle Anderson</strong>, <strong>Markus Ansmann</strong>, <strong>Frank Arute</strong>, <strong>Kunal Arya</strong>, <strong>Abraham Asfaw</strong>, <strong>Joseph C. Bardin</strong>, <strong>Alexandre Bourassa</strong>, <strong>Jenna Bovaird</strong>, <strong>Leon Brill</strong>, <strong>Bob B. Buckley</strong>, <strong>David A. Buell</strong>, <strong>Tim Burger</strong>, <strong>Brian Burkett</strong>, <strong>Nicholas Bushnell</strong>, <strong>Juan Campero</strong>, <strong>Ben Chiaro</strong>, <strong>Roberto Collins</strong>, <strong>Paul Conner</strong>, <strong>Alexander L. Crook</strong>, <strong>Ben Curtin</strong>, <strong>Dripto M. Debroy</strong>, <strong>Sean Demura</strong>, <strong>Andrew Dunsworth</strong>, <strong>Catherine Erickson</strong>, <strong>Reza Fatemi</strong>, <strong>Vinicius S. Ferreira</strong>, <strong>Leslie Flores Burgos</strong>, <strong>Ebrahim Forati</strong>, <strong>Austin G. Fowler</strong>, <strong>Brooks Foxen</strong>, <strong>Gonzalo Garcia</strong>, <strong>William Giang</strong>, <strong>Craig Gidney</strong>, <strong>Marissa Giustina</strong>, <strong>Raja Gosula</strong>, <strong>Alejandro Grajales Dau</strong>, <strong>Jonathan A. Gross</strong>, <strong>Michael C. Hamilton</strong>, <strong>Sean D. Harrington</strong>, <strong>Paula Heu</strong>, <strong>Jeremy Hilton</strong>, <strong>Markus R. Hoffmann</strong>, <strong>Sabrina Hong</strong>, <strong>Trent Huang</strong>, <strong>Ashley Huff</strong>, <strong>Justin Iveland</strong>, <strong>Evan Jeffrey</strong>, <strong>Zhang Jiang</strong>, <strong>Cody Jones</strong>, <strong>Julian Kelly</strong>, <strong>Seon Kim</strong>, <strong>Fedor Kostritsa</strong>, <strong>John Mark Kreikebaum</strong>, <strong>David Landhuis</strong>, <strong>Pavel Laptev</strong>, <strong>Lily Laws</strong>, <strong>Kenny Lee</strong>, <strong>Brian J. Lester</strong>, <strong>Alexander T. Lill</strong>, <strong>Wayne Liu</strong>, <strong>Aditya Locharla</strong>, <strong>Erik Lucero</strong>, <strong>Steven Martin</strong>, <strong>Anthony Megrant</strong>, <strong>Xiao Mi</strong>, <strong>Shirin Montazeri</strong>, <strong>Alexis Morvan</strong>, <strong>Ofer Naaman</strong>, <strong>Matthew Neeley</strong>, <strong>Charles Neill</strong>, <strong>Ani Nersisyan</strong>, <strong>Michael Newman</strong>, <strong>Jiun How Ng</strong>, <strong>Anthony Nguyen</strong>, <strong>Murray Nguyen</strong>, <strong>Rebecca Potter</strong>, <strong>Charles Rocque</strong>, <strong>Pedram Roushan</strong>, <strong>Kannan Sankaragomathi</strong>, <strong>Christopher Schuster</strong>, <strong>Michael J. Shearn</strong>, <strong>Aaron Shorter</strong>, <strong>Noah Shutty</strong>, <strong>Vladimir Shvarts</strong>, <strong>Jindra Skruzny</strong>, <strong>W. Clarke Smith</strong>, <strong>George Sterling</strong>, <strong>Marco Szalay</strong>, <strong>Douglas Thor</strong>, <strong>Alfredo Torres</strong>, <strong>Theodore White</strong>, <strong>Bryan W. K. Woo</strong>, <strong>Z. Jamie Yao</strong>, <strong>Ping Yeh</strong>, <strong>Juhwan Yoo</strong>, <strong>Grayson Young</strong>, <strong>Adam Zalcman</strong>, <strong>Ningfeng Zhu</strong>, <strong>Nicholas Zobrist</strong>, <strong>Hartmut Neven</strong>, <strong>Vadim Smelyanskiy</strong>, <strong>Andre Petukhov</strong>, <strong>Alexander N. Korotkov</strong>, <strong>Daniel Sank</strong>, <strong>Yu Chen</strong>
<br />
<em>Session N51: Quantum Error Correction Code Performance and Implementation I</em>
<br />
<a href="https://www.nature.com/articles/s41567-023-02226-w">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.11">Modeling the performance of the surface code with non-uniform error distribution: Part 1</a>
<br />
Presenter: <strong>Yuri D Lensky</strong>
<br />
Authors: <strong>Yuri D Lensky</strong>, <strong>Volodymyr Sivak</strong>, <strong>Kostyantyn Kechedzhi</strong>, <strong>Igor Aleiner</strong>
<br />
<em>Session N51: Quantum Error Correction Code Performance and Implementation I</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/N51.12">Modeling the performance of the surface code with non-uniform error distribution: Part 2</a>
<br />
Presenter: <strong>Volodymyr Sivak</strong>
<br />
Authors: <strong>Volodymyr Sivak</strong>, <strong>Michael Newman</strong>, <strong>Cody Jones</strong>, <strong>Henry Schurkus</strong>, <strong>Dvir Kafri</strong>, <strong>Yuri D Lensky</strong>, <strong>Paul Klimov</strong>, <strong>Kostyantyn Kechedzhi</strong>, <strong>Vadim Smelyanskiy</strong>
<br />
<em>Session N51: Quantum Error Correction Code Performance and Implementation I</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Q51.7">Highly optimized tensor network contractions for the simulation of classically challenging quantum computations</a>
<br />
Presenter: <strong>Benjamin Villalonga</strong>
<br />
Author: <strong>Benjamin Villalonga</strong>
<br />
<em>Session Q51: Co-evolution of Quantum Classical Algorithms</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Q61.7">Teaching modern quantum computing concepts using hands-on open-source software at all levels</a>
<br />
Presenter: <strong>Abraham Asfaw</strong>
<br />
Author: <strong>Abraham Asfaw</strong>
<br />
<em>Session Q61: Teaching Quantum Information at All Levels II</em>
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h3>Thursday</h3>
<div style="margin-left: 20px;">
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/S51.1">New circuits and an open source decoder for the color code</a>
<br />
Presenter: <strong>Craig Gidney</strong>
<br />
Authors: <strong>Craig Gidney</strong>, <strong>Cody Jones</strong>
<br />
<em>Session S51: Quantum Error Correction Code Performance and Implementation II</em>
<br />
<a href="https://arxiv.org/pdf/2312.08813.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/S18.2">Performing Hartree-Fock many-body physics calculations with large language models</a>
<br />
Presenter: <strong>Eun-Ah Kim</strong>
<br />
Authors: <strong>Eun-Ah Kim</strong>, Haining Pan, <strong>Nayantara Mudur</strong>, William Taranto,<strong> Subhashini Venugopalan</strong>, <strong>Yasaman Bahri</strong>, <strong>Michael P Brenner</strong>
<br />
<em>Session S18: Data Science, AI and Machine Learning in Physics I</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/S51.5">New methods for reducing resource overhead in the surface code</a>
<br />
Presenter: <strong>Michael Newman</strong>
<br />
Authors: <strong>Craig M Gidney</strong>, <strong>Michael Newman</strong>, <strong>Peter Brooks</strong>, <strong>Cody Jones</strong>
<br />
<em>Session S51: Quantum Error Correction Code Performance and Implementation II</em>
<br />
<a href="https://arxiv.org/pdf/2312.04522.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/S49.10">Challenges and opportunities for applying quantum computers to drug design</a>
<br />
Presenter: Raffaele Santagati
<br />
Authors: Raffaele Santagati, Alan Aspuru-Guzik, <strong>Ryan Babbush</strong>, Matthias Degroote, Leticia Gonzalez, Elica Kyoseva, Nikolaj Moll, Markus Oppel, Robert M. Parrish, <strong>Nicholas C. Rubin</strong>, Michael Streif, Christofer S. Tautermann, Horst Weiss, Nathan Wiebe, Clemens Utschig-Utschig
<br />
<em>Session S49: Advances in Quantum Algorithms for Near-Term Applications</em>
<br />
<a href="https://arxiv.org/pdf/2301.04114.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/T45.1">Dispatches from Google's hunt for super-quadratic quantum advantage in new applications</a>
<br />
Presenter: <strong>Ryan Babbush</strong>
<br />
Author: <strong>Ryan Babbush</strong>
<br />
<em>Session T45: Recent Advances in Quantum Algorithms</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/T48.11">Qubit as a reflectometer</a>
<br />
Presenter: <strong>Yaxing Zhang</strong>
<br />
Authors: <strong>Yaxing Zhang</strong>, <strong>Benjamin Chiaro</strong>
<br />
<em>Session T48: Superconducting Fabrication, Packaging, & Validation</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/W14.3">Random-matrix theory of measurement-induced phase transitions in nonlocal Floquet quantum circuits</a>
<br />
Presenter: Aleksei Khindanov
<br />
Authors: Aleksei Khindanov, <strong>Lara Faoro</strong>, <strong>Lev Ioffe</strong>, <strong>Igor Aleiner</strong>
<br />
<em>Session W14: Measurement-Induced Phase Transitions</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/W58.5">Continuum limit of finite density many-body ground states with MERA</a>
<br />
Presenter: Subhayan Sahu
<br />
Authors: Subhayan Sahu, <strong>Guifré Vidal</strong>
<br />
<em>Session W58: Extreme-Scale Computational Science Discovery in Fluid Dynamics and Related Disciplines II</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/W50.8">Dynamics of magnetization at infinite temperature in a Heisenberg spin chain</a>
<br />
Presenter: <strong>Eliott Rosenberg</strong>
<br />
Authors: <strong>Eliott Rosenberg</strong>, <strong>Trond Andersen</strong>, Rhine Samajdar, <strong>Andre Petukhov</strong>, Jesse Hoke*,<strong> Dmitry Abanin</strong>, <strong>Andreas Bengtsson</strong>, <strong>Ilya Drozdov</strong>, <strong>Catherine Erickson</strong>,<strong> Paul Klimov</strong>, <strong>Xiao Mi</strong>, <strong>Alexis Morvan</strong>, <strong>Matthew Neeley</strong>, <strong>Charles Neill</strong>, <strong>Rajeev Acharya</strong>, <strong>Richard Allen</strong>, <strong>Kyle Anderson</strong>, <strong>Markus Ansmann</strong>, <strong>Frank Arute</strong>, <strong>Kunal Arya</strong>, <strong>Abraham Asfaw</strong>, <strong>Juan Atalaya</strong>, <strong>Joseph Bardin</strong>, <strong>A. Bilmes</strong>, <strong>Gina Bortoli</strong>, <strong>Alexandre Bourassa</strong>, <strong>Jenna Bovaird</strong>, <strong>Leon Brill</strong>, <strong>Michael Broughton</strong>, <strong>Bob B. Buckley</strong>, <strong>David Buell</strong>, <strong>Tim Burger</strong>, <strong>Brian Burkett</strong>, <strong>Nicholas Bushnell</strong>, <strong>Juan Campero</strong>, <strong>Hung-Shen Chang</strong>, <strong>Zijun Chen</strong>, <strong>Benjamin Chiaro</strong>, <strong>Desmond Chik</strong>, <strong>Josh Cogan</strong>, <strong>Roberto Collins</strong>, <strong>Paul Conner</strong>, <strong>William Courtney</strong>, <strong>Alexander Crook</strong>, <strong>Ben Curtin</strong>, <strong>Dripto Debroy</strong>, <strong>Alexander Del Toro Barba</strong>, <strong>Sean Demura</strong>, <strong>Agustin Di Paolo</strong>, <strong>Andrew Dunsworth</strong>, <strong>Clint Earle</strong>, <strong>E. Farhi</strong>, <strong>Reza Fatemi</strong>, <strong>Vinicius Ferreira</strong>, <strong>Leslie Flores</strong>, <strong>Ebrahim Forati</strong>, <strong>Austin Fowler</strong>, <strong>Brooks Foxen</strong>, <strong>Gonzalo Garcia</strong>, <strong>Élie Genois</strong>, <strong>William Giang</strong>, <strong>Craig Gidney</strong>, <strong>Dar Gilboa</strong>, <strong>Marissa Giustina</strong>, <strong>Raja Gosula</strong>, <strong>Alejandro Grajales Dau</strong>, <strong>Jonathan Gross</strong>, <strong>Steve Habegger</strong>, <strong>Michael Hamilton</strong>, <strong>Monica Hansen</strong>, <strong>Matthew Harrigan</strong>, <strong>Sean Harrington</strong>, <strong>Paula Heu</strong>, <strong>Gordon Hill</strong>, <strong>Markus Hoffmann</strong>, <strong>Sabrina Hong</strong>, <strong>Trent Huang</strong>, <strong>Ashley Huff</strong>, <strong>William Huggins</strong>, <strong>Lev Ioffe</strong>, <strong>Sergei Isakov</strong>, <strong>Justin Iveland</strong>, <strong>Evan Jeffrey</strong>, <strong>Zhang Jiang</strong>, <strong>Cody Jones</strong>, <strong>Pavol Juhas</strong>, <strong>D. Kafri</strong>, <strong>Tanuj Khattar</strong>, <strong>Mostafa Khezri</strong>, <strong>Mária Kieferová</strong>, <strong>Seon Kim</strong>, <strong>Alexei Kitaev</strong>, <strong>Andrey Klots</strong>, <strong>Alexander Korotkov</strong>, <strong>Fedor Kostritsa</strong>, <strong>John Mark Kreikebaum</strong>, <strong>David Landhuis</strong>, <strong>Pavel Laptev</strong>, <strong>Kim Ming Lau</strong>, <strong>Lily Laws</strong>, <strong>Joonho Lee</strong>, <strong>Kenneth Lee</strong>, <strong>Yuri Lensky</strong>, <strong>Brian Lester</strong>, <strong>Alexander Lill</strong>, <strong>Wayne Liu</strong>, <strong>William P. Livingston</strong>, <strong>A. Locharla</strong>, <strong>Salvatore Mandrà</strong>, <strong>Orion Martin</strong>, <strong>Steven Martin</strong>, <strong>Jarrod McClean</strong>, <strong>Matthew McEwen</strong>, <strong>Seneca Meeks</strong>, <strong>Kevin Miao</strong>, <strong>Amanda Mieszala</strong>, <strong>Shirin Montazeri</strong>, <strong>Ramis Movassagh</strong>, <strong>Wojciech Mruczkiewicz</strong>, <strong>Ani Nersisyan</strong>, <strong>Michael Newman</strong>, <strong>Jiun How Ng</strong>, <strong>Anthony Nguyen</strong>, <strong>Murray Nguyen</strong>, <strong>M. Niu</strong>, <strong>Thomas O'Brien</strong>, <strong>Seun Omonije</strong>, <strong>Alex Opremcak</strong>, <strong>Rebecca Potter</strong>, <strong>Leonid Pryadko</strong>, <strong>Chris Quintana</strong>, <strong>David Rhodes</strong>, <strong>Charles Rocque</strong>, <strong>N. Rubin</strong>, <strong>Negar Saei</strong>, <strong>Daniel Sank</strong>, <strong>Kannan Sankaragomathi</strong>, <strong>Kevin Satzinger</strong>, <strong>Henry Schurkus</strong>, <strong>Christopher Schuster</strong>, <strong>Michael Shearn</strong>, <strong>Aaron Shorter</strong>, <strong>Noah Shutty</strong>, <strong>Vladimir Shvarts</strong>, <strong>Volodymyr Sivak</strong>, <strong>Jindra Skruzny</strong>, <strong>Clarke Smith</strong>, <strong>Rolando Somma</strong>, <strong>George Sterling</strong>, <strong>Doug Strain</strong>, <strong>Marco Szalay</strong>, <strong>Douglas Thor</strong>, <strong>Alfredo Torres</strong>, <strong>Guifre Vidal</strong>, <strong>Benjamin Villalonga</strong>, <strong>Catherine Vollgraff Heidweiller</strong>, <strong>Theodore White</strong>, <strong>Bryan Woo</strong>, <strong>Cheng Xing</strong>, <strong>Jamie Yao</strong>, <strong>Ping Yeh</strong>, <strong>Juhwan Yoo</strong>, <strong>Grayson Young</strong>, <strong>Adam Zalcman</strong>, <strong>Yaxing Zhang</strong>, <strong>Ningfeng Zhu</strong>, <strong>Nicholas Zobrist</strong>, <strong>Hartmut Neven</strong>, <strong>Ryan Babbush</strong>, <strong>Dave Bacon</strong>, <strong>Sergio Boixo</strong>, <strong>Jeremy Hilton</strong>, <strong>Erik Lucero</strong>, <strong>Anthony Megrant</strong>, <strong>Julian Kelly</strong>, <strong>Yu Chen</strong>, <strong>Vadim Smelyanskiy</strong>, Vedika Khemani, Sarang Gopalakrishnan,<strong> Tomaž Prosen</strong>, <strong>Pedram Roushan</strong>
<br />
<em>Session W50: Quantum Simulation of Many-Body Physics</em>
<br />
<a href="https://arxiv.org/pdf/2306.09333.pdf">Link to Paper</a>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/W50.13">The fast multipole method on a quantum computer</a>
<br />
Presenter: Kianna Wan
<br />
Authors: Kianna Wan, Dominic W Berry, <strong>Ryan Babbush</strong>
<br />
<em>Session W50: Quantum Simulation of Many-Body Physics</em>
</p>
</div>
<div style="line-height: 40%;">
<br />
</div>
<h3>Friday</h3>
<div style="margin-left: 20px;">
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Y43.1">The quantum computing industry and protecting national security: what tools will work?</a>
<br />
Presenter: <strong>Kate Weber</strong>
<br />
Author: <strong>Kate Weber</strong>
<br />
<em>Session Y43: Industry, Innovation, and National Security: Finding the Right Balance</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Y46.3">Novel charging effects in the fluxonium qubit</a>
<br />
Presenter: <strong>Agustin Di Paolo</strong>
<br />
Authors: <strong>Agustin Di Paolo</strong>, Kyle Serniak, Andrew J Kerman, <strong>William D Oliver</strong>
<br />
<em>Session Y46: Fluxonium-Based Superconducting Quibits</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Z46.3">Microwave Engineering of Parametric Interactions in Superconducting Circuits</a>
<br />
Presenter: <strong>Ofer Naaman</strong>
<br />
Author: <strong>Ofer Naaman</strong>
<br />
<em>Session Z46: Broadband Parametric Amplifiers and Circulators</em>
</p>
<p>
<a href="https://meetings.aps.org/Meeting/MAR24/Session/Z62.3">Linear spin wave theory of large magnetic unit cells using the Kernel Polynomial Method</a>
<br />
Presenter: Harry Lane
<br />
Authors: Harry Lane, Hao Zhang, David A Dahlbom, Sam Quinn, <strong>Rolando D Somma</strong>, Martin P Mourigal, Cristian D Batista, Kipton Barros
<br />
<em>Session Z62: Cooperative Phenomena, Theory</em>
</p>
</div>
<!--Footnotes-->
<hr width="80%" />
<p>
<span class="Apple-style-span" style="font-size: x-small;"><sup><b>*</b></sup>Work done while at Google</span></p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-16952642776386708942024-02-22T12:05:00.000-08:002024-02-23T10:07:08.500-08:00VideoPrism: A foundational visual encoder for video understanding<span class="byline-author">Posted by Long Zhao, Senior Research Scientist, and Ting Liu, Senior Staff Software Engineer, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4kKy9Vqp7LE__mAG3METzRxmp6Z5PCH8AyfXzxQ_mNeIgOwYitblprQbb1fOTSUDgNgdmgsm7QwyXgkBcUDs2iIkxGue1n1sxdaomCyAo_eZD1-NFJEbn0fct-gJSNNs_MXHQQCxA79hVbd2CHzg2Nkpw1RnsOQWLq4Y7A7mxXTAFjR9NEE42A6pMOaDi/s450/VideoPrismSample.gif" style="display: none;" />
<p>
An astounding number of videos are available on the Web, covering a variety of content from everyday moments people share to historical moments to scientific observations, each of which contains a unique record of the world. The right tools could help researchers analyze these videos, transforming how we understand the world around us.
</p>
<a name='more'></a>
<p>
Videos offer dynamic visual content far more rich than static images, capturing movement, changes, and dynamic relationships between entities. Analyzing this complexity, along with the immense diversity of publicly available video data, demands models that go beyond traditional image understanding. Consequently, many of the approaches that best perform on video understanding still rely on specialized models tailor-made for particular tasks. Recently, there has been exciting progress in this area using video foundation models (ViFMs), such as <a href="https://arxiv.org/abs/2109.14084">VideoCLIP</a>, <a href="https://arxiv.org/abs/2212.03191">InternVideo</a>, <a href="https://arxiv.org/abs/2212.04979">VideoCoCa</a>, and <a href="https://arxiv.org/abs/2303.16058">UMT</a>. However, building a ViFM that handles the sheer diversity of video data remains a challenge.
</p>
<p>
With the goal of building a single model for general-purpose video understanding, we introduce “<a href="https://arxiv.org/abs/2402.13217">VideoPrism: A Foundational Visual Encoder for Video Understanding</a>”. VideoPrism is a ViFM designed to handle a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering (QA). We propose innovations in both the pre-training data as well as the modeling strategy. We pre-train VideoPrism on a massive and diverse dataset: 36 million high-quality video-text pairs and 582 million video clips with noisy or machine-generated parallel text. Our pre-training approach is designed for this hybrid data, to learn both from video-text pairs and the videos themselves. VideoPrism is incredibly easy to adapt to new video understanding challenges, and achieves state-of-the-art performance using a single frozen model.
</p><p></p>
<video autoplay="" loop="" muted="" playsinline="" width="100%"> <source src="https://github.com/garyzhao/videoprism-blog/raw/main/teaser.mp4" type="video/mp4"></source> </video>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td class="tr-caption" style="text-align: center;">VideoPrism is a general-purpose video encoder that enables state-of-the-art results over a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering, by producing video representations from a single frozen model.</td></tr></tbody></table>
<br />
<h2>Pre-training data</h2>
<p>
A powerful ViFM needs a very large collection of videos on which to train — similar to other foundation models (FMs), such as those for large language models (LLMs). Ideally, we would want the pre-training data to be a representative sample of all the videos in the world. While naturally most of these videos do not have perfect captions or descriptions, even imperfect text can provide useful information about the semantic content of the video.
</p>
<p>
To give our model the best possible starting point, we put together a massive pre-training corpus consisting of several public and private datasets, including <a href="https://rowanzellers.com/merlot/">YT-Temporal-180M</a>, <a href="https://arxiv.org/abs/2307.06942">InternVid</a>, <a href="https://arxiv.org/abs/2204.00679">VideoCC</a>, <a href="https://arxiv.org/abs/2007.14937">WTS-70M</a>, etc. This includes 36 million carefully selected videos with high-quality captions, along with an additional 582 million clips with varying levels of noisy text (like auto-generated transcripts). To our knowledge, this is the largest and most diverse video training corpus of its kind.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgrhfnM1Rg_xbS1b3ZtydWc0M7zOchLpi5qdj65UaR3mOYbV8SQQqKhUhltYwmkPNqrULdeVeE1nU3gnRkjR7pE-yFaiVRC1al-BxZecsO0aojXFzSDhfv45oZoOBeYA93IiNeCGdnUryh4HLc3w7Qr2PX0fy6-4qFMTKBORA_PfHspp7Nr1OW0WnAvn-S9/s1999/image18.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="779" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgrhfnM1Rg_xbS1b3ZtydWc0M7zOchLpi5qdj65UaR3mOYbV8SQQqKhUhltYwmkPNqrULdeVeE1nU3gnRkjR7pE-yFaiVRC1al-BxZecsO0aojXFzSDhfv45oZoOBeYA93IiNeCGdnUryh4HLc3w7Qr2PX0fy6-4qFMTKBORA_PfHspp7Nr1OW0WnAvn-S9/s16000/image18.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Statistics on the video-text pre-training data. The large variations of the <a href="https://arxiv.org/abs/2104.14806">CLIP similarity scores</a> (the higher, the better) demonstrate the diverse caption quality of our pre-training data, which is a byproduct of the various ways used to harvest the text.</td></tr></tbody></table>
<br />
<h2>Two-stage training</h2>
<p>
The VideoPrism model architecture stems from the standard <a href="https://arxiv.org/abs/2010.11929">vision transformer</a> (ViT) with a factorized design that sequentially encodes spatial and temporal information following <a href="https://arxiv.org/abs/2103.15691">ViViT</a>. Our training approach leverages both the high-quality video-text data and the video data with noisy text mentioned above. To start, we use <a href="https://en.wikipedia.org/wiki/Self-supervised_learning#Contrastive_self-supervised_learning">contrastive learning</a> (an approach that minimizes the distance between positive video-text pairs while maximizing the distance between negative video-text pairs) to teach our model to match videos with their own text descriptions, including imperfect ones. This builds a foundation for matching semantic language content to visual content.
</p>
<p>
After video-text contrastive training, we leverage the collection of videos without text descriptions. Here, we build on the <a href="https://arxiv.org/abs/2212.04500">masked video modeling framework</a> to predict masked patches in a video, with a few improvements. We train the model to predict both the video-level global embedding and token-wise embeddings from the first-stage model to effectively leverage the knowledge acquired in that stage. We then randomly shuffle the predicted tokens to prevent the model from learning shortcuts.
</p>
<p>
What is unique about VideoPrism’s setup is that we use two complementary pre-training signals: text descriptions and the visual content within a video. Text descriptions often focus on what things look like, while the video content provides information about movement and visual dynamics. This enables VideoPrism to excel in tasks that demand an understanding of both appearance and motion.
</p>
<br />
<h2>Results</h2>
<p>
We conduct extensive evaluation on VideoPrism across four broad categories of video understanding tasks, including video classification and localization, video-text retrieval, video captioning, question answering, and scientific video understanding. VideoPrism achieves state-of-the-art performance on 30 out of 33 video understanding benchmarks — all with minimal adaptation of a single, frozen model.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgiUtXCxgEXrgAZJ2B-Mn8L0DP7VkFUfUbI1yLTgGYSbWtn_Q5AjgGRgi3yQ5PMB3fVFlHLzDP4yhlCeGaPpdXr5I1-TNYelYMUBYiXx16qNYTpqKwAqXX7-EFV-4Asn6qYFWOb6_5p71n5Zzxbt-ZeUy5yIj2aieKXl0LnFOqdhKXa56xm4ZoXbccYDz3H/s1999/image20.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1999" data-original-width="1959" height="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgiUtXCxgEXrgAZJ2B-Mn8L0DP7VkFUfUbI1yLTgGYSbWtn_Q5AjgGRgi3yQ5PMB3fVFlHLzDP4yhlCeGaPpdXr5I1-TNYelYMUBYiXx16qNYTpqKwAqXX7-EFV-4Asn6qYFWOb6_5p71n5Zzxbt-ZeUy5yIj2aieKXl0LnFOqdhKXa56xm4ZoXbccYDz3H/w628-h640/image20.png" width="628" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">VideoPrism compared to the previous best-performing FMs.</td></tr></tbody></table>
<div style="line-height: 40%;"><br />
</div>
<h3>Classification and localization</h3>
<p>
We evaluate VideoPrism on an existing large-scale video understanding benchmark (<a href="https://arxiv.org/abs/2307.03166">VideoGLUE</a>) covering classification and localization tasks. We find that (1) VideoPrism outperforms all of the other state-of-the-art FMs, and (2) no other single model consistently came in second place. This tells us that VideoPrism has learned to effectively pack a variety of video signals into one encoder — from semantics at different granularities to appearance and motion cues — and it works well across a variety of video sources.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNnyg_lnLfwDIsJElqFwLKJleb1quzOR4h7X5jBf_bAnxwo_Em-_XLtWkkyMkyMPcLGdm0F25tLmccw3eK9qt6NN4LrLvfF45Wu8J2ylCqi4hPE-rFOwzmGuV8II6Nq8hileMNrS1lMwCuOHTVNGS04Dsxc7yVztaMCu0sRvuMUHnN4u9IKEvv2g8fRYWo/s1816/image12.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="742" data-original-width="1816" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNnyg_lnLfwDIsJElqFwLKJleb1quzOR4h7X5jBf_bAnxwo_Em-_XLtWkkyMkyMPcLGdm0F25tLmccw3eK9qt6NN4LrLvfF45Wu8J2ylCqi4hPE-rFOwzmGuV8II6Nq8hileMNrS1lMwCuOHTVNGS04Dsxc7yVztaMCu0sRvuMUHnN4u9IKEvv2g8fRYWo/s16000/image12.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">VideoPrism outperforms state-of-the-art approaches (including <a href="https://arxiv.org/abs/2103.00020">CLIP</a>, <a href="https://arxiv.org/abs/2104.11178">VATT</a>, <a href="https://arxiv.org/abs/2212.03191">InternVideo</a>, and <a href="https://arxiv.org/abs/2303.16058">UMT</a>) on the <a href="https://arxiv.org/abs/2307.03166">video understanding benchmark</a>. In this plot, we show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. On <a href="http://vuchallenge.org/charades.html">Charades</a>, <a href="http://activity-net.org/">ActivityNet</a>, <a href="https://research.google.com/ava/">AVA</a>, and <a href="https://research.google.com/ava/">AVA-K</a>, we use <a href="https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision">mean average precision</a> (mAP) as the evaluation metric. On the other datasets, we report top-1 accuracy.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Combining with LLMs</h3>
<p>
We further explore combining VideoPrism with LLMs to unlock its ability to handle various video-language tasks. In particular, when paired with a text encoder (following <a href="https://arxiv.org/abs/2111.07991">LiT</a>) or a language decoder (such as <a href="https://arxiv.org/abs/2305.10403">PaLM-2</a>), VideoPrism can be utilized for video-text retrieval, video captioning, and video QA tasks. We compare the combined models on a broad and challenging set of vision-language benchmarks. VideoPrism sets the new state of the art on most benchmarks. From the visual results, we find that VideoPrism is capable of understanding complex motions and appearances in videos (e.g., the model can recognize the different colors of spinning objects on the window in the visual examples below). These results demonstrate that VideoPrism is strongly compatible with language models.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjd7V86xYM18_i3s0aemjiiYxaJeBiooZrEicQ5VVkLK3QnWTR96hKVsobSO4qRiN0f253JPX4y-T_h17E2Rx80PIVtVed0q499uCv42RzxZ7crkr21nuCR0zwalkSUX9FxIbjWVmlQGb1yx9Y5J8aVT_ROkY4DB1skUkk-bc9FaCc6tc-XLumHk5P65_UR/s1028/VideoPrismResults.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="932" data-original-width="1028" height="580" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjd7V86xYM18_i3s0aemjiiYxaJeBiooZrEicQ5VVkLK3QnWTR96hKVsobSO4qRiN0f253JPX4y-T_h17E2Rx80PIVtVed0q499uCv42RzxZ7crkr21nuCR0zwalkSUX9FxIbjWVmlQGb1yx9Y5J8aVT_ROkY4DB1skUkk-bc9FaCc6tc-XLumHk5P65_UR/w640-h580/VideoPrismResults.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">VideoPrism achieves competitive results compared with state-of-the-art approaches (including <a href="https://arxiv.org/abs/2212.04979">VideoCoCa</a>, <a href="https://arxiv.org/abs/2303.16058">UMT</a> and <a href="https://arxiv.org/abs/2204.14198">Flamingo</a>) on multiple video-text retrieval (top) and video captioning and video QA (bottom) benchmarks. We also show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. We report the Recall@1 on <a href="https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/">MASRVTT</a>, <a href="https://eric-xw.github.io/vatex-website/index.html">VATEX</a>, and <a href="https://cs.stanford.edu/people/ranjaykrishna/densevid/">ActivityNet</a>, <a href="https://arxiv.org/abs/1411.5726">CIDEr score</a> on <a href="https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/">MSRVTT-Cap</a>, <a href="https://eric-xw.github.io/vatex-website/index.html">VATEX-Cap</a>, and <a href="http://youcook2.eecs.umich.edu/">YouCook2</a>, top-1 accuracy on <a href="https://github.com/xudejing/video-question-answering">MSRVTT-QA</a> and <a href="https://github.com/xudejing/video-question-answering">MSVD-QA</a>, and <a href="https://arxiv.org/abs/cmp-lg/9406033">WUPS index</a> on <a href="https://doc-doc.github.io/docs/nextqa.html">NExT-QA</a>.</td></tr></tbody></table>
<br />
<video autoplay="" loop="" muted="" playsinline="" width="100%"> <source src="https://github.com/garyzhao/videoprism-blog/raw/main/snowball_water_bottle_drum.mp4" type="video/mp4"></source> </video>
<video autoplay="" loop="" muted="" playsinline="" width="100%"> <source src="https://github.com/garyzhao/videoprism-blog/raw/main/spin_roller_skating.mp4" type="video/mp4"></source> </video>
<video autoplay="" loop="" muted="" playsinline="" width="100%"> <source src="https://github.com/garyzhao/videoprism-blog/raw/main/making_ice_cream_ski_lifting.mp4" type="video/mp4"></source> </video>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td class="tr-caption" style="text-align: center;">We show qualitative results using VideoPrism with a text encoder for video-text retrieval (first row) and adapted to a language decoder for video QA (second and third row). For video-text retrieval examples, the blue bars indicate the embedding similarities between the videos and the text queries.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Scientific applications</h3>
<p>
Finally, we test VideoPrism on datasets used by scientists across domains, including fields such as ethology, behavioral neuroscience, and ecology. These datasets typically require domain expertise to annotate, for which we leverage existing scientific datasets open-sourced by the community including <a href="https://data.caltech.edu/records/zrznw-w7386">Fly vs. Fly</a>, <a href="https://data.caltech.edu/records/s0vdx-0k302">CalMS21</a>, <a href="https://shirleymaxx.github.io/ChimpACT/">ChimpACT</a>, and <a href="https://dirtmaxim.github.io/kabr/">KABR</a>. VideoPrism not only performs exceptionally well, but actually surpasses models designed specifically for those tasks. This suggests tools like VideoPrism have the potential to transform how scientists analyze video data across different fields.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3v-C36GWUp8CkaCVqFvaXYKW6-1SvCo99Ogiul-fSTkftyc-t4z5CNUgEWlJkRmzranQrYHldtBvjeJXsqdB4ZbgBkyaZv-_I9QE5U7kus_Z8QWlVqfzX0JfELSDPfGj9V4QqhUMwX_EkyPM-vG7pdYMXN0kj1-s98IZJl3U8CpvqoOHyAsuwXIVt7M4_/s1200/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="742" data-original-width="1200" height="397" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3v-C36GWUp8CkaCVqFvaXYKW6-1SvCo99Ogiul-fSTkftyc-t4z5CNUgEWlJkRmzranQrYHldtBvjeJXsqdB4ZbgBkyaZv-_I9QE5U7kus_Z8QWlVqfzX0JfELSDPfGj9V4QqhUMwX_EkyPM-vG7pdYMXN0kj1-s98IZJl3U8CpvqoOHyAsuwXIVt7M4_/w640-h397/image5.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">VideoPrism outperforms the domain experts on various scientific benchmarks. We show the absolute score differences to highlight the relative improvements of VideoPrism. We report mean average precision (mAP) for all datasets, except for KABR which uses class-averaged top-1 accuracy.</td></tr></tbody></table>
<br />
<h2>Conclusion</h2>
<p>
With VideoPrism, we introduce a powerful and versatile video encoder that sets a new standard for general-purpose video understanding. Our emphasis on both building a massive and varied pre-training dataset and innovative modeling techniques has been validated through our extensive evaluations. Not only does VideoPrism consistently outperform strong baselines, but its unique ability to generalize positions it well for tackling an array of real-world applications. Because of its potential broad use, we are committed to continuing further responsible research in this space, guided by our <a href="http://ai.google/principles">AI Principles</a>. We hope VideoPrism paves the way for future breakthroughs at the intersection of AI and video analysis, helping to realize the potential of ViFMs across domains such as scientific discovery, education, and healthcare.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>This blog post is made on behalf of all the VideoPrism authors: Long Zhao, Nitesh B. Gundavarapu, Liangzhe Yuan, Hao Zhou, Shen Yan, Jennifer J. Sun, Luke Friedman, Rui Qian, Tobias Weyand, Yue Zhao, Rachel Hornung, Florian Schroff, Ming-Hsuan Yang, David A. Ross, Huisheng Wang, Hartwig Adam, Mikhail Sirotenko, Ting Liu, and Boqing Gong. We sincerely thank David Hendon for their product management efforts, and Alex Siegman, Ramya Ganeshan, and Victor Gomes for their program and resource management efforts. We also thank Hassan Akbari, Sherry Ben, Yoni Ben-Meshulam, Chun-Te Chu, Sam Clearwater, Yin Cui, Ilya Figotin, Anja Hauth, Sergey Ioffe, Xuhui Jia, Yeqing Li, Lu Jiang, Zu Kim, Dan Kondratyuk, Bill Mark, Arsha Nagrani, Caroline Pantofaru, Sushant Prakash, Cordelia Schmid, Bryan Seybold, Mojtaba Seyedhosseini, Amanda Sadler, Rif A. Saurous, Rachel Stigler, Paul Voigtlaender, Pingmei Xu, Chaochao Yan, Xuan Yang, and Yukun Zhu for the discussions, support, and feedback that greatly contributed to this work. We are grateful to Jay Yagnik, Rahul Sukthankar, and Tomas Izo for their enthusiastic support for this project. Lastly, we thank Tom Small, Jennifer J. Sun, Hao Zhou, Nitesh B. Gundavarapu, Luke Friedman, and Mikhail Sirotenko for the tremendous help with making this blog post.</em>
</p><p></p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-43432355099090917412024-02-21T12:15:00.000-08:002024-02-21T12:15:36.694-08:00Advances in private training for production on-device language models<span class="byline-author">Posted by Zheng Xu, Research Scientist, and Yanxiang Zhang, Software Engineer, Google</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEifnCZ_XGUoUG0hESM0dF5B8Rsoqo4YrT_-uv0hlDM1iTADhtEEyEvBM4hOWT0rxgpVtZKyuFoj2xeXmkeXwGe-XTmvBuwBDJOCqgN8Ba7Wcjh_s1seWUaCRl1xNpNe_6MqxcFFZoAvhfCge5vq9UATjXG_BnTiGdQ6YLLo7AK7ABS3KLFMKmjAtA1gkcBk/s1600/GBoard%20PrivacyHero.gif" style="display: none;" />
<p>
Language models (LMs) trained to predict the next word given input text are the key technology for many applications [<a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/">1</a>, <a href="https://blog.google/technology/ai/google-gemini-ai/">2</a>]. In <a href="https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US&gl=US">Gboard</a>, LMs are used to improve users’ typing experience by supporting features like <a href="https://arxiv.org/abs/1811.03604">next word prediction</a> (NWP), <a href="https://support.google.com/gboard/answer/7068415">Smart Compose</a>,<a href="https://support.google.com/gboard/answer/7068415"> smart completion</a> and <a href="https://support.google.com/gboard/answer/7068415">suggestion</a>, <a href="https://support.google.com/gboard/answer/2811346">slide to type</a><span style="text-decoration: underline;">,</span> and <a href="https://support.google.com/gboard/answer/7068415">proofread</a>. Deploying models on users’ devices rather than enterprise servers has advantages like lower latency and better privacy for model usage. While training on-device models directly from user data effectively improves the utility performance for applications such as NWP and <a href="https://blog.research.google/2021/11/predicting-text-selections-with.html">smart text selection</a>, protecting the privacy of user data for model training is important.
</p>
<a name='more'></a>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiWvaPvikHjeVBb9njeoP2z499_LU0a4VfEgI2kOVxYEoApqgZ49-Ej_TpY6pyoy9HKU2jASzSBsKhdXuOhP-ykpsK_makFmWzVF67BPS3PSpRrCIxC0hYHogBVcDM74AXmjD5hh2mP22tPmXQqEkOak9QXXLyJOCsJB94dv0P-W3IINYyah2O-nF1HLTXE/s1996/image45.gif" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1600" data-original-width="1996" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiWvaPvikHjeVBb9njeoP2z499_LU0a4VfEgI2kOVxYEoApqgZ49-Ej_TpY6pyoy9HKU2jASzSBsKhdXuOhP-ykpsK_makFmWzVF67BPS3PSpRrCIxC0hYHogBVcDM74AXmjD5hh2mP22tPmXQqEkOak9QXXLyJOCsJB94dv0P-W3IINYyah2O-nF1HLTXE/s16000/image45.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Gboard features powered by on-device language models.</td></tr></tbody></table>
<p>
In this blog we discuss how years of research advances now power the private training of Gboard LMs, since the proof-of-concept development of <a href="https://blog.research.google/2017/04/federated-learning-collaborative.html">federated learning</a> (FL) in 2017 and formal <a href="https://blog.research.google/2022/02/federated-learning-with-formal.html">differential privacy</a> (DP) guarantees in 2022. <a href="https://blog.research.google/2017/04/federated-learning-collaborative.html">FL</a> enables mobile phones to collaboratively learn a model while keeping all the training data on device, and <a href="https://en.wikipedia.org/wiki/Differential_privacy">DP</a> provides a quantifiable measure of data anonymization. Formally, DP is often characterized by (<em>ε</em>, <em>δ</em>) with smaller values representing stronger guarantees. Machine learning (ML) models are considered to have <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">reasonable DP guarantees for ε=10 and strong DP guarantees for ε=1</a> when <em>δ</em> is small.
</p>
<p>
As of today, all NWP neural network LMs in Gboard are trained with FL with formal DP guarantees, and all future launches of Gboard LMs trained on user data require DP. These 30+ Gboard on-device LMs are launched in 7+ languages and 15+ countries, and satisfy (<em>ɛ</em>, <em>δ</em>)-DP guarantees of small <em>δ</em> of 10<sup>-10</sup> and ɛ between 0.994 and 13.69. To the best of our knowledge, this is the largest known deployment of user-level DP in production at Google or anywhere, and the first time a strong DP guarantee of <em>ɛ</em> < 1 is announced for models trained directly on user data.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Privacy principles and practices in Gboard</h2>
<p>
In “<a href="https://arxiv.org/abs/2306.14793">Private Federated Learning in Gboard</a>”, we discussed how different <a href="https://queue.acm.org/detail.cfm?id=3501293">privacy principles</a> are currently reflected in production models, including:
</p>
<ul>
<li><em>Transparency and user control</em>: We provide disclosure of what data is used, what purpose it is used for, how it is processed in various channels, and how Gboard users can easily <a href="https://support.google.com/gboard/answer/12373137">configure</a> the data usage in learning models.
</li><li><em>Data minimization</em>: FL immediately aggregates only focused updates that improve a specific model. <a href="https://eprint.iacr.org/2017/281.pdf">Secure aggregation</a> (SecAgg) is an encryption method to further guarantee that only aggregated results of the ephemeral updates can be accessed.
</li><li><em>Data anonymization</em>: DP is applied by the server to prevent models from memorizing the unique information in individual user’s training data.
</li><li><em>Auditability and verifiability</em>: We have made public the key algorithmic approaches and privacy accounting in open-sourced code (<a href="https://github.com/tensorflow/federated/blob/main/tensorflow_federated/python/aggregators/differential_privacy.py">TFF aggregator</a>, <a href="https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py">TFP DPQuery</a>, <a href="https://github.com/google-research/federated/blob/master/dp_ftrl/blogpost_supplemental_privacy_accounting.ipynb">DP accounting</a>, and <a href="https://github.com/google/federated-compute">FL system</a>).
</li>
</ul>
<div style="line-height: 40%;">
<br />
</div>
<h3>A brief history</h3>
<p>
In recent years, FL has become the default method for training <a href="https://arxiv.org/abs/1811.03604">Gboard on-device LMs</a> from user data. In 2020, a DP mechanism that <a href="https://arxiv.org/abs/1710.06963">clips and adds noise</a> to model updates was used to <a href="https://arxiv.org/abs/2009.10031">prevent memorization</a> for training the Spanish LM in Spain, which satisfies finite DP guarantees (<a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">Tier 3</a> described in “<a href="https://arxiv.org/abs/2303.00654">How to DP-fy ML“</a> guide). In 2022, with the help of the <a href="https://arxiv.org/abs/2103.00039">DP-Follow-The-Regularized-Leader (DP-FTRL) algorithm</a>, the Spanish LM became the first production neural network trained directly on user data announced with <a href="https://blog.research.google/2022/02/federated-learning-with-formal.html">a formal DP guarantee of (ε=8.9, δ=10<sup>-10</sup>)-DP</a> (equivalent to the reported <em><a href="https://blog.research.google/2022/02/federated-learning-with-formal.html">ρ=0.81</a></em> <a href="https://arxiv.org/abs/1605.02065">zero-Concentrated-Differential-Privacy</a>), and therefore satisfies <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">reasonable privacy guarantees</a> (<a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">Tier 2</a>).
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Differential privacy by default in federated learning </h2>
<p>
In “<a href="https://arxiv.org/abs/2305.18465">Federated Learning of Gboard Language Models with Differential Privacy</a>”, we announced that all the NWP neural network LMs in Gboard have DP guarantees, and all future launches of Gboard LMs trained on user data require DP guarantees. DP is enabled in FL by applying the following practices:
</p>
<ul>
<li>Pre-train the model with the <a href="https://arxiv.org/abs/2010.11934">multilingual</a> <a href="https://arxiv.org/abs/1910.10683">C4</a> dataset.
</li><li>Via simulation experiments on public datasets, find a large DP-noise-to-signal ratio that allows for high utility. Increasing the number of clients contributing to one round of model update improves privacy while keeping the noise ratio fixed for good utility, up to the point the DP target is met, or the maximum allowed by the system and the size of the population.
</li><li>Configure the parameter to restrict the frequency each client can contribute (e.g., once every few days) based on computation budget and estimated population in <a href="https://arxiv.org/abs/1902.01046">the FL system</a>.
</li><li>Run <a href="https://arxiv.org/abs/2103.00039">DP-FTRL</a> training with limits on the magnitude of per-device updates chosen either via <a href="https://github.com/tensorflow/federated/commit/ee9d08368828ea730662e5e2b3a90e103368b6b6">adaptive clipping</a>, or fixed based on experience.
</li>
</ul>
<p>
SecAgg can be additionally applied by adopting the <a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html">advances in improving computation and communication for scales and sensitivity</a>.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht2ZweKyxBRqShB6i41lTpZmfS2gEi2rbNHFGgT-36di1HMxwV6caxFJ2lUXpznxuXYHEb928yfHwueojKlB-gxfKfT4aEv-_2mUlO5zlaWNPceMDGdnOVWp4M8T5qCzMPTuinPOtRy1WmXMtsaSpNpMLvokQKlOnWYFMJF0tXbhmc-dkpI-o7T4FBn8-N/s1600/image3.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1000" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht2ZweKyxBRqShB6i41lTpZmfS2gEi2rbNHFGgT-36di1HMxwV6caxFJ2lUXpznxuXYHEb928yfHwueojKlB-gxfKfT4aEv-_2mUlO5zlaWNPceMDGdnOVWp4M8T5qCzMPTuinPOtRy1WmXMtsaSpNpMLvokQKlOnWYFMJF0tXbhmc-dkpI-o7T4FBn8-N/s16000/image3.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Federated learning with differential privacy and (SecAgg).</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Reporting DP guarantees</h3>
<p>
The DP guarantees of launched Gboard NWP LMs are visualized in the barplot below. The <em>x</em>-axis shows LMs labeled by language-locale and trained on corresponding populations; the <em>y</em>-axis shows the <em>ε</em> value when <em>δ</em> is fixed to a small value of 10<sup>-10</sup> for <a href="https://www.iacr.org/archive/eurocrypt2006/40040493/40040493.pdf">(ε, δ)-DP</a> (lower is better). The utility of these models are either significantly better than previous non-neural models in production, or comparable with previous LMs without DP, measured based on user-interactions metrics during A/B testing. For example, by applying the best practices, the DP guarantee of the Spanish model in Spain is improved from <em><a href="https://blog.research.google/2022/02/federated-learning-with-formal.html">ε=8.9</a></em> to <em>ε</em>=5.37. SecAgg is additionally used for training the Spanish model in Spain and English model in the US. More details of the DP guarantees are reported in <a href="https://arxiv.org/abs/2305.18465">the appendix </a>following the <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">guidelines outlined</a> in “<a href="https://arxiv.org/abs/2303.00654">How to DP-fy ML</a>”.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Towards stronger DP guarantees</h2>
<p>
The <em>ε</em>~10 DP guarantees of many launched LMs are already considered <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">reasonable</a> for ML models in practice, while the journey of DP FL in Gboard continues for improving user typing experience while protecting data privacy. We are excited to announce that, for the first time, production LMs of Portuguese in Brazil and Spanish in Latin America are trained and launched with a DP guarantee of <em>ε</em> ≤ 1, which satisfies <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">Tier 1 strong privacy guarantees</a>. Specifically, the (<em>ε</em>=0.994, <em>δ</em>=10<sup>-10</sup>)-DP guarantee is achieved by running the advanced <a href="https://arxiv.org/abs/2306.08153">Matrix Factorization DP-FTRL</a> (MF-DP-FTRL) algorithm, with 12,000+ devices participating in every training round of server model update larger than the <a href="https://arxiv.org/abs/2305.18465">common setting of 6500+ devices</a>, and a carefully configured policy to restrict each client to at most participate twice in the total 2000 rounds of training in 14 days in the large Portuguese user population of Brazil. Using a similar setting, the es-US Spanish LM was trained in a large population combining multiple countries in Latin America to achieve (<em>ε</em>=0.994, <em>δ</em>=10<sup>-10</sup>)-DP. The <em>ε</em> ≤ 1 es-US model significantly improved the utility in many countries, and launched in Colombia, Ecuador, Guatemala, Mexico, and Venezuela. For the smaller population in Spain, the DP guarantee of es-ES LM is improved from <em><a href="https://arxiv.org/abs/2305.18465">ε=5.37</a></em> to <em>ε</em>=3.42 by only replacing <a href="https://arxiv.org/abs/2103.00039">DP-FTRL</a> with <a href="https://arxiv.org/abs/2306.08153">MF-DP-FTRL</a> without increasing the number of devices participating every round. More technical details are disclosed in the <a href="https://colab.sandbox.google.com/github/google-research/federated/blob/master/mf_dpftrl_matrices/privacy_accounting.ipynb">colab</a> for privacy accounting.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgp1yNOAbd8IRoisQDX-OHq-a8PUDH2V1OF7btRsUXI86-tuEXwrR8otAGEqPN8J2HGcpH9aB25s04Nybm_Vn6bpRmfD_AHnHYkGJtld7ockal6mhdRXcsA-M6rf3vM7kzQ5hXfdPbw9hk7bsQU8EV4ul5QAn3Hw4b1yXIKjnokfhrkEF0hNXGt9DbLU3yk/s1999/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="709" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgp1yNOAbd8IRoisQDX-OHq-a8PUDH2V1OF7btRsUXI86-tuEXwrR8otAGEqPN8J2HGcpH9aB25s04Nybm_Vn6bpRmfD_AHnHYkGJtld7ockal6mhdRXcsA-M6rf3vM7kzQ5hXfdPbw9hk7bsQU8EV4ul5QAn3Hw4b1yXIKjnokfhrkEF0hNXGt9DbLU3yk/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">DP guarantees for Gboard NWP LMs (the purple bar represents the first es-ES launch of ε=8.9; cyan bars represent privacy improvements for models trained with <a href="https://arxiv.org/abs/2306.08153">MF-DP-FTRL</a>; <a href="https://blog.research.google/2023/05/making-ml-models-differentially-private.html">tiers </a>are from “<a href="https://arxiv.org/abs/2303.00654">How to DP-fy ML</a>“ guide; en-US* and es-ES* are additionally trained with SecAgg).</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Discussion and next steps</h2>
<p>
Our experience suggests that DP can be achieved in practice through system algorithm co-design on client participation, and that both privacy and utility can be strong when populations are large <em>and</em> a large number of devices' contributions are aggregated. Privacy-utility-computation trade-offs can be improved by <a href="https://arxiv.org/abs/2305.18465">using public data</a>, the <a href="https://arxiv.org/abs/2306.08153">new MF-DP-FTRL algorithm</a>, <a href="https://github.com/google/differential-privacy">and tightening accounting</a>. With these techniques, a strong DP guarantee of <em>ε</em> ≤ 1 is possible but still challenging. Active research on empirical privacy auditing [<a href="https://arxiv.org/abs/2302.03098">1</a>, <a href="https://arxiv.org/abs/2305.08846">2</a>] suggests that DP models are potentially more private than the worst-case DP guarantees imply. While we keep pushing the frontier of algorithms, which dimension of privacy-utility-computation should be prioritized?
</p>
<p>
We are actively working on all privacy aspects of ML, including extending DP-FTRL to <a href="https://blog.research.google/2023/03/distributed-differential-privacy-for.html">distributed DP</a> and improving <a href="https://arxiv.org/abs/2306.14793">auditability and verifiability</a>. <a href="https://en.wikipedia.org/wiki/Trusted_execution_environment">Trusted Execution Environment</a> opens the opportunity for substantially increasing the model size with verifiable privacy. The recent <a href="https://blog.google/technology/ai/google-gemini-ai/">breakthrough in large LMs</a> (LLMs) motivates us to <a href="https://arxiv.org/abs/2305.12132">rethink</a> the usage of <a href="https://arxiv.org/abs/2212.06470">public</a> information in private training and more future interactions between LLMs, on-device LMs, and Gboard production.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgments</h2>
<p>
<em>The authors would like to thank Peter Kairouz, Brendan McMahan, and Daniel Ramage for their early feedback on the blog post itself, Shaofeng Li and Tom Small for helping with the animated figures, and the teams at Google that helped with algorithm design, infrastructure implementation, and production maintenance. The collaborators below directly contribute to the presented results:</em>
</p>
<p>
<em>Research and algorithm development: Galen Andrew, Stanislav Chiknavaryan, Christopher A. Choquette-Choo, Arun Ganesh, Peter Kairouz, Ryan McKenna, H. Brendan McMahan, Jesse Rosenstock, Timon Van Overveldt, Keith Rush, Shuang Song, Thomas Steinke, Abhradeep Guha Thakurta, Om Thakkar, and Yuanbo Zhang.</em>
</p>
<p>
<em>Infrastructure, production and leadership support: Mingqing Chen, Stefan Dierauf, Billy Dou, Hubert Eichner, Zachary Garrett, Jeremy Gillula, Jianpeng Hou, Hui Li, Xu Liu, Wenzhi Mao, Brett McLarnon, Mengchen Pei, Daniel Ramage, Swaroop Ramaswamy, Haicheng Sun, Andreas Terzis, Yun Wang, Shanshan Wu, Yu Xiao, and Shumin Zhai.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-56059330332992610252024-02-14T10:32:00.000-08:002024-02-14T10:32:25.557-08:00Learning the importance of training data under concept drift<span class="byline-author">Posted by Nishant Jain, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgUeskw4YD6cFTpLaRnv7OwMsljyeipfAb1riYxIuBsiWd6TBmUXMJ4QoI9tlvUzWX9NzBbEjz3-P2Zl2kuXe5BrVclmqQFrLButoya5phiEELq1azrhsIaGaCz-ov_jXaMsFrGRDE0EjotyRQPOX3xV5MAkVJfKp9xecX4t2CoLBiZ8r2RpZ25Y5KRitFG/s1600/temporalreweightinghero.png" style="display: none;" />
<p>
The constantly changing nature of the world around us poses a significant challenge for the development of AI models. Often, models are trained on longitudinal data with the hope that the training data used will accurately represent inputs the model may receive in the future. More generally, the default assumption that all training data are equally relevant often breaks in practice. For example, the figure below shows images from the <a href="https://arxiv.org/abs/2201.06289">CLEAR</a> nonstationary learning benchmark, and it illustrates how visual features of objects evolve significantly over a 10 year span (a phenomenon we refer to as <em>slow concept drift</em>), posing a challenge for object categorization models.
</p>
<a name='more'></a>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBAkCetRQiAPA4cmiXvtwa2SJ0pMwvRYDcuL7rQEDHxEgi9lAyU69bBeeEw-_k182BITn4w2WtdE5QfUwaF-Ny-Dkai-pLeHV23mlgAwrX_0le28l5hba9q9QUO3LeYl2jgkPGkKcLW7dtnGFMiY7PrZbpigSggAiOSrRB8X9eQZGHLE8H7TZoxYy4AD2Q/s1999/image4.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="662" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBAkCetRQiAPA4cmiXvtwa2SJ0pMwvRYDcuL7rQEDHxEgi9lAyU69bBeeEw-_k182BITn4w2WtdE5QfUwaF-Ny-Dkai-pLeHV23mlgAwrX_0le28l5hba9q9QUO3LeYl2jgkPGkKcLW7dtnGFMiY7PrZbpigSggAiOSrRB8X9eQZGHLE8H7TZoxYy4AD2Q/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Sample images from the CLEAR benchmark. (Adapted from Lin et al<a href="https://arxiv.org/abs/2201.06289">.</a>)</td></tr></tbody></table>
<br>
<p>
Alternative approaches, such as <a href="https://en.wikipedia.org/wiki/Online_machine_learning">online</a> and <a href="https://wiki.continualai.org/the-continualai-wiki/introduction-to-continual-learning">continual learning</a>, repeatedly update a model with small amounts of recent data in order to keep it current. This implicitly prioritizes recent data, as the learnings from past data are gradually erased by subsequent updates. However in the real world, different kinds of information lose relevance at different rates, so there are two key issues: 1) By design they focus <em>exclusively</em> on the most recent data and lose any signal from older data that is erased. 2) Contributions from data instances decay <em>uniformly over time</em> irrespective of the contents of the data.
</p>
<p>
In our recent work, “<a href="https://arxiv.org/abs/2212.05908">Instance-Conditional Timescales of Decay for Non-Stationary Learning</a>”, we propose to assign each instance an importance score during training in order to maximize model performance on future data. To accomplish this, we employ an auxiliary model that produces these scores using the training instance as well as its age. This model is jointly learned with the primary model. We address both the above challenges and achieve significant gains over other robust learning methods on a range of benchmark datasets for nonstationary learning. For instance, on a <a href="https://arxiv.org/abs/2108.09020">recent large-scale benchmark</a> for nonstationary learning (~39M photos over a 10 year period), we show up to 15% relative accuracy gains through learned reweighting of training data.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>The challenge of concept drift for supervised learning</h2>
<p>
To gain quantitative insight into slow concept drift, we built classifiers on a <a href="https://arxiv.org/abs/2108.09020">recent photo categorization task</a>, comprising roughly 39M photographs sourced from social media websites over a 10 year period. We compared offline training, which iterated over all the training data multiple times in random order, and continual training, which iterated multiple times over each month of data in sequential (temporal) order. We measured model accuracy both during the training period and during a subsequent period where both models were frozen, i.e., not updated further on new data (shown below). At the end of the training period (left panel, x-axis = 0), both approaches have seen the same amount of data, but show a large performance gap. This is due to <a href="https://www.sciencedirect.com/science/article/abs/pii/S0079742108605368">catastrophic forgetting</a>, a problem in continual learning where a model’s knowledge of data from early on in the training sequence is diminished in an uncontrolled manner. On the other hand, forgetting has its advantages — over the test period (shown on the right), the continual trained model degrades much less rapidly than the offline model because it is less dependent on older data. The decay of both models’ accuracy in the test period is confirmation that the data is indeed evolving over time, and both models become increasingly less relevant.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEizQmgaL3NNsCLWbeTndyOxPikcGKqQIrpDisMVTy-7eAIxamEv3Klpncd5B4SB19yNnPmpySlfAz_hPN8x4zV7o0LPmcLKEnyVJBctKuLF8plITBmDz3BTR2aPHqlKarPPHZHpp0EY0M3HA9l5oV_IOaQS5UzS-uMaNq3Fi1D1qHUYJ6XC-4t0_xS91fnw/s1554/image2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="616" data-original-width="1554" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEizQmgaL3NNsCLWbeTndyOxPikcGKqQIrpDisMVTy-7eAIxamEv3Klpncd5B4SB19yNnPmpySlfAz_hPN8x4zV7o0LPmcLKEnyVJBctKuLF8plITBmDz3BTR2aPHqlKarPPHZHpp0EY0M3HA9l5oV_IOaQS5UzS-uMaNq3Fi1D1qHUYJ6XC-4t0_xS91fnw/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparing offline and continually trained models on the photo classification task.</td></tr></tbody></table>
<br>
<div style="line-height:40%;">
<br>
</div>
<h2>Time-sensitive reweighting of training data</h2>
<p>
We design a method combining the benefits of offline learning (the flexibility of effectively reusing all available data) and continual learning (the ability to downplay older data) to address slow concept drift. We build upon offline learning, then add careful control over the influence of past data and an optimization objective, both designed to reduce model decay in the future.
</p>
<p>
Suppose we wish to train a model, <em>M</em>,<em> </em>given some training data collected over time. We propose to also train a helper model that assigns a weight to each point based on its contents and age. This weight scales the contribution from that data point in the training objective for <em>M</em>. The objective of the weights is to improve the performance of <em>M</em> on future data.
</p>
<p>
In <a href="https://arxiv.org/abs/2212.05908">our work</a>, we describe how the helper model can be <em>meta-learned, </em>i.e., learned alongside <em>M</em> in a manner that helps the learning of the model <em>M</em> itself. A key design choice of the helper model is that we separated out instance- and age-related contributions in a factored manner. Specifically, we set the weight by combining contributions from multiple different fixed timescales of decay, and learn an approximate “assignment” of a given instance to its most suited timescales. We find in our experiments that this form of the helper model outperforms many other alternatives we considered, ranging from unconstrained joint functions to a single timescale of decay (exponential or linear), due to its combination of simplicity and expressivity. Full details may be found in the <a href="https://arxiv.org/abs/2212.05908">paper</a>.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>Instance weight scoring</h2>
<p>
The top figure below shows that our learned helper model indeed up-weights more modern-looking objects in the <a href="https://arxiv.org/abs/2201.06289">CLEAR object recognition challenge</a>; older-looking objects are correspondingly down-weighted. On closer examination (bottom figure below, gradient-based <a href="https://arxiv.org/abs/1610.02391">feature importance</a> assessment), we see that the helper model focuses on the primary object within the image, as opposed to, e.g., background features that may spuriously be correlated with instance age.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEggQImnpFiW7s3jeT9qoxQOM1kT8vIaHihnlAPusLRx8lJCaxyB7Lzhewn7J6qTiz9-qkWBJzzxLj-uHXhlB94WBMUVRsAgqZVBMBAnDaHGeCe6evZOo6hYgR5oXImP5vO9ZUNcF1q3Bpvau94hM9D71xwOGRqm9c8lJ6ixrB69w_JjneqW5JGcg_u6ZW2J/s1999/image1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="499" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEggQImnpFiW7s3jeT9qoxQOM1kT8vIaHihnlAPusLRx8lJCaxyB7Lzhewn7J6qTiz9-qkWBJzzxLj-uHXhlB94WBMUVRsAgqZVBMBAnDaHGeCe6evZOo6hYgR5oXImP5vO9ZUNcF1q3Bpvau94hM9D71xwOGRqm9c8lJ6ixrB69w_JjneqW5JGcg_u6ZW2J/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Sample images from the <a href="https://arxiv.org/abs/2201.06289">CLEAR</a> benchmark (camera & computer categories) assigned the highest and lowest weights respectively by our helper model.</td></tr></tbody></table>
<br>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhiKCafyxNrHJUkwV3KjoFMJk_v9WSPlzfMyYa-TZCODZdBNCnUOLOZogf9njyGQp_TWzCZ-a6-P5smLhSyeHVFd_jaSBbmS9soN5A5AF6oTq_OWvk-xOWgKaDCIFYz8mhe-GoVEZ56QSsIpKxDduNmCA0ORnf_kgW8ph0uZci8UBCQDBHs0j4Nq5hb5J7e/s1999/image5.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="339" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhiKCafyxNrHJUkwV3KjoFMJk_v9WSPlzfMyYa-TZCODZdBNCnUOLOZogf9njyGQp_TWzCZ-a6-P5smLhSyeHVFd_jaSBbmS9soN5A5AF6oTq_OWvk-xOWgKaDCIFYz8mhe-GoVEZ56QSsIpKxDduNmCA0ORnf_kgW8ph0uZci8UBCQDBHs0j4Nq5hb5J7e/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Feature importance analysis of our helper model on sample images from the <a href="https://arxiv.org/abs/2201.06289">CLEAR</a> benchmark.</td></tr></tbody></table>
<br>
<div style="line-height:40%;">
<br>
</div>
<h2>Results</h2>
<div style="line-height:40%;">
<br>
</div>
<h3>Gains on large-scale data </h3>
<p>
We first study the large-scale <a href="https://arxiv.org/abs/2108.09020">photo categorization task</a> (PCAT) on the <a href="https://arxiv.org/abs/1503.01817">YFCC100M dataset</a> discussed earlier, using the first five years of data for training and the next five years as test data. Our method (shown in red below) improves substantially over the no-reweighting baseline (black) as well as many other robust learning techniques. Interestingly, our method deliberately trades off accuracy on the distant past (training data unlikely to reoccur in the future) in exchange for marked improvements in the test period. Also, as desired, our method degrades less than other baselines in the test period.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgLKQZo3e80Ttgw64eHAndZZc6BMXKBNLXAPTQZDP1tsFEQZpGckd6fzqG0aC1x_b5HQmiYlp6AzgbQ3gYRGVcHEZvhnPiDVsl1rxKh3vjVtqXJd20xp5og5yowR2SmyvqNdhhaSuNT5IY_rm_SJanFAsM4jt1Pf_TChyphenhyphenK8y0mNi2Jji1oDWcSiH_7vaC7b/s800/image3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgLKQZo3e80Ttgw64eHAndZZc6BMXKBNLXAPTQZDP1tsFEQZpGckd6fzqG0aC1x_b5HQmiYlp6AzgbQ3gYRGVcHEZvhnPiDVsl1rxKh3vjVtqXJd20xp5og5yowR2SmyvqNdhhaSuNT5IY_rm_SJanFAsM4jt1Pf_TChyphenhyphenK8y0mNi2Jji1oDWcSiH_7vaC7b/s16000/image3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparison of our method and relevant baselines on the PCAT dataset.</td></tr></tbody></table>
<br>
<div style="line-height:40%;">
<br>
</div>
<h3>Broad applicability</h3>
<p>
We validated our findings on a wide range of nonstationary learning challenge datasets sourced from the academic literature (see <a href="https://arxiv.org/abs/2108.09020">1</a>, <a href="https://arxiv.org/abs/2201.06289">2</a>, <a href="https://arxiv.org/abs/2211.14238">3</a>, <a href="https://proceedings.mlr.press/v206/awasthi23b/awasthi23b.pdf">4</a> for details) that spans data sources and modalities (photos, satellite images, social media text, medical records, sensor readings, tabular data) and sizes (ranging from 10k to 39M instances). We report significant gains in the test period when compared to the nearest published benchmark method for each dataset (shown below). Note that the previous best-known method may be different for each dataset. These results showcase the broad applicability of our approach.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhw95hIflfZ4eiNddWi0-YXONJYbMLT2yHp_Ekzm8v5e1WHpxeT5v7k21EYihoAqrplmlrtM76iiHjuBWtMQDbtj7TvtwIU0eZb44_QSeEe5U4k_z70y_9SsS3If8Y5xkMXKQYI5VzaTafWC7nVv5MgvNw_yL8HA6N7-gUPGGcJI2qtgKTcnqn2oN1ruBt-/s765/image7.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="552" data-original-width="765" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhw95hIflfZ4eiNddWi0-YXONJYbMLT2yHp_Ekzm8v5e1WHpxeT5v7k21EYihoAqrplmlrtM76iiHjuBWtMQDbtj7TvtwIU0eZb44_QSeEe5U4k_z70y_9SsS3If8Y5xkMXKQYI5VzaTafWC7nVv5MgvNw_yL8HA6N7-gUPGGcJI2qtgKTcnqn2oN1ruBt-/s16000/image7.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Performance gain of our method on a variety of tasks studying natural concept drift. Our reported gains are over the previous best-known method for each dataset.</td></tr></tbody></table>
<br>
<div style="line-height:40%;">
<br>
</div>
<h3>Extensions to continual learning</h3>
<p>
Finally, we consider an interesting extension of our work. The work above described how offline learning can be extended to handle concept drift using ideas inspired by continual learning. However, sometimes offline learning is infeasible — for example, if the amount of training data available is too large to maintain or process. We adapted our approach to continual learning in a straightforward manner by applying temporal reweighting <em>within the context of </em>each bucket of data being used to sequentially update the model. This proposal still retains some limitations of continual learning, e.g., model updates are performed only on most-recent data, and all optimization decisions (including our reweighting) are only made over that data. Nevertheless, our approach consistently beats regular continual learning as well as a wide range of other continual learning algorithms on the photo categorization benchmark (see below). Since our approach is complementary to the ideas in many baselines compared here, we anticipate even larger gains when combined with them.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtWaqgT_9wt2sckjfrLbQ8LhRK5gL1yTowCf0h2nMnHhBYqfKP7VBwWfbK-5Y5zbYXiKoaF0TKve71FWrHazA4g4SPFD3leb56aZHex95MM_yovx2Y_uO4c5rOA5GzTndUGyBO4HH0gL3jYd8Jk4oPbi4HuSYDuMkKY5kPlqsb0s-re13QKfei2IrMig6S/s800/image6.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjtWaqgT_9wt2sckjfrLbQ8LhRK5gL1yTowCf0h2nMnHhBYqfKP7VBwWfbK-5Y5zbYXiKoaF0TKve71FWrHazA4g4SPFD3leb56aZHex95MM_yovx2Y_uO4c5rOA5GzTndUGyBO4HH0gL3jYd8Jk4oPbi4HuSYDuMkKY5kPlqsb0s-re13QKfei2IrMig6S/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Results of our method adapted to continual learning, compared to the latest baselines.</td></tr></tbody></table>
<br>
<div style="line-height:40%;">
<br>
</div>
<h2>Conclusion</h2>
<p>
We addressed the challenge of data drift in learning by combining the strengths of previous approaches — offline learning with its effective reuse of data, and continual learning with its emphasis on more recent data. We hope that our work helps improve model robustness to concept drift in practice, and generates increased interest and new ideas in addressing the ubiquitous problem of slow concept drift.
</p>
<div style="line-height:40%;">
<br>
</div>
<h2>Acknowledgements</h2>
<p>
<em>We thank Mike Mozer for many interesting discussions in the early phase of this work, as well as very helpful advice and feedback during its development.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-59333654601250947742024-02-13T14:11:00.000-08:002024-02-13T14:11:49.258-08:00DP-Auditorium: A flexible library for auditing differential privacy<span class="byline-author">Posted by Mónica Ribero Díaz, Research Scientist, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhNVpxjk-jj1rIYQ8AM3A-Syqxd3d8L8-wIy8NWwyobCXmTRK7mY9h94aJYgFCiC0gnehVFFoM8-in8HsOZjfhoNce03nbsrN5fxY07wADV6ULPC0POGmCc-8eL3OqA9KrDyzQxN38JKvh6xCmLV6FZ1g0UfaXtKORhtTy0WuJexlPqV6P2c9rPdg_W_5zP/s320/hero.jpg" style="display: none;" />
<p>
<a href="https://en.wikipedia.org/wiki/Differential_privacy">Differential privacy</a> (DP) is a property of randomized mechanisms that limit the influence of any individual user’s information while processing and analyzing data. DP offers a robust solution to address growing concerns about data protection, enabling technologies <a href="https://blog.research.google/2022/02/federated-learning-with-formal.html">across</a> <a href="https://www.apple.com/privacy/docs/Differential_Privacy_Overview.pdf">industries</a> and government applications (e.g., <a href="https://www.census.gov/programs-surveys/decennial-census/decade/2020/planning-management/process/disclosure-avoidance/differential-privacy.html">the US census</a>) without compromising individual user identities. As its adoption increases, it’s important to identify the potential risks of developing mechanisms with faulty implementations. Researchers have recently found errors in the mathematical proofs of private mechanisms, and their implementations. For example, <a href="https://arxiv.org/pdf/1603.01699.pdf">researchers compared</a> six sparse vector technique (SVT) variations and found that only two of the six actually met the asserted privacy guarantee. Even when mathematical proofs are correct, the code implementing the mechanism is vulnerable to human error.
</p>
<a name='more'></a>
<p>
However, practical and efficient DP auditing is challenging primarily due to the inherent randomness of the mechanisms and the probabilistic nature of the tested guarantees. In addition, a range of guarantee types exist, (e.g., <a href="https://dl.acm.org/doi/10.1007/11681878_14">pure DP</a>, <a href="https://link.springer.com/chapter/10.1007/11761679_29">approximate DP</a>, <a href="https://arxiv.org/abs/1702.07476">Rényi DP</a>, and <a href="https://arxiv.org/pdf/1603.01887.pdf">concentrated DP</a>), and this diversity contributes to the complexity of formulating the auditing problem. Further, debugging mathematical proofs and code bases is an intractable task given the volume of proposed mechanisms. While <em>ad hoc</em> testing techniques exist under specific assumptions of mechanisms, few efforts have been made to develop an extensible tool for testing DP mechanisms.
</p>
<p>
To that end, in “<a href="https://arxiv.org/abs/2307.05608">DP-Auditorium: A Large Scale Library for Auditing Differential Privacy</a>”, we introduce an <a href="https://github.com/google/differential-privacy/tree/main/python/dp_auditorium">open source library</a> for auditing DP guarantees with only black-box access to a mechanism (i.e., without any knowledge of the mechanism’s internal properties). DP-Auditorium is implemented in Python and provides a flexible interface that allows contributions to continuously improve its testing capabilities. We also introduce new testing algorithms that perform divergence optimization over function spaces for Rényi DP, pure DP, and approximate DP. We demonstrate that DP-Auditorium can efficiently identify DP guarantee violations, and suggest which tests are most suitable for detecting particular bugs under various privacy guarantees.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>DP guarantees</h2>
<p>
The output of a DP mechanism is a sample drawn from a probability distribution (<em>M</em> (<em>D</em>)) that satisfies a mathematical property ensuring the privacy of user data. A DP guarantee is thus tightly related to properties between pairs of probability distributions. A mechanism is differentially private if the probability distributions determined by <i>M</i> on dataset <em>D</em> and a neighboring dataset <em>D’</em>, which differ by only one record, are <em><a href="https://en.wikipedia.org/wiki/Computational_indistinguishability">indistinguishable</a></em> under a given divergence metric.
</p>
<p>
For example, the classical <a href="https://software.imdea.org/~federico/pubs/2013.ICALP.pdf">approximate DP</a> definition states that a mechanism is approximately DP with parameters (<em>ε</em>, <em>δ</em>) if the <a href="https://arxiv.org/pdf/1508.00335.pdf">hockey-stick divergence</a> of order <em>e<sup>ε</sup></em>, between <em>M</em>(<em>D) </em>and <em>M</em>(<em>D’</em>), is at most <em>δ</em>. Pure DP is a special instance of approximate DP where <em>δ = 0</em>. Finally, a mechanism is considered <a href="https://arxiv.org/abs/1702.07476">Rényi DP</a> with parameters (<em>𝛼</em>, <em>ε)</em> if the <a href="https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy">Rényi divergence</a> of order <em>𝛼</em>, is at most <em>ε</em> (where <em>ε</em> is a small positive value). In these three definitions, <em>ε </em>is not interchangeable but intuitively conveys the same concept; larger values of <em>ε</em> imply larger divergences between the two distributions or less privacy, since the two distributions are easier to distinguish.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>DP-Auditorium</h2>
<p>
DP-Auditorium comprises two main components: property testers and dataset finders. Property testers take samples from a mechanism evaluated on specific datasets as input and aim to identify privacy guarantee violations in the provided datasets. Dataset finders suggest datasets where the privacy guarantee may fail. By combining both components, DP-Auditorium enables (1) automated testing of diverse mechanisms and privacy definitions and, (2) detection of bugs in privacy-preserving mechanisms. We implement various private and non-private mechanisms, including simple mechanisms that compute the mean of records and more complex mechanisms, such as different SVT and <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">gradient descent</a> mechanism variants.
</p>
<p>
<strong>Property testers</strong> determine if evidence exists to reject the hypothesis that a given divergence between two probability distributions, <em>P</em> and <em>Q</em>, is bounded by a prespecified budget determined by the DP guarantee being tested. They compute a lower bound from samples from <em>P</em> and <em>Q,</em> rejecting the property if the lower bound value exceeds the expected divergence. No guarantees are provided if the result is indeed bounded. To test for a range of privacy guarantees, DP-Auditorium introduces three novel testers: (1) HockeyStickPropertyTester, (2) RényiPropertyTester, and (3) MMDPropertyTester. Unlike other approaches, these testers don’t depend on explicit histogram approximations of the tested distributions. They rely on variational representations of the hockey-stick divergence, Rényi divergence, and <a href="https://jmlr.csail.mit.edu/papers/v13/gretton12a.html">maximum mean discrepancy</a> (MMD) that enable the estimation of divergences through optimization over function spaces. As a baseline, we implement <a href="https://arxiv.org/abs/1806.06427">HistogramPropertyTester</a>, a commonly used approximate DP tester. While our three testers follow a similar approach, for brevity, we focus on the HockeyStickPropertyTester in this post.
</p>
<p>
Given two neighboring datasets, <em>D</em> and <em>D’</em>, the HockeyStickPropertyTester finds a lower bound,<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i> for the hockey-stick divergence between <em>M</em>(<em>D) </em>and <em>M</em>(<em>D’</em>) that holds with high probability. Hockey-stick divergence enforces that the two distributions <em>M</em>(<em>D) </em>and <em>M</em>(<em>D’</em>) are close under an approximate DP guarantee. Therefore, if a privacy guarantee claims that the hockey-stick divergence is at most <em>δ</em>, and<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i> > <em>δ</em>, then with high probability the divergence is higher than what was promised on <em>D</em> and <em>D’</em> and the mechanism cannot satisfy the given approximate DP guarantee. The lower bound<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i> is computed as an empirical and tractable counterpart of a variational formulation of the hockey-stick divergence (see <a href="https://arxiv.org/pdf/2307.05608.pdf">the paper</a> for more details). The accuracy of<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i> increases with the number of samples drawn from the mechanism, but decreases as the variational formulation is simplified. We balance these factors in order to ensure that<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i> is both accurate and easy to compute.
</p>
<p>
<strong>Dataset finders</strong> use <a href="https://arxiv.org/pdf/2207.13676.pdf">black-box optimization</a> to find datasets <em>D</em> and <em>D’</em> that maximize<i><span style="bottom: 9px; left: 9px; position: relative; transfrom: scale(4,0.5);">^</span>δ</i>, a lower bound on the divergence value <em>δ</em>. Note that black-box optimization techniques are specifically designed for settings where deriving gradients for an objective function may be impractical or even impossible. These optimization techniques oscillate between exploration and exploitation phases to estimate the shape of the objective function and predict areas where the objective can have optimal values. In contrast, a full exploration algorithm, such as the <a href="https://en.wikipedia.org/wiki/Hyperparameter_optimization#Grid_search">grid search method</a>, searches over the full space of neighboring datasets <em>D</em> and <em>D’</em>. DP-Auditorium implements different dataset finders through the open sourced black-box optimization library <a href="https://github.com/google/vizier">Vizier</a>.
</p>
<p>
Running existing components on a new mechanism only requires defining the mechanism as a Python function that takes an array of data <em>D</em> and a desired number of samples <em>n</em> to be output by the mechanism computed on <em>D</em>. In addition, we provide flexible wrappers for testers and dataset finders that allow practitioners to implement their own testing and dataset search algorithms.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Key results</h2>
<p>
We assess the effectiveness of DP-Auditorium on five private and nine non-private mechanisms with diverse output spaces. For each property tester, we repeat the test ten times on fixed datasets using different values of <em>ε</em>, and report the number of times each tester identifies privacy bugs. While no tester consistently outperforms the others, we identify bugs that would be missed by previous techniques (HistogramPropertyTester). Note that the HistogramPropertyTester is not applicable to SVT mechanisms.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlLYAUJ1cew8xCQNyNMvggKZ2c2bd5uHLzUdLx3xVdn_TW4ZBwd5tCI6zVVvVjmOWKJanJ4vP4swXOzNpZ4388x-iwISjqAzxnDAgM8F4-HL5gHLAGs3AIuqhns-gNJfA_AT9lmAMvItLRDEP5OjHPRFRA6OldJrY6Yost66LZ8Zsif8wIw6Uhkfa4PkN7/s785/image22.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="409" data-original-width="785" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhlLYAUJ1cew8xCQNyNMvggKZ2c2bd5uHLzUdLx3xVdn_TW4ZBwd5tCI6zVVvVjmOWKJanJ4vP4swXOzNpZ4388x-iwISjqAzxnDAgM8F4-HL5gHLAGs3AIuqhns-gNJfA_AT9lmAMvItLRDEP5OjHPRFRA6OldJrY6Yost66LZ8Zsif8wIw6Uhkfa4PkN7/s16000/image22.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Number of times each property tester finds the privacy violation for the tested non-private mechanisms. NonDPLaplaceMean and NonDPGaussianMean mechanisms are faulty implementations of the <a href="https://en.wikipedia.org/wiki/Additive_noise_differential_privacy_mechanisms#Laplace_Mechanism">Laplace</a> and <a href="https://en.wikipedia.org/wiki/Additive_noise_differential_privacy_mechanisms#Gaussian_Mechanism">Gaussian</a> mechanisms for computing the mean.</td></tr></tbody></table>
<br />
<p>
We also analyze the implementation of a <a href="https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py">DP gradient descent algorithm</a> (DP-GD) in TensorFlow that computes gradients of the loss function on private data. To preserve privacy, DP-GD employs a clipping mechanism to bound the <a href="https://mathworld.wolfram.com/L2-Norm.html">l2-norm</a> of the gradients by a value <em>G</em>, followed by the addition of Gaussian noise. This implementation incorrectly assumes that the noise added has a scale of <em>G</em>, while in reality, the scale is <em>sG</em>, where <em>s</em> is a positive scalar. This discrepancy leads to an approximate DP guarantee that holds only for values of <em>s</em> greater than or equal to 1.
</p>
<p>
We evaluate the effectiveness of property testers in detecting this bug and show that HockeyStickPropertyTester and RényiPropertyTester exhibit superior performance in identifying privacy violations, outperforming MMDPropertyTester and HistogramPropertyTester. Notably, these testers detect the bug even for values of <em>s</em> as high as 0.6. It is worth highlighting that <em>s </em>= 0.5 corresponds to a <a href="https://github.com/tensorflow/privacy/blob/308cbda4db6ccad5d1e7d56248727274e4c0c79e/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy_lib.py#L445C1-L446C1">common error</a> in literature that involves missing a factor of two when accounting for the privacy budget <em>ε</em>. DP-Auditorium successfully captures this bug as shown below. For more details see section 5.6 <a href="https://arxiv.org/pdf/2303.00654.pdf">here</a>.
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-pnMcLqTWv1vSIZWncvObk3acW_SkBS3Lp_KuspJPbGBSjlepwW0hTLkCgLA7yTgU35y-Kj4HC_ddRX1fXS6T_HoF5Na87cSIcdiTBAwHnQ1sQZV3pdir_SI5PuwT7HAMEYmQohCd7wI84bNjKSt4sUVdnk9dOAXtkxCUDgzd3KZs5r2G2Z4jIZR0-FJH/s836/image21.jpg" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="332" data-original-width="836" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-pnMcLqTWv1vSIZWncvObk3acW_SkBS3Lp_KuspJPbGBSjlepwW0hTLkCgLA7yTgU35y-Kj4HC_ddRX1fXS6T_HoF5Na87cSIcdiTBAwHnQ1sQZV3pdir_SI5PuwT7HAMEYmQohCd7wI84bNjKSt4sUVdnk9dOAXtkxCUDgzd3KZs5r2G2Z4jIZR0-FJH/s16000/image21.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Estimated divergences and test thresholds for different values of <em>s</em> when testing DP-GD with the HistogramPropertyTester (<strong>left</strong>) and the HockeyStickPropertyTester (<strong>right</strong>).</td></tr></tbody></table>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEibbce0TFnWcnJ4CoXPVVyuZrja_3JJTnBjsza7Ig-NibA14jHoh4TIuIhLRn9BgCdo_N4hSuft7Zpl3WgNjmteMUGkQ5xdjeFH2SzZlKmPR_PvXS-JeOIcwJO8J_h7SlR9_tknZ0fLbP2qOypalwVm-nZO118Oa67zgdi_VGc72tAzGKaYpGoWIl6p_ljD/s828/image20.jpg" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="333" data-original-width="828" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEibbce0TFnWcnJ4CoXPVVyuZrja_3JJTnBjsza7Ig-NibA14jHoh4TIuIhLRn9BgCdo_N4hSuft7Zpl3WgNjmteMUGkQ5xdjeFH2SzZlKmPR_PvXS-JeOIcwJO8J_h7SlR9_tknZ0fLbP2qOypalwVm-nZO118Oa67zgdi_VGc72tAzGKaYpGoWIl6p_ljD/s16000/image20.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Estimated divergences and test thresholds for different values of <em>s</em> when testing DP-GD with the RényiPropertyTester (<strong>left</strong>) and the MMDPropertyTester (<strong>right</strong>)</td></tr></tbody></table>
<br />
<p>
To test dataset finders, we compute the number of datasets explored before finding a privacy violation. On average, the majority of bugs are discovered in less than 10 calls to dataset finders. Randomized and exploration/exploitation methods are more efficient at finding datasets than grid search. For more details, see the <a href="https://arxiv.org/abs/2307.05608">paper</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
DP is one of the most powerful frameworks for data protection. However, proper implementation of DP mechanisms can be challenging and prone to errors that cannot be easily detected using traditional unit testing methods. A unified testing framework can help auditors, regulators, and academics ensure that private mechanisms are indeed private.
</p>
<p>
DP-Auditorium is a new approach to testing DP via divergence optimization over function spaces. Our results show that this type of function-based estimation consistently outperforms previous black-box access testers. Finally, we demonstrate that these function-based estimators allow for a better discovery rate of privacy bugs compared to histogram estimation. By <a href="https://github.com/google/differential-privacy/tree/main/python/dp_auditorium">open sourcing</a> DP-Auditorium, we aim to establish a standard for end-to-end testing of new differentially private algorithms.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>The work described here was done jointly with Andrés Muñoz Medina, William Kong and Umar Syed. We thank Chris Dibak and Vadym Doroshenko for helpful engineering support and interface suggestions for our library.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-32646941557106473102024-02-06T11:17:00.000-08:002024-02-06T11:17:53.968-08:00Graph neural networks in TensorFlow<span class="byline-author">Posted by Dustin Zelle, Software Engineer, Google Research, and Arno Eigenwillig, Software Engineer, CoreML</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhcnTwrjg8cyZhVY1c-qi2ZEenIrDlkmlKlX0GsAuiKiIoxUu6i-phANh8tsCG4mUm5i-7t3zdLwuwn5DCcuQI5FKq-C3eibPnuqfoLuKFUsx-I3Ovim1Teps_JKiKZH7XqgHupnsOa2Y3peUgWcPNYG4ZIqA2_KQwxJpflo0WM6gNW8tXg5eDndiWx_dKK/s1600/TFGNN%20hero.gif" style="display: none;" />
<p>
Objects and their relationships are ubiquitous in the world around us, and relationships can be as important to understanding an object as its own attributes viewed in isolation — take for example transportation networks, production networks, knowledge graphs, or social networks. Discrete mathematics and computer science have a long history of formalizing such networks as <em><a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)">graphs</a></em>, consisting of <em>nodes</em> connected by <em>edges</em> in various irregular ways. Yet most machine learning (ML) algorithms allow only for regular and uniform relations between input objects, such as a grid of pixels, a sequence of words, or no relation at all.
</p>
<a name='more'></a>
<p>
<a href="https://distill.pub/2021/gnn-intro/">Graph neural networks</a>, or GNNs for short, have emerged as a powerful technique to leverage both the graph’s connectivity (as in the older algorithms <a href="http://perozzi.net/projects/deepwalk/">DeepWalk</a> and <a href="https://snap.stanford.edu/node2vec/">Node2Vec</a>) and the input features on the various nodes and edges. GNNs can make predictions for graphs as a whole (Does this molecule react in a certain way?), for individual nodes (What’s the topic of this document, given its citations?) or for potential edges (Is this product likely to be purchased together with that product?). Apart from making predictions about graphs, GNNs are a powerful tool used to bridge the chasm to more typical neural network use cases. They encode a graph's <em>discrete</em>, <em>relational</em> information in a <em>continuous</em> way so that it can be included naturally in another deep learning system.
</p>
<p>
We are excited to announce the release of <a href="https://github.com/tensorflow/gnn">TensorFlow GNN 1.0</a> (TF-GNN), a production-tested library for building GNNs at large scales. It supports both modeling and training in TensorFlow as well as the extraction of input graphs from huge data stores. TF-GNN is built from the ground up for heterogeneous graphs, where types of objects and relations are represented by distinct sets of nodes and edges. Real-world objects and their relations occur in distinct types, and TF-GNN's heterogeneous focus makes it natural to represent them.
</p>
<p>
Inside TensorFlow, such graphs are represented by objects of type <code>tfgnn.GraphTensor</code>. This is a composite tensor type (a collection of tensors in one Python class) accepted as a <a href="https://en.wikipedia.org/wiki/First-class_citizen">first-class citizen</a> in <code>tf.data.Dataset</code>, <code>tf.function</code>, etc. It stores both the graph structure and its features attached to nodes, edges and the graph as a whole. Trainable transformations of GraphTensors can be defined as Layers objects in the high-level <a href="https://www.tensorflow.org/guide/keras">Keras API</a>, or directly using the <code>tfgnn.GraphTensor</code> primitive.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>GNNs: Making predictions for an object in context</h2>
<p>
For illustration, let’s look at one typical application of TF-GNN: predicting a property of a certain type of node in a graph defined by cross-referencing tables of a huge database. For example, a citation database of Computer Science (CS) arXiv papers with one-to-many cites and many-to-one cited relationships where we would like to predict the subject area of each paper.
</p>
<p>
Like most neural networks, a GNN is trained on a dataset of many labeled examples (~millions), but each training step consists only of a much smaller batch of training examples (say, hundreds). To scale to millions, the GNN gets trained on a stream of reasonably small subgraphs from the underlying graph. Each subgraph contains enough of the original data to compute the GNN result for the labeled node at its center and train the model. This process — typically referred to as subgraph sampling — is extremely consequential for GNN training. Most existing tooling accomplishes sampling in a batch way, producing static subgraphs for training. TF-GNN provides tooling to improve on this by sampling dynamically and interactively.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhE36FVnslwVrX4LjLgpe5NOcVgJ2WSHCaw64LT9pMhjhHOFt-1pjp1AhaXqjxfEODX04Buw93D1G36HOStu5_mWUEdNs0gZTa1c7MXJ6ir9DYOp_HCYpFMT5NZiBbHxNwvUmF-dwhN2rgKQX0CeFY25X9aFnoD0W7bzL_xtkDJFdP0guocAJDSOgBHIiZm/s800/image2.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="600" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhE36FVnslwVrX4LjLgpe5NOcVgJ2WSHCaw64LT9pMhjhHOFt-1pjp1AhaXqjxfEODX04Buw93D1G36HOStu5_mWUEdNs0gZTa1c7MXJ6ir9DYOp_HCYpFMT5NZiBbHxNwvUmF-dwhN2rgKQX0CeFY25X9aFnoD0W7bzL_xtkDJFdP0guocAJDSOgBHIiZm/s16000/image2.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.</td></tr></tbody></table>
<p>
TF-GNN 1.0 debuts a flexible Python API to configure dynamic or batch subgraph sampling at all relevant scales: interactively in a Colab notebook (like <a href="https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb">this one</a>), for efficient sampling of a small dataset stored in the main memory of a single training host, or distributed by <a href="https://beam.apache.org/">Apache Beam</a> for huge datasets stored on a network filesystem (up to hundreds of millions of nodes and billions of edges). For details, please refer to our user guides for <a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/inmemory_sampler.md">in-memory</a> and <a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/beam_sampler.md">beam-based</a> sampling, respectively.
</p>
<p>
On those same sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node; the hidden state aggregates and encodes the relevant information of the root node's neighborhood. One classical approach is <a href="https://research.google/pubs/neural-message-passing-for-quantum-chemistry/">message-passing neural networks</a>. In each round of message passing, nodes receive messages from their neighbors along incoming edges and update their own hidden state from them. After <em>n</em> rounds, the hidden state of the root node reflects the aggregate information from all nodes within <em>n</em> edges (pictured below for <em>n</em> = 2). The messages and the new hidden states are computed by hidden layers of the neural network. In a heterogeneous graph, it often makes sense to use separately trained hidden layers for the different types of nodes and edges
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMrCrQ1SCcwhZfE33X46EifocYAmKCPXMVe1d4na1V6flQavJ_f_FKtnlQbe2vnvzbSEtx5mxJHZ2OlQbO9rsiEhiPLY1PKQOT-EwahobMIVC92PZJs8RroEuYswHCpEjjpwqPrpqzKsDgrNaiY4lM_E8NVnxVRsYn0PNxe3TghByKJpW9V_YRD0RnNnm4/s573/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="511" data-original-width="573" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMrCrQ1SCcwhZfE33X46EifocYAmKCPXMVe1d4na1V6flQavJ_f_FKtnlQbe2vnvzbSEtx5mxJHZ2OlQbO9rsiEhiPLY1PKQOT-EwahobMIVC92PZJs8RroEuYswHCpEjjpwqPrpqzKsDgrNaiY4lM_E8NVnxVRsYn0PNxe3TghByKJpW9V_YRD0RnNnm4/s16000/image1.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made.</td></tr></tbody></table>
<p>
The training setup is completed by placing an output layer on top of the GNN’s hidden state for the labeled nodes, computing the <em>loss </em>(to measure the prediction error), and updating model weights by backpropagation, as usual in any neural network training.
</p>
<p>
Beyond supervised training (i.e., minimizing a loss defined by labels), GNNs can also be trained in an unsupervised way (i.e., without labels). This lets us compute a <em>continuous</em> representation (or <em>embedding</em>) of the <em>discrete</em> graph structure of nodes and their features. These representations are then typically utilized in other ML systems. In this way, the discrete, relational information encoded by a graph can be included in more typical neural network use cases. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Building GNN architectures</h2>
<p>
The TF-GNN library supports building and training GNNs at various levels of abstraction.
</p>
<p>
At the highest level, users can take any of the predefined models bundled with the library that are expressed in Keras layers. Besides a small collection of models from the research literature, TF-GNN comes with a highly configurable model template that provides a curated selection of modeling choices that we have found to provide strong baselines on many of our in-house problems. The templates implement GNN layers; users need only to initialize the Keras layers.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMfB8QoX14UU1GEAmFFOP0cAj__zxa_MKzVSiJoak9cVLNdbbhrSxbIWhqQM3OYKA5lo7zW8sWr6-9utm-rw0808rBOE4Cbw7NZxcmifenvF6DCH4opWhVQJHR-MLGcFoNu_WpET5h1PZRdXMhjcyKgBg3NchNTPq6gWVVluzcQNaO5qtonVp5KnJRgUaD/s1400/TFGNN%20code1.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="865" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjMfB8QoX14UU1GEAmFFOP0cAj__zxa_MKzVSiJoak9cVLNdbbhrSxbIWhqQM3OYKA5lo7zW8sWr6-9utm-rw0808rBOE4Cbw7NZxcmifenvF6DCH4opWhVQJHR-MLGcFoNu_WpET5h1PZRdXMhjcyKgBg3NchNTPq6gWVVluzcQNaO5qtonVp5KnJRgUaD/s16000/TFGNN%20code1.png" /></a></td></tr></tbody></table>
<p>
At the lowest level, users can write a GNN model from scratch in terms of primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges or pooling data into a node from all its incoming edges (e.g., computing the sum of incoming messages). TF-GNN’s graph data model treats nodes, edges and whole input graphs equally when it comes to features or hidden states, making it straightforward to express not only node-centric models like the MPNN discussed above but also more general forms of <a href="https://arxiv.org/abs/1806.01261">GraphNets</a>. This can, but need not, be done with Keras as a modeling framework on the top of core TensorFlow. For more details, and intermediate levels of modeling, see the TF-GNN <a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md">user guide</a> and <a href="https://github.com/tensorflow/gnn/tree/main/tensorflow_gnn/models">model collection</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Training orchestration</h2>
<p>
While advanced users are free to do custom model training, the <a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/runner.md">TF-GNN Runner</a> also provides a succinct way to orchestrate the training of Keras models in the common cases. A simple invocation may look like this:
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxRRMrWL-AyxpHeyAhffhApAzlq-u7FoZaDnZFlwRsoYCljzZNi0LmRDDMwZ7mkXeBK0oUFujf_TDD-zlTQcgnLGhPedfrJ2vVs-D5-RPZFWXaaRpOJIt-MH3N8Tj7NZy-SFXTjxjDrhHQY_HVUA3-_C8_xQjfRWBlO-dzcFzgUL6wynMWJhUM7z_MYKvF/s1400/TFGNN%20code2.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="508" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxRRMrWL-AyxpHeyAhffhApAzlq-u7FoZaDnZFlwRsoYCljzZNi0LmRDDMwZ7mkXeBK0oUFujf_TDD-zlTQcgnLGhPedfrJ2vVs-D5-RPZFWXaaRpOJIt-MH3N8Tj7NZy-SFXTjxjDrhHQY_HVUA3-_C8_xQjfRWBlO-dzcFzgUL6wynMWJhUM7z_MYKvF/s16000/TFGNN%20code2.png" /></a></td></tr></tbody></table>
<p>
The Runner provides ready-to-use solutions for ML pains like distributed training and <code>tfgnn.GraphTensor</code> padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks:
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg4GGpfZib5MAUnX7BRLywJC4xMVt9Tz8kSMhgyDGN5A-aS9k-gna_t0Fo3uxMaAb8gK0ovrOO3XkeSNZ3i24leBCNsALR2NU_MWI7M_s47p2bx-aviaUKy_DxDEkzndNYMI_52jcEmNKyJrqDFye3_PHaWJZz7MAQ1lVW-YpuWPOOYpSAfbrunU5q4M2ev/s1400/TFGNN%20code3.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="392" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg4GGpfZib5MAUnX7BRLywJC4xMVt9Tz8kSMhgyDGN5A-aS9k-gna_t0Fo3uxMaAb8gK0ovrOO3XkeSNZ3i24leBCNsALR2NU_MWI7M_s47p2bx-aviaUKy_DxDEkzndNYMI_52jcEmNKyJrqDFye3_PHaWJZz7MAQ1lVW-YpuWPOOYpSAfbrunU5q4M2ev/s16000/TFGNN%20code3.png" /></a></td></tr></tbody></table>
<p>
Additionally, the TF-GNN Runner also includes an implementation of <a href="https://www.tensorflow.org/tutorials/interpretability/integrated_gradients">integrated gradients</a> for use in model attribution. Integrated gradients output is a GraphTensor with the same connectivity as the observed GraphTensor but its features replaced with gradient values where larger values contribute more than smaller values in the GNN prediction. Users can inspect gradient values to see which features their GNN uses the most.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
In short, we hope TF-GNN will be useful to advance the application of GNNs in TensorFlow at scale and fuel further innovation in the field. If you’re curious to find out more, please try our <a href="https://colab.sandbox.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb">Colab demo</a> with the popular OGBN-MAG benchmark (in your browser, no installation required), browse the rest of our <a href="https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/overview.md">user guides and Colabs</a>, or take a look at our <a href="https://arxiv.org/abs/2207.03522">paper</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>The TF-GNN release 1.0 was developed by a collaboration between Google Research: Sami Abu-El-Haija, Neslihan Bulut, Bahar Fatemi, Johannes Gasteiger, Pedro Gonnet, Jonathan Halcrow, Liangze Jiang, Silvio Lattanzi, Brandon Mayer, Vahab Mirrokni, Bryan Perozzi, Anton Tsitsulin, Dustin Zelle, Google Core ML: Arno Eigenwillig, Oleksandr Ferludin, Parth Kothari, Mihir Paradkar, Jan Pfeifer, Rachael Tamakloe, and Google DeepMind:<strong> </strong>Alvaro Sanchez-Gonzalez and Lisa Wang.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-69567679206129147062024-02-02T11:07:00.000-08:002024-02-07T16:05:00.722-08:00A decoder-only foundation model for time-series forecasting<span class="byline-author">Posted by Rajat Sen and Yichen Zhou, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgjLAVI4q3e6yNyTPTCFiLZVQfFm71GOX1TosHg_Sb8M6tVSO1hyphenhyphenZccOlufnqSuXP1rVWHmqHcely6fgW1vex4JdxenniJcaJ7TOomZolUFut8RUdxnOFZDrbt0hrIHkcrK7rl6cq5-kUuWGrOYqIirPAKtnf4vMDauPX4lFAz2PQjiqzqHxMna7eja9gOF/s320/hero.jpg" style="display: none;" />
<p>
<a href="https://en.wikipedia.org/wiki/Time_series">Time-series</a> forecasting is ubiquitous in various domains, such as retail, finance, manufacturing, healthcare and natural sciences. In retail use cases, for example, it has been observed that <a href="https://www.mckinsey.com/featured-insights/artificial-intelligence/notes-from-the-ai-frontier-applications-and-value-of-deep-learning">improving demand forecasting accuracy</a> can meaningfully reduce inventory costs and increase revenue. Deep learning (DL) models have emerged as a popular approach for forecasting rich, multivariate, time-series data because they have proven to perform well in a variety of settings (e.g., DL models performed well in the <a href="https://www.sciencedirect.com/science/article/pii/S0169207021001874">M5 competition</a>).
</p>
<a name='more'></a>
<p>
At the same time, there has been rapid progress in large foundation language models used for natural language processing (NLP) tasks, such as <a href="https://en.wikipedia.org/wiki/Machine_translation">translation</a>, <a href="https://www.analyticsvidhya.com/blog/2023/09/retrieval-augmented-generation-rag-in-ai/">retrieval-augmented generation</a>, and <a href="https://en.wikipedia.org/wiki/Intelligent_code_completion">code completion</a>. These models are trained on massive amounts of <em>textual </em>data derived from a variety of sources like <a href="https://commoncrawl.org/">common crawl</a> and open-source code that allows them to identify patterns in languages. This makes them very powerful <a href="https://en.wikipedia.org/wiki/Zero-shot_learning">zero-shot</a> tools; for instance, <a href="https://blog.google/products/bard/google-bard-try-gemini-ai/">when paired with retrieval</a>, they can answer questions about and summarize current events.
</p>
<p>
Despite DL-based forecasters largely <a href="https://arxiv.org/abs/1704.04110">outperforming</a> traditional methods and progress being made in <a href="https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-forecasting">reducing training and inference costs</a>, they face challenges: most DL architectures require <a href="https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-forecasting">long and involved training and validation cycles</a> before a customer can test the model on a new time-series. A foundation model for time-series forecasting, in contrast, can provide decent out-of-the-box forecasts on unseen time-series data with no additional training, enabling users to focus on refining forecasts for the actual downstream task like <a href="https://en.wikipedia.org/wiki/Customer_demand_planning">retail demand planning</a>.
</p>
<p>
To that end, in “<a href="https://arxiv.org/pdf/2310.10688.pdf">A decoder-only foundation model for time-series forecasting</a>”, we introduce TimesFM, a single forecasting model pre-trained on a large time-series corpus of 100 billion real world time-points. Compared to the latest large language models (LLMs), TimesFM is much smaller (200M parameters), yet we show that even at such scales, its zero-shot performance on a variety of unseen datasets of different domains and temporal granularities come close to the state-of-the-art supervised approaches trained explicitly on these datasets. Later this year we plan to make this model available for external customers in <a href="https://cloud.google.com/vertex-ai/docs/tabular-data/forecasting/train-model#aiplatform_create_training_pipeline_tabular_forecasting_sample-python">Google Cloud Vertex AI</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>A decoder-only foundation model for time-series forecasting</h2>
<p>
LLMs are usually trained in a <a href="https://arxiv.org/pdf/1801.10198.pdf">decoder-only</a> fashion that involves three steps. First, text is broken down into subwords called tokens. Then, the tokens are fed into stacked causal <a href="https://arxiv.org/abs/1706.03762">transformer</a> layers that produce an output corresponding to each input token (it cannot attend to future tokens). Finally, the output corresponding to the <em>i</em>-th token summarizes all the information from previous tokens and predicts the (<em>i</em>+1)-th token. During inference, the LLM generates the output one token at a time. For example, when prompted with “What is the capital of France?”, it might generate the token “The”, then condition on “What is the capital of France? The” to generate the next token “capital” and so on until it generates the complete answer: “The capital of France is Paris”.
</p>
<p>
A foundation model for time-series forecasting should adapt to variable context (what we observe) and horizon (what we query the model to forecast) lengths, while having enough capacity to encode all patterns from a large pretraining dataset. Similar to LLMs, we use stacked transformer layers (self-attention and <a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">feedforward</a> layers) as the main building blocks for the TimesFM model. In the context of time-series forecasting, we treat a patch (a group of contiguous time-points) as a token that was popularized by a recent <a href="https://arxiv.org/abs/2211.14730">long-horizon forecasting work</a>. The task then is to forecast the (<em>i</em>+1)-th patch of time-points given the <em>i</em>-th output at the end of the stacked transformer layers.
</p>
<p>
However, there are several key differences from language models. Firstly, we need a <a href="https://en.wikipedia.org/wiki/Multilayer_perceptron">multilayer perceptron</a> block with residual connections to convert a patch of time-series into a token that can be input to the transformer layers along with <a href="https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/">positional encodings</a> (PE). For that, we use a residual block similar to our prior work in <a href="https://arxiv.org/abs/2304.08424">long-horizon forecasting</a>. Secondly, at the other end, an output token from the stacked transformer can be used to predict a longer length of subsequent time-points than the input patch length, i.e., the output patch length can be larger than the input patch length.
</p>
<p>
Consider a time-series of length 512 time-points being used to train a TimesFM model with input patch length 32 and output patch length 128. During training, the model is simultaneously trained to use the first 32 time-points to forecast the next 128 time-points, the first 64 time-points to forecast time-points 65 to 192, the first 96 time-points to forecast time-points 97 to 224 and so on. During inference, suppose the model is given a new time-series of length 256 and tasked with forecasting the next 256 time-points into the future. The model will first generate the future predictions for time-points 257 to 384, then condition on the initial 256 length input plus the generated output to generate time-points 385 to 512. On the other hand, if in our model the output patch length was equal to the input patch length of 32 then for the same task we would have to go through eight generation steps instead of just the two above. This increases the chances of more errors accumulating and therefore, in practice, we see that a longer output patch length yields better performance for long-horizon forecasting
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj4G0lBOLUqlPIXJ3R68kjS984MBIKBPDBrCWtgmjVVTyQRqY6-rn3aHJjgxCbG-8csyBLsp0POILdeJ2VcsRy8lrip0k5DWsUpuL9LU1qOPXLW99mraNdd6HVU791NYqJeTyY7LjuMnOIo6RGmkxBQqqaPrSsC0dELrwy21QUs1Jgwxr8flmdNkDV2tZsT/s1084/image3.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="674" data-original-width="1084" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj4G0lBOLUqlPIXJ3R68kjS984MBIKBPDBrCWtgmjVVTyQRqY6-rn3aHJjgxCbG-8csyBLsp0POILdeJ2VcsRy8lrip0k5DWsUpuL9LU1qOPXLW99mraNdd6HVU791NYqJeTyY7LjuMnOIo6RGmkxBQqqaPrSsC0dELrwy21QUs1Jgwxr8flmdNkDV2tZsT/s16000/image3.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">TimesFM architecture.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Pretraining data</h2>
<p>
Just like LLMs get better with more tokens, TimesFM requires a large volume of legitimate time series data to learn and improve. We have spent a great amount of time creating and assessing our training datasets, and the following is what we have found works best:
</p>
<div style="margin-left: 40px;">
<p>
<strong>Synthetic data helps with the basics.</strong> Meaningful synthetic time-series data can be generated using statistical models or physical simulations. These basic temporal patterns can teach the model the grammar of time series forecasting.
</p></div>
<div style="margin-left: 40px;">
<p>
<strong>Real-world data adds real-world flavor.</strong> We comb through available public time series datasets, and selectively put together a large corpus of 100 billion time-points. Among these datasets there are <a href="https://trends.google.com/trends/">Google Trends</a> and <a href="https://meta.wikimedia.org/wiki/Research:Page_view">Wikipedia Pageviews</a>, which track what people are interested in, and that nicely mirrors trends and patterns in many other real-world time series. This helps TimesFM understand the bigger picture and generalize better when provided with domain-specific contexts not seen during training.
</p></div>
<div style="line-height: 40%;">
<br />
</div>
<h2>Zero-shot evaluation results</h2>
<p>
We evaluate TimesFM zero-shot on data not seen during training using popular time-series benchmarks. We observe that TimesFM performs better than most statistical methods like <a href="https://en.wikipedia.org/wiki/Autoregressive_integrated_moving_average">ARIMA</a>, <a href="https://en.wikipedia.org/wiki/Exponential_smoothing">ETS</a> and can match or outperform powerful DL models like <a href="https://arxiv.org/abs/1704.04110">DeepAR</a>, <a href="https://arxiv.org/abs/2211.14730">PatchTST</a> that have been <em>explicitly trained</em> on the target time-series.
</p>
<p>
We used the <a href="https://huggingface.co/datasets/monash_tsf">Monash Forecasting Archive</a> to evaluate TimesFM’s out-of-the-box performance. This archive contains tens of thousands of time-series from various domains like traffic, weather, and demand forecasting covering frequencies ranging from few minutes to yearly data. Following existing literature, we inspect the <a href="https://en.wikipedia.org/wiki/Mean_absolute_error">mean absolute error</a> (MAE) <a href="https://arxiv.org/abs/2310.07820">appropriately scaled</a> so that it can be averaged across the datasets. We see that zero-shot (ZS) TimesFM is better than most supervised approaches, including recent deep learning models. We also compare TimesFM to <a href="https://platform.openai.com/docs/models/gpt-3-5">GPT-3.5</a> for forecasting using a specific prompting technique proposed by <a href="https://arxiv.org/abs/2310.07820">llmtime(ZS)</a>. We demonstrate that TimesFM performs better than llmtime(ZS) despite being orders of magnitude smaller.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhIeNF6GcmbUvVvYpKxNSvwlm_swz6M3G7nTDl0INa2zq8AlvjTBCVuvwOw0dx48JCk4H3S0aBUcsvqj2BypV3340cblqgD6yktoLBXzpxA2fwoM4n_KU8m0TfaESjihc3nx29RYVTpO4g09RCK-rucPulH3gqEOU9jO7EZ_VbDcFnfB_RHXmdpuZO_T_-g/s1476/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="876" data-original-width="1476" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhIeNF6GcmbUvVvYpKxNSvwlm_swz6M3G7nTDl0INa2zq8AlvjTBCVuvwOw0dx48JCk4H3S0aBUcsvqj2BypV3340cblqgD6yktoLBXzpxA2fwoM4n_KU8m0TfaESjihc3nx29RYVTpO4g09RCK-rucPulH3gqEOU9jO7EZ_VbDcFnfB_RHXmdpuZO_T_-g/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Scaled MAE (the lower the better) of TimesFM(ZS) against other supervised and zero-shot approaches on Monash datasets.</td></tr></tbody></table>
<br>
<p>
Most of the Monash datasets are short or medium horizon, i.e., the prediction length is not too long. We also test TimesFM on popular benchmarks for long horizon forecasting against a recent state-of-the-art baseline <a href="https://arxiv.org/abs/2211.14730">PatchTST</a> (and other long-horizon forecasting baselines). In the next figure, we plot the MAE on <a href="https://paperswithcode.com/dataset/ett">ETT</a> datasets for the task of predicting 96 and 192 time-points into the future. The metric has been calculated on the last test window of each dataset (as done by the <a href="https://arxiv.org/abs/2310.07820">llmtime</a> paper). We see that TimesFM not only surpasses the performance of llmtime(ZS) but also matches that of the supervised PatchTST model explicitly trained on the respective datasets.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0DDM32GPO6zkmnIrObEP2OA92g45b-zSMHgCf-uNoj6Ed0M0zVsN7vmFmfgXT6Sh5p-W0xI1qj6YwXcqi3T6aD5hI9ZOJqT8Sobp43FGrtSsLUkI2poHnGml7Za4BMObSd6nEKUVL8wj7nHJDFYHbWaQOXOcfxvqXUcMxUZ3WVQW8Z5sabfFsi7M85_7I/s735/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="433" data-original-width="735" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0DDM32GPO6zkmnIrObEP2OA92g45b-zSMHgCf-uNoj6Ed0M0zVsN7vmFmfgXT6Sh5p-W0xI1qj6YwXcqi3T6aD5hI9ZOJqT8Sobp43FGrtSsLUkI2poHnGml7Za4BMObSd6nEKUVL8wj7nHJDFYHbWaQOXOcfxvqXUcMxUZ3WVQW8Z5sabfFsi7M85_7I/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Last window MAE (the lower the better) of TimesFM(ZS) against llmtime(ZS) and long-horizon forecasting baselines on ETT datasets.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
We train a decoder-only foundation model for time-series forecasting using a large pretraining corpus of 100B real world time-points, the majority of which was search interest time-series data derived from Google Trends and pageviews from Wikipedia. We show that even a relatively small 200M parameter pretrained model that uses our TimesFM architecture displays impressive zero-shot performance on a variety of public benchmarks from different domains and granularities.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>This work is the result of a collaboration between several individuals across Google Research and Google Cloud, including (in alphabetical order): Abhimanyu Das, Weihao Kong, Andrew Leach, Mike Lawrence, Alex Martin, Rajat Sen, Yang Yang, Skander Hannachi, Ivan Kuznetsov and Yichen Zhou.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-22542879280407275022024-02-02T09:49:00.000-08:002024-02-02T09:49:36.211-08:00Intervening on early readouts for mitigating spurious features and simplicity bias<span class="byline-author">Posted by Rishabh Tiwari, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgdBd5rMRA2U1nd8fetuEweTgmHncn49ASMQtPlm6dfsr5V29RwsoUR8UtK4B7oSE1eiIdW-vD-gjCUK4tGZTbsY4XdO0adL2YtAjpgbF1S3mL_Jw3f31SwLKYUtCOLJ807gdXdRmD5iVsrtc_Ii-BiqQacv89vbtRbNAIINa9PhKAF_sDAZu09FLs4599T/s1600/SiFer%20Hero.png" style="display: none;" />
<p>
Machine learning models in the real world are often trained on limited data that may contain unintended <a href="https://en.wikipedia.org/wiki/Bias_(statistics)">statistical biases</a>. For example, in the <a href="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html">CELEBA</a> celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as <a href="https://www.researchgate.net/publication/362524426_Addressing_fairness_in_artificial_intelligence_for_medical_imaging">medical diagnosis</a>.
</p>
<a name='more'></a>
<p>
Surprisingly, recent work has also discovered an inherent tendency of deep networks to <em>amplify such statistical biases</em>, through the so-called <a href="https://proceedings.neurips.cc/paper/2020/file/6cfe0e6127fa25df2a0ef2ae1067d915-Paper.pdf">simplicity bias</a> of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features.
</p>
<p>
With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying <em>early readouts</em> and <em>feature forgetting</em>. First, in “<a href="https://arxiv.org/abs/2310.18590">Using Early Readouts to Mediate Featural Bias in Distillation</a>”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in <a href="https://arxiv.org/pdf/1503.02531.pdf">model distillation</a>, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “<a href="https://arxiv.org/abs/2301.13293">Overcoming Simplicity Bias in Deep Networks using a Feature Sieve</a>”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our <a href="https://ai.google/responsibility/principles">AI Principles</a> and our <a href="https://ai.google/responsibility/responsible-ai-practices/">Responsible AI practices</a> guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhzG_p8Re7HHeTp_Qg_GwjX5LcHsE-TZDmHr3azTSOLKl4f1J4xcL9vxo46zicAl6QoIKIrTJaI2Z51iFq2oICjeb6Ut4-W1W74bytv87pH3hKVJOotWWWDk0gwB-ak_YZRmtZyimw8b9lSJ1DRzh6uIpvIBN2pbIw-6MuN47rUjTK_RzLLfYXPrIjtpjRz/s1080/image3.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="540" data-original-width="1080" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhzG_p8Re7HHeTp_Qg_GwjX5LcHsE-TZDmHr3azTSOLKl4f1J4xcL9vxo46zicAl6QoIKIrTJaI2Z51iFq2oICjeb6Ut4-W1W74bytv87pH3hKVJOotWWWDk0gwB-ak_YZRmtZyimw8b9lSJ1DRzh6uIpvIBN2pbIw-6MuN47rUjTK_RzLLfYXPrIjtpjRz/s16000/image3.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Animation comparing hypothetical responses from two models trained with and without the feature sieve.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Early readouts for debiasing distillation</h2>
<p>
We first illustrate the diagnostic value of <em>early readouts</em> and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the <a href="https://towardsdatascience.com/cross-entropy-loss-function-f38c4ec8643e">cross-entropy loss</a> between student outputs and the ground-truth labels) and teacher matching (minimizing the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL divergence</a> loss between student and teacher outputs for any given input).
</p>
<p>
Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as <em>Aux,</em> on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEixpq4OhPGxL9gGW30-0kqQ_CieDj3PJcqw8L4_7fBDZOFKuQpI67ljqIItOoJ3U9-dpPd1CpofAG_ld689r0HcPTrzFeTd1ceMQ42C3CRPWWJMYknydHpJhFjQUjb-M6mx8ILQbWEBIOv-NSgTauMGgDZ8t3EMGHE3j6UN9HIF3BJmB63GhOzFwOVmswlc/s1128/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="796" data-original-width="1128" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEixpq4OhPGxL9gGW30-0kqQ_CieDj3PJcqw8L4_7fBDZOFKuQpI67ljqIItOoJ3U9-dpPd1CpofAG_ld689r0HcPTrzFeTd1ceMQ42C3CRPWWJMYknydHpJhFjQUjb-M6mx8ILQbWEBIOv-NSgTauMGgDZ8t3EMGHE3j6UN9HIF3BJmB63GhOzFwOVmswlc/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss.</td></tr></tbody></table>
<p>
We used this signal to modulate the contribution of the teacher in the distillation loss on a per-instance basis, and found significant improvements in the trained student model as a result.
</p>
<p>
We evaluated our approach on standard benchmark datasets known to contain spurious correlations (<a href="https://arxiv.org/pdf/1911.08731.pdf">Waterbirds</a>, <a href="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html">CelebA</a>, <a href="https://www.tensorflow.org/datasets/catalog/civil_comments">CivilComments</a>, <a href="https://cims.nyu.edu/~sbowman/multinli/">MNLI</a>). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its <em>worst group accuracy</em>, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our <a href="https://arxiv.org/pdf/2310.18590.pdf">paper</a>.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpQiz04rM3DMtDiusAWyWl92FMUKbafR0l2dGvrj17fX3nuvPDnyXMQaumsxDvch3ScnOCL4Duq5_O32dWbv_CTsIu5aNc-c3xrVAIXjQ3kmn0jZ_TZ5SJ7C2lq1oxLZ33-VKXSSPRa_oGUB5jJlsBTZupsHMeUtSVXLh414e1NVEgI1IamqhTA1dqU0s5/s1270/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1014" data-original-width="1270" height="511" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpQiz04rM3DMtDiusAWyWl92FMUKbafR0l2dGvrj17fX3nuvPDnyXMQaumsxDvch3ScnOCL4Duq5_O32dWbv_CTsIu5aNc-c3xrVAIXjQ3kmn0jZ_TZ5SJ7C2lq1oxLZ33-VKXSSPRa_oGUB5jJlsBTZupsHMeUtSVXLh414e1NVEgI1IamqhTA1dqU0s5/w640-h511/image4.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Overcoming simplicity bias with a feature sieve</h2>
<p>
In a second, closely related project, we intervene directly on the information provided by early readouts, to improve <a href="https://en.wikipedia.org/wiki/Feature_learning">feature learning</a> and <a href="https://developers.google.com/machine-learning/crash-course/generalization/video-lecture">generalization</a>. The workflow alternates between <em>identifying </em>problematic features and <em>erasing identified features</em> from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEghN4NJ5vZ6jESH3koLTfGa3DpSenk5liLEg2awv2cOo1blDwwuDjLGVGxyeHSAzkLWTBUwO_swf4uGC2oShnD0WTNrebCL9KLAMOBIxR3ZZnw9eVS8g16s_lgP5kCbhZmVoTctASyDVvb3wtzIlzju01m4ADr7G21NpOWpac55hBllzYBaQVAXCjq8BIca/s1098/image6.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="604" data-original-width="1098" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEghN4NJ5vZ6jESH3koLTfGa3DpSenk5liLEg2awv2cOo1blDwwuDjLGVGxyeHSAzkLWTBUwO_swf4uGC2oShnD0WTNrebCL9KLAMOBIxR3ZZnw9eVS8g16s_lgP5kCbhZmVoTctASyDVvb3wtzIlzju01m4ADr7G21NpOWpac55hBllzYBaQVAXCjq8BIca/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration).</td></tr></tbody></table>
<p>
We describe the identification and erasure steps in more detail:
</p>
<ul>
<li><b>Identifying simple features</b>: We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network.
</li><li><b>Applying the feature sieve</b>: We aim to erase the identified features in the early layers of the neural network with the use of a novel <em>forgetting loss</em>,<em> L<sub>f </sub></em>, which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged.
</li>
</ul>
<p>
We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via <a href="https://en.wikipedia.org/wiki/Hyperparameter_optimization">hyperparameter search</a> to maximize validation accuracy, a standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline.
</p>
<p>
Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (<a href="https://github.com/alinlab/BAR">BAR</a>) and animal categorization (<a href="https://arxiv.org/pdf/1906.02899v3.pdf">NICO</a>). Feature importance was estimated using post-hoc gradient-based importance scoring (<a href="https://arxiv.org/abs/1610.02391">GRAD-CAM</a>), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgumwu2DQ-nPeTLxt_uS6q6tIR6oQZdlWOoM4_I5kUmYfyJi8xyWIpw7WusdRAsA_YthYgO2Zz8sj7V1Id3JOTsljM9zpK2vwhokMfnZQOxbAIWtaFvFN4sfN6qF0rkOklj10y-_rLfL-WQS4zf6AWCub7aUTS7a8LyEsZ5uhQmXjTai7neuWElZBbP_5UI/s1616/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="850" data-original-width="1616" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgumwu2DQ-nPeTLxt_uS6q6tIR6oQZdlWOoM4_I5kUmYfyJi8xyWIpw7WusdRAsA_YthYgO2Zz8sj7V1Id3JOTsljM9zpK2vwhokMfnZQOxbAIWtaFvFN4sfN6qF0rkOklj10y-_rLfL-WQS4zf6AWCub7aUTS7a8LyEsZ5uhQmXjTai7neuWElZBbP_5UI/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label.</td></tr></tbody></table>
<p>
Through this ability to learn better, generalizable features, we show substantial gains over a range of relevant baselines on real-world spurious feature benchmark datasets: <a href="https://github.com/alinlab/BAR">BAR</a>, <a href="https://arxiv.org/pdf/2104.06885.pdf">CelebA Hair</a>, <a href="https://nico.thumedialab.com/">NICO</a> and <a href="https://www.tensorflow.org/datasets/catalog/imagenet_a">ImagenetA</a>, by margins up to 11% (see figure below). More details are available in <a href="https://arxiv.org/abs/2301.13293">our paper</a>.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjjuXHls8mwfL2u-TVZlDlu5UMPrank9F2ODbf6h12q9oMLNrIYyfyv4OuQriS0XzI-z0BrQOs2xUiXt53lGLQtdzmKQDtGXFtv6TZEGg4pKua8JD9AkQn0J92mTjlQAlZTUPgqIYRAFpnsRTU0szE5J90_LeGNj3PTUKrsgq3WAMAjWSy30HQtMnNzevvY/s1082/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1082" data-original-width="844" height="640" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjjuXHls8mwfL2u-TVZlDlu5UMPrank9F2ODbf6h12q9oMLNrIYyfyv4OuQriS0XzI-z0BrQOs2xUiXt53lGLQtdzmKQDtGXFtv6TZEGg4pKua8JD9AkQn0J92mTjlQAlZTUPgqIYRAFpnsRTU0szE5J90_LeGNj3PTUKrsgq3WAMAjWSy30HQtMnNzevvY/w501-h640/image1.png" width="501" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
We hope that our work on early readouts and their use in feature sieving for generalization will both spur the development of a new class of adversarial feature learning approaches and help improve the generalization capability and robustness of deep learning systems.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements </h2>
<p>
<em>The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at <a href="https://www.iitb.ac.in/">IIT Bombay</a>. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-59665531149676739842024-01-31T13:59:00.000-08:002024-01-31T13:59:36.056-08:00MobileDiffusion: Rapid text-to-image generation on-device<span class="byline-author">Posted by Yang Zhao, Senior Software Engineer, and Tingbo Hou, Senior Staff Software Engineer, Core ML</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgOndf55Pc7tkXJektbVBEYRsOlxbUVui2uwOdXvuHj9cNpoNw2One4-68fqFNl2_fvv11CcgYfoI1XVQIkpjA9DosaOeqdkIRj9aZZJNoDy8KqB_XCVDtDd_EvT5UGL2ZhXvL2PU3RjN8XBjI0eQe8VIJCKI0-20AG0TKGK58mO9tBZa80P58KSjTU_liK/s1600/InstantTIGO%20hero.png" style="display: none;" />
<p>
Text-to-image <a href="https://arxiv.org/abs/2006.11239">diffusion models</a> have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., <a href="https://stability.ai/news/stable-diffusion-public-release">Stable Diffusion</a>, <a href="https://openai.com/research/dall-e">DALL·E</a>, and <a href="https://imagen.research.google/">Imagen</a>). While recent advancements in inference solutions on <a href="https://blog.research.google/2023/06/speed-is-all-you-need-on-device.html">Android</a> via MediaPipe and <a href="https://github.com/apple/ml-stable-diffusion">iOS</a> via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach.
</p> <a name='more'></a>
<p>
To that end, in “<a href="https://arxiv.org/abs/2311.16567">MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices</a>”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt <a href="https://arxiv.org/abs/2311.09257">DiffusionGAN</a> to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512x512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody>
<tr>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgc9IegNp6IHze1sPewUyoR_WouBi8jMhiThcaavD0SXFld3788eA89uyOP6gpmdCXSZMMuacrgQMJ61ygVJsLfE51tqTmmYS0C-GI9SaF_hEGlhTp_zTFXdW_AgXIP5CLCejKQVCsPrhycF8p_Rj9qQHR0J_kTO8Md7VT5R47IMJHinO6dkHn23lUlU7rf/s800/image2.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="800" data-original-width="369" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgc9IegNp6IHze1sPewUyoR_WouBi8jMhiThcaavD0SXFld3788eA89uyOP6gpmdCXSZMMuacrgQMJ61ygVJsLfE51tqTmmYS0C-GI9SaF_hEGlhTp_zTFXdW_AgXIP5CLCejKQVCsPrhycF8p_Rj9qQHR0J_kTO8Md7VT5R47IMJHinO6dkHn23lUlU7rf/s16000/image2.gif" /></a></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpz0XGSpMH9OVTd865uusar0AeXtu_26HD3tHzJHm2iEVeLYynBhi6pl0tidIYOoJVamc-NplnsNPCNl3vMX-qjqEZCYtndsl-9YjulMpLiDbP3Uws9cZ5ITjb0C3MNaVNC5mh-kbyKZYXn5rxBAuPLaHg_56ZAJfPOrkBfh44goI3CnEW-XZFDUvJgWAV/s800/image5.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="800" data-original-width="369" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhpz0XGSpMH9OVTd865uusar0AeXtu_26HD3tHzJHm2iEVeLYynBhi6pl0tidIYOoJVamc-NplnsNPCNl3vMX-qjqEZCYtndsl-9YjulMpLiDbP3Uws9cZ5ITjb0C3MNaVNC5mh-kbyKZYXn5rxBAuPLaHg_56ZAJfPOrkBfh44goI3CnEW-XZFDUvJgWAV/s16000/image5.gif" /></a></td>
</tr></tbody></table>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td class="tr-caption" style="text-align: center;">Rapid text-to-image generation on-device.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Background</h2>
<p>
The relative inefficiency of text-to-image diffusion models arises from two primary challenges. First, the inherent design of diffusion models requires <a href="https://blog.research.google/2023/06/on-device-diffusion-plugins-for.html">iterative denoising</a> to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature.
</p>
<p>
The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., <a href="https://arxiv.org/abs/2206.00927">DPM</a>) or distillation techniques (e.g., <a href="https://arxiv.org/abs/2202.00512">progressive distillation</a>, <a href="https://arxiv.org/abs/2303.01469">consistency distillation</a>), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like <a href="https://arxiv.org/abs/2311.09257">DiffusionGAN</a> and <a href="https://arxiv.org/abs/2311.17042#:~:text=We%20introduce%20Adversarial%20Diffusion%20Distillation,while%20maintaining%20high%20image%20quality.">Adversarial Diffusion Distillation</a>, even reduce to a single necessary step.
</p>
<p>
However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., <a href="https://snap-research.github.io/SnapFusion/">SnapFusion</a>). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>MobileDiffusion</h2>
<p>
Effectively overcoming the challenges imposed by the limited computational power of mobile devices requires an in-depth and holistic exploration of the model's architectural efficiency. In pursuit of this objective, our research undertakes a detailed examination of each constituent and computational operation within Stable Diffusion’s <a href="https://arxiv.org/abs/2112.10752">UNet architecture</a>. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion.
</p>
<p>
The design of MobileDiffusion follows that of <a href="https://arxiv.org/abs/2112.10752">latent diffusion models</a>. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use <a href="https://arxiv.org/abs/2103.00020">CLIP-ViT/L14</a>, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Diffusion UNet</h3>
<p>
As illustrated in the figure below, diffusion UNets commonly interleave transformer blocks and convolution blocks. We conduct a comprehensive investigation of these two fundamental building blocks. Throughout the study, we control the training pipeline (e.g., data, optimizer) to study the effects of different architectures.
</p>
<p>
In classic text-to-image diffusion models, a transformer block consists of a self-attention layer (SA) for modeling long-range dependencies among visual features, a cross-attention layer (CA) to capture interactions between text conditioning and visual features, and a feed-forward layer (FF) to post-process the output of attention layers. These transformer blocks hold a pivotal role in text-to-image diffusion models, serving as the primary components responsible for text comprehension. However, they also pose a significant efficiency challenge, given the computational expense of the attention operation, which is quadratic to the sequence length. We follow the idea of <a href="https://arxiv.org/abs/2301.11093">UViT</a> architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgsshK53k6noqIbabpGMBzYIBCdviXisDoBsD3Houk-lXzN8pZQcusKYBvjWwcwA1Aq5DnWyk01YM9B2RyRZx6HcGgTP-LrW-tnwFwByzlBACN3WggyPYM0Mpyr2OVGVLFhx1uN48aR1g9P4o0joN2STli9VpA_tFMdQ-ikRXVrNpawzB793-unSENR-PIV/s915/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="249" data-original-width="915" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgsshK53k6noqIbabpGMBzYIBCdviXisDoBsD3Houk-lXzN8pZQcusKYBvjWwcwA1Aq5DnWyk01YM9B2RyRZx6HcGgTP-LrW-tnwFwByzlBACN3WggyPYM0Mpyr2OVGVLFhx1uN48aR1g9P4o0joN2STli9VpA_tFMdQ-ikRXVrNpawzB793-unSENR-PIV/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions.</td></tr></tbody></table>
<p>
Convolution blocks, in particular <a href="https://arxiv.org/abs/1512.03385">ResNet</a> blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is <a href="https://arxiv.org/abs/1704.04861">separable convolution</a>. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance.
</p>
<p>
In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of <a href="https://arxiv.org/pdf/2110.12894.pdf">FLOPs</a> (floating-point operations) and number of parameters.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjXYleITSssbZnLffeh3BzG3tX2qNQNeB__xc-ySks0SPnXsMb2kTLZ0PcE2KWJ4I9FX_QMP32pXd06IuV1kJJSlgp7CuV6dqkXJsiFqo_6xqWXZ1-65p_EPU9gk7G9B4-L2TaKGiD5cahwg428CTmV1dcuQQ_vBTVmP8543IJigIF0qHo8_JaB8h5EuVvl/s1200/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="742" data-original-width="1200" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjXYleITSssbZnLffeh3BzG3tX2qNQNeB__xc-ySks0SPnXsMb2kTLZ0PcE2KWJ4I9FX_QMP32pXd06IuV1kJJSlgp7CuV6dqkXJsiFqo_6xqWXZ1-65p_EPU9gk7G9B4-L2TaKGiD5cahwg428CTmV1dcuQQ_vBTVmP8543IJigIF0qHo8_JaB8h5EuVvl/s16000/image3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Comparison of some diffusion UNets.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Image decoder</h3>
<p>
In addition to the UNet, we also optimized the image decoder. We trained a <a href="https://arxiv.org/abs/2012.03715">variational autoencoder</a> (VAE) to encode an <a href="https://en.wikipedia.org/wiki/RGB_color_model">RGB</a> image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size. To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our <a href="https://arxiv.org/abs/2311.16567">paper</a>.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjT2Nmo7GjGdN0_2dqevJB52RogqnWFDVmFsrusHHxnVf9YQYsdbVkAQvBI3h9SzKZ0TqOQOmnxaZ6z2kdix12tei5oMpD17SY1LoBWqxD1EHgV0ygTb9TV0IFZQtv4dAix378lb8WGv5GGPQIuyStX3gWqn0pjTTXbpIlA0VzYSeiGpkO5bsHhZfjbkR07/s1124/image6.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="789" data-original-width="1124" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjT2Nmo7GjGdN0_2dqevJB52RogqnWFDVmFsrusHHxnVf9YQYsdbVkAQvBI3h9SzKZ0TqOQOmnxaZ6z2kdix12tei5oMpD17SY1LoBWqxD1EHgV0ygTb9TV0IFZQtv4dAix378lb8WGv5GGPQIuyStX3gWqn0pjTTXbpIlA0VzYSeiGpkO5bsHhZfjbkR07/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion).</td></tr></tbody></table>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="text-align: center;">
<tbody><tr>
<td style="text-align: left;"><b>Decoder</b>
</td>
<td><b> #Params (M) </b>
</td>
<td><b> PSNR↑ </b>
</td>
<td><b> SSIM↑ </b>
</td>
<td><b> LPIPS↓ </b>
</td>
</tr>
<tr>
<td style="text-align: left;"><b>SD</b>
</td>
<td>49.5
</td>
<td>26.7
</td>
<td>0.76
</td>
<td>0.037
</td>
</tr>
<tr>
<td style="text-align: left;"><b>Ours</b>
</td>
<td>39.3
</td>
<td>30.0
</td>
<td>0.83
</td>
<td>0.032
</td>
</tr>
<tr>
<td style="text-align: left;"><b>Ours-Lite </b>
</td>
<td>9.8
</td>
<td>30.2
</td>
<td>0.84
</td>
<td>0.032
</td>
</tr>
</tbody></table>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td class="tr-caption" style="text-align: center;">Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including <a href="https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio">peak signal-to-noise ratio</a> (PSNR), <a href="https://en.wikipedia.org/wiki/Structural_similarity">structural similarity index measure</a> (SSIM), and <a href="https://arxiv.org/abs/1801.03924">Learned Perceptual Image Patch Similarity</a> (LPIPS).</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>One-step sampling</h3>
<p>
In addition to optimizing the model architecture, we adopt a <a href="https://arxiv.org/abs/2311.09257">DiffusionGAN hybrid</a> to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., <a href="https://arxiv.org/abs/2301.09515">StyleGAN-T</a>, <a href="https://arxiv.org/abs/2303.05511">GigaGAN</a>) confront similar complexities, resulting in highly intricate and expensive training.
</p>
<p>
To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training.
</p>
<p>
The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnK7SE2-cHSlP-PDmkl_xfjp3sP-kB41r6OvC8Wg6miXnYwdES0INwN19BHWQ_uyXtcBT-872U5J6jLY8yXVtA_W96qkRRPh6Pjvw0n-ZJvjJK91kYTh7H1n4nzy8z1TyrQZlZoZrQUDTo5Qm-6a_2vIVye3aqm7o32qOOXiWXwxDzw_J6cQsOrJ-UILKw/s960/image7.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="576" data-original-width="960" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnK7SE2-cHSlP-PDmkl_xfjp3sP-kB41r6OvC8Wg6miXnYwdES0INwN19BHWQ_uyXtcBT-872U5J6jLY8yXVtA_W96qkRRPh6Pjvw0n-ZJvjJK91kYTh7H1n4nzy8z1TyrQZlZoZrQUDTo5Qm-6a_2vIVye3aqm7o32qOOXiWXwxDzw_J6cQsOrJ-UILKw/s16000/image7.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Illustration of DiffusionGAN fine-tuning.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Results</h2>
<p>
Below we show example images generated by our MobileDiffusion with DiffusionGAN one-step sampling. With such a compact model (520M parameters in total), MobileDiffusion can generate high-quality diverse images for various domains.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyDLq1NW7Qvy4_oEqg1pHAMzeBfuei3VadIKZRNkv6ZHnzewVWQU5x76e0bm-QqWVr-_q1W4axBJeyqyCbdRFoUFBYxRxDj3qo7I4-Du6TS2Bez_-mmXzYoHLJk7y5fiKl9PPkHNk_dsvy7ezuAFavW4sYIeYTxhAPAH35FYP5YOceS8NfJey0gpvHUwza/s1728/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1296" data-original-width="1728" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyDLq1NW7Qvy4_oEqg1pHAMzeBfuei3VadIKZRNkv6ZHnzewVWQU5x76e0bm-QqWVr-_q1W4axBJeyqyCbdRFoUFBYxRxDj3qo7I4-Du6TS2Bez_-mmXzYoHLJk7y5fiKl9PPkHNk_dsvy7ezuAFavW4sYIeYTxhAPAH35FYP5YOceS8NfJey0gpvHUwza/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Images generated by our MobileDiffusion</td></tr></tbody></table>
<p>
We measured the performance of our MobileDiffusion on both iOS and Android devices, using different runtime optimizers. The latency numbers are reported below. We see that MobileDiffusion is very efficient and can run within half a second to generate a 512x512 image. This lightning speed potentially enables many interesting use cases on mobile devices.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFkcI7kibwRFhpxTsVmUkAzK38MCeBoTR6fOWyhjnqwPm7x8TwrVn_O0OipsXCbgS4qTtcbtm41Fxi7U_IJjpeuZadWO7cBKkcdrXHniAJgQP4Qk-wOBfnhtwNPxDbzxtM0uxVba3BjwzLa3Lw13-03FoRQbWwf_25KR9GLLkSqIFpnU5aE-6hnomY5IuK/s1184/image8.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="742" data-original-width="1184" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiFkcI7kibwRFhpxTsVmUkAzK38MCeBoTR6fOWyhjnqwPm7x8TwrVn_O0OipsXCbgS4qTtcbtm41Fxi7U_IJjpeuZadWO7cBKkcdrXHniAJgQP4Qk-wOBfnhtwNPxDbzxtM0uxVba3BjwzLa3Lw13-03FoRQbWwf_25KR9GLLkSqIFpnU5aE-6hnomY5IuK/s16000/image8.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Latency measurements (<b>s</b>) on mobile devices.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
With superior efficiency in terms of latency and size, MobileDiffusion has the potential to be a very friendly option for mobile deployments given its capability to enable a rapid image generation experience while typing text prompts. And we will ensure any application of this technology will be in-line with Google’s <a href="https://ai.google/responsibility/responsible-ai-practices/">responsible AI practices</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgments</h2>
<p>
<em>We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-51449067291092534952024-01-26T11:56:00.000-08:002024-01-26T11:56:23.553-08:00Mixed-input matrix multiplication performance optimizations<span class="byline-author">Posted by Manish Gupta, Staff Software Engineer, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhEKJJf1R773hab0veY6zffF2Nf_yfV2mk8YU9yRnuBDD3ak1o0iXecWlJw2x7bL-Ez2MX1c21MXk65VMK5IsoLpJ1H6BTC6k7BvVWl_gHJpJIOG2cm3BwP4V-HCScGHYIynuskbhvu1uorQGprHGbOFmfGI7E5UWemJcZ0xSC3tC5DolBYgyBwugl6OOLr/s1180/matrixhero.png" style="display: none;" />
<p>
AI-driven technologies are weaving themselves into the fabric of our daily routines, with the potential to enhance our access to knowledge and boost our overall productivity. The backbone of these applications lies in large language models (LLMs). LLMs are memory-intensive and typically require specialized hardware accelerators to efficiently deliver <a href="https://cloud.google.com/blog/products/compute/the-worlds-largest-distributed-llm-training-job-on-tpu-v5e">tens of exaflops</a> of computing power. This blog post shows how we can start addressing the computational challenges by utilizing memory more effectively.
</p>
<a name='more'></a>
<p>
The bulk of an LLM’s memory and compute are consumed by <a href="https://arxiv.org/pdf/2005.14165.pdf">weights</a> in <a href="https://arxiv.org/pdf/2006.16668.pdf">matrix multiplication</a> operations. Using narrower <em><a href="https://en.wikipedia.org/wiki/Primitive_data_type">data types</a></em> reduces memory consumption. For example, storing weights in the 8-bit <a href="https://en.wikipedia.org/wiki/Integer_(computer_science)">integer</a> (i.e., U8 or S8) data type reduces the memory footprint by 4× relative to <a href="https://en.wikipedia.org/wiki/Single-precision_floating-point_format">single-precision</a> (F32) and 2× relative to <a href="https://en.wikipedia.org/wiki/Half-precision_floating-point_format">half-precision</a> (F16) or <a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format">bfloat16</a> (BF16). Furthermore, <a href="https://arxiv.org/pdf/2206.01861.pdf">previous work has</a> shown that LLM models running matrix multiplications with <em>weights</em> in S8 and <em>input</em> in F16 (preserving higher precision of the user-input) is an effective method for increasing the efficiency with acceptable trade-offs in accuracy. This technique is known as <em>weight-only quantization</em> and requires efficient implementation of matrix multiplication with <em>mixed-inputs</em>, e.g., half-precision input multiplied with 8-bits integer. Hardware accelerators, including GPUs, support a fixed set of data types, and thus, mixed-input matrix multiplication requires software transformations to map to the hardware operations.
</p>
<p>
To that end, in this blog we focus on mapping mixed-input matrix multiplication onto the <a href="https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/">NVIDIA Ampere architecture</a>. We present software techniques addressing data type conversion and layout conformance to map mixed-input matrix multiplication efficiently onto hardware-supported data types and layouts. Our results show that the overhead of additional work in software is minimal and enables performance close to the peak hardware capabilities. The software techniques described here are released in the open-source <a href="https://github.com/NVIDIA/cutlass/pull/1084">NVIDIA/CUTLASS</a> repository.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgaLaSxuLbV_5ifXLyJsTGs0WLa23prrxrhX4IKSLZw5l3oSd2SPk5AgZtNgvUY_j-IbOyjttva-XIfkRr1cDBwCXghEz-3Q0G-6236m7_TIgTrm_K2UejYnTnhAEmZtKHq1mN9HKP0xxV8nqSxzTNHG1U0j-cVj236efpR7lSgmt082QEYNwKsGMTRiWZb/s1999/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1159" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgaLaSxuLbV_5ifXLyJsTGs0WLa23prrxrhX4IKSLZw5l3oSd2SPk5AgZtNgvUY_j-IbOyjttva-XIfkRr1cDBwCXghEz-3Q0G-6236m7_TIgTrm_K2UejYnTnhAEmZtKHq1mN9HKP0xxV8nqSxzTNHG1U0j-cVj236efpR7lSgmt082QEYNwKsGMTRiWZb/s16000/image3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Memory footprint for an 175B parameter LLM model with various data types formats.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>The matrix-multiply-accumulate operation</h2>
<p>
Modern AI hardware accelerators such as <a href="https://cloud.google.com/tpu/docs/intro-to-tpu#how_a_tpu_works">Google’s TPU</a> and <a href="https://www.nvidia.com/en-us/data-center/tensor-cores/">NVIDIA’s GPU</a> multiply matrices natively in the hardware by targeting Tensor Cores, which are specialized processing elements to accelerate matrix operations, particularly for AI workloads. In this blog, we focus on NVIDIA Ampere Tensor Cores, which provide the <em>matrix-multiply-accumulate</em> (<code><a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma">mma</a></code>) operation. For the rest of the blog the reference to <span style="color: #54863f;"><code>mma</code></span> is for Ampere Tensor Cores. The supported data types, shapes, and data layout of the two input matrices (called operands) for the <span style="color: #54863f;"><code>mma</code></span> operation are fixed in hardware. This means that matrix multiplications with various data types and larger shapes are implemented in the software by tiling the problem onto hardware-supported data types, shapes, and layouts.
</p>
<p>
The Tensor Core <span style="color: #54863f;"><code>mma</code></span> operation is defined by specifying two input matrices (e.g., <em>A</em> & <em>B</em>, shown below) to produce a result matrix, <em>C</em>. The <span style="color: #54863f;"><code>mma</code></span> operation natively supports mixed-precision. <em><a href="https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/">Mixed-precision Tensor Cores</a></em> allow mixing input (<em>A</em> and <em>B</em>) data type with the result (<em>C</em>) data type. In contrast, <em>mixed-input </em>matrix multiplication involves mixing the input data types, and it is not supported by the hardware, so it needs to be implemented in the software.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS_vu1tTxHo9Gy6Mywfx1xbQ0G6XTpOOQ04-l-Nw_rM7qOAM9kXg_qDjIakIpx-IclRmfR96cTGGExo2k9fxnVdltW4I9nb7RHloRtqWFMFeOtZ68Yr5wve9uLTIsZKA3GxB_VaNo98Gfsa7zGGP0dCrjebZ0Fq1dutfoxoy25eByHXorHCwTTiqsFzw6M/s1039/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="668" data-original-width="1039" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS_vu1tTxHo9Gy6Mywfx1xbQ0G6XTpOOQ04-l-Nw_rM7qOAM9kXg_qDjIakIpx-IclRmfR96cTGGExo2k9fxnVdltW4I9nb7RHloRtqWFMFeOtZ68Yr5wve9uLTIsZKA3GxB_VaNo98Gfsa7zGGP0dCrjebZ0Fq1dutfoxoy25eByHXorHCwTTiqsFzw6M/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Tensor Core operation of M-by-N-by-K on input matrix A of M-by-K and matrix B of K-by-N produces output matrix C of M-by-N.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Challenges of mixed-input matrix multiplication</h2>
<p>
To simplify the discussion, we restrict to a specific example of mixed-input matrix multiplication: F16 for user input and U8 for the model weights (written as F16 * U8). The techniques described here work for various combinations of mixed-input data types.
</p>
<p>
A GPU programmer can access a <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy">hierarchy of memory</a>, including global memory, shared memory, and registers, which are arranged in order of decreasing capacity but increasing speed. NVIDIA Ampere Tensor Core <span style="color: #54863f;"><code>mma</code></span> operations consume input matrices from registers. Furthermore, input and output matrices are required to conform to a layout of data within a group of 32 threads known as a <em>warp</em>. The supported data type <em>and</em> layout within a warp are fixed for an <span style="color: #54863f;"><code>mma</code></span> operation, so to implement mixed-input multiplication efficiently, it is necessary to solve the challenges of data type conversion and layout conformance in software.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Data type conversion </h3>
<p>
The <span style="color: #54863f;"><code>mma</code></span> operation requires two input matrices with the same data type. Thus, mixed-input matrix multiplication, where one of the operands is stored in U8 in global memory and other in F16, requires a data type conversion from U8 to F16. The conversion will bring two operands to F16, mapping the <em>mixed-input</em> matrix multiplication to hardware-supported <em>mixed-precision</em> Tensor Cores. Given the large number of weights, there are a large number of such operations, and our techniques show how to reduce their latency and improve performance.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Layout conformance </h3>
<p>
The <span style="color: #54863f;"><code>mma</code></span> operation also requires the layout of two input matrices, within the registers of a warp, to be conformat with hardware specification. The layout for the input matrix <em>B</em> of U8 data type in mixed-input matrix multiplication (F16 * U8) needs to conform with the converted F16 data type. This is called <em>layout conformance</em> and needs to be achieved in the software.
</p>
<p>
The figure below shows an <span style="color: #54863f;"><code>mma</code></span> operation consuming matrix <em>A</em> and matrix <em>B</em> from registers to produce matrix <em>C</em> in registers, distributed across one warp. The thread <em>T0</em> is highlighted and zoomed in to show the weight matrix <em>B</em> goes through data type conversion and needs a layout conformance to be able to map to the hardware-supported Tensor Core operation.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMMvieW8Uyta8c4afsNM7SgyZtlB2ra7G7aBG4z7D73rn-T7NHge0J1zfK7A_edL9tsQIthWVtEd0hZmwAjfO5C-XM6d5hNkv8IEBlpRxHilOxFgjYi27qauWFAQTl5wV8ixQ9MrfvqpuEQrdFuqDtjPJESG795s6cH3FlPJIVS4TuvKo0gmd8L1HwOJ_6/s1999/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1240" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMMvieW8Uyta8c4afsNM7SgyZtlB2ra7G7aBG4z7D73rn-T7NHge0J1zfK7A_edL9tsQIthWVtEd0hZmwAjfO5C-XM6d5hNkv8IEBlpRxHilOxFgjYi27qauWFAQTl5wV8ixQ9MrfvqpuEQrdFuqDtjPJESG795s6cH3FlPJIVS4TuvKo0gmd8L1HwOJ_6/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The mapping of mixed-input (F32 = F16 * U8) operation in software to natively supported warp-level Tensor Cores in hardware (F32 = F16 * F16). (Original figure source <a href="https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/">Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100</a>.)</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Software strategies addressing challenges</h2>
<p>
A typical data type conversion involves a sequence of operations on 32-bit registers, shown below. Each rectangular block represents a register and the adjoining text are the operations. The entire sequence shows the conversion from 4xU8 to 2x(2xF16). The sequence involves roughly 10 operations.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyJ4C214tiBhdjds0fWCV9EWh8X_UEDQlFqkpeoo6CZR3QMMrWyqi5mfRjvHLtbHH55J4hM5oRxe0HouGnbE3KuPbmh8MKk-TtDMMZv1YMKPv-Q4gYAr5l3ZXdTIPUHKs7f8wfCgr3XPe6_jUO7u12pGEmZVFiAGn_LCOlUlQQRSF7_r7jlOrPJW9Oc4V1/s947/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="836" data-original-width="947" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiyJ4C214tiBhdjds0fWCV9EWh8X_UEDQlFqkpeoo6CZR3QMMrWyqi5mfRjvHLtbHH55J4hM5oRxe0HouGnbE3KuPbmh8MKk-TtDMMZv1YMKPv-Q4gYAr5l3ZXdTIPUHKs7f8wfCgr3XPe6_jUO7u12pGEmZVFiAGn_LCOlUlQQRSF7_r7jlOrPJW9Oc4V1/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><code><a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L760">NumericArrayConvertor</a></code> from 4xU8 to 2x(2xF16) in 32-bit registers.</td></tr></tbody></table>
<p>
There are many ways of achieving layout conformance. Two of the existing solutions are:
</p>
<ol>
<li><em>Narrower bitwidth shared memory loads</em>: In this approach, threads issue narrow bitwidth memory loads moving the U8 data from shared memory to registers. This results in <em>two</em> 32-bit registers, with each register containing 2xF16 values (shown above for the matrix <em>B</em>’s thread <em>T0</em>). The narrower shared memory load achieves layout conformance directly into registers without needing any shuffles; however, it does not utilize the full shared memory bandwidth.
</li><li><em>Pre-processing in global memory</em>: An <a href="https://arxiv.org/pdf/2211.10017.pdf">alternative strategy</a> involves rearranging the data within the global memory (one level above the shared memory in <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-hierarchy">memory hierarchy</a>), allowing wider shared memory loads. This approach maximizes the shared memory bandwidth utilization and ensures that the data is loaded in a conformant layout directly in the registers. Although the rearrangement process can be executed offline prior to the LLM deployment, ensuring no impact on the application performance, it introduces an additional, non-trivial hardware-specific pre-processing step that requires an extra program to rearrange the data. <a href="https://github.com/NVIDIA/FasterTransformer">NVIDIA/FasterTransformer</a> adopts this method to effectively address layout conformance challenges.
</li>
</ol>
<div style="line-height: 40%;">
<br />
</div>
<h2>Optimized software strategies</h2>
<p>
To further optimize and reduce the overhead of data type conversion and layout conformance, we have implemented <code><a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514">FastNumericArrayConvertor</a></code> and <code><a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h#L120">FragmentShuffler</a></code>, respectively.
</p><p>
<code>FastNumericArrayConvertor</code> operates on 4xU8 in 32-bit registers without unpacking individual 1xU8 values. Furthermore, it uses less expensive arithmetic operations which reduces the number of instructions and increases the speed of the conversion.
</p>
<p>
The conversion sequence for <a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514">U8-to-F16</a> is shown below. The operations use packed 32b registers, avoiding explicit unpacking and packing. <code>FastNumericArrayConvertor</code> uses the <code><a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt">permute byte</a></code> to rearrange bytes of 4xU8 into two registers. Additionally, <code>FastNumericArrayConvertor</code> does not use expensive integer to floating-point conversion instructions and employs vectorized operations to obtain the packed results in <em>two</em> 32-bit registers containing 2x(2xF16) values. The <code>FastNumericArrayConvertor</code> for U8-to-F16 approximately uses six operations, a 1.6× reduction relative to the approach shown above.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRhtLljZ8wfnfnyXQsYZlNMDZ-cUqCV7wPvGimtPtU3JcKJLv6lCDT_PfBBmyp0TuHRgFIZ2cbgEDeL5bqke4FGUcpGMbAhcIBJxQcpcuWZIlqG1yXOHPf5BivF26_qlDnR9W2Y3RVE36ZB7rEGZO3x2Xva7-rqBZkoI7l4gnzBWLYfIrmhFBNN8DpaoEA/s1392/image201.png" imageanchor="1" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="733" data-original-width="1392" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRhtLljZ8wfnfnyXQsYZlNMDZ-cUqCV7wPvGimtPtU3JcKJLv6lCDT_PfBBmyp0TuHRgFIZ2cbgEDeL5bqke4FGUcpGMbAhcIBJxQcpcuWZIlqG1yXOHPf5BivF26_qlDnR9W2Y3RVE36ZB7rEGZO3x2Xva7-rqBZkoI7l4gnzBWLYfIrmhFBNN8DpaoEA/s16000/image201.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><code>FastNumericArrayConvertor</code> utilizes <code><a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt">permute bytes</a></code> and packed arithmetic, reducing the number of instructions in the data type conversion.</td></tr></tbody></table>
<p>
<code>FragmentShuffler</code> handles the layout conformance by shuffling data in a way that allows the use of wider bitwidth load operation, increasing shared memory bandwidth utilization and reducing the total number of operations.
</p>
<p>
NVIDIA Ampere architecture provides a load matrix instruction (<code><a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix">ldmatrix</a></code>). The <span style="color: #54863f;"><code>ldmatrix</code></span> is a warp-level operation, where 32 threads of a warp move the data from shared memory to registers in the <em>shape</em> and <em>layout</em> that <span style="color: #54863f;"><code>mma</code></span> matrix <em>A</em> and <em>B</em> consume. The use of <span style="color: #54863f;"><code>ldmatrix</code></span> <em>reduces</em> the number of load instructions and <em>increases</em> the memory bandwidth utilization. Since the <span style="color: #54863f;"><code>ldmatrix</code></span> instruction moves U8 data to registers, the layout after the load conforms with U8*U8 <span style="color: #54863f;"><code>mma</code></span> operation, and not with F16*F16 <span style="color: #54863f;"><code>mma</code></span> operation. We implemented <code>FragmentShuffler</code> to rearrange the data within registers using shuffle (<code><a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions">shfl.sync</a>)</code> operations to achieve the layout conformance.
</p><p>
The most significant contribution of this work is to achieve layout conformance through register shuffles, avoiding offline pre-processing in global memory or narrower bitwidth shared memory loads. Furthermore, we provide implementations for <code>FastNumericArrayConvertor</code> covering data type conversion from <a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2514">U8-to-F16</a>, <a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2448">S8-to-F16</a>, <a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2546">U8-to-BF16</a>, and <a href="https://github.com/NVIDIA/cutlass/blob/757275f2796bb901575c633e2a32bc76ca84ffec/include/cutlass/numeric_conversion.h#L2588">S8-to-BF16</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Performance results</h2>
<p>
We measured the performance of eight mixed-input variants of <em>our method</em> (shown below in blue and red; varying the data types of matrix <em>A</em> and <em>B</em>) and two <em>mixed-precision</em> data types (shown in green) on an NVIDIA A100 SXM chip. The performance results are shown in <a href="https://en.wikipedia.org/wiki/FLOPS">FLOPS</a> (higher is better). Notably, the first eight matrix-multipications require additional operations relative to the last two, because the mixed-precision variants directly target hardware-accelerated Tensor Core operations and do not need data type conversion and layout conformance. Even so, our approach demonstrates mixed-input matrix multiplication performance only slightly below or on par with mixed-precision.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-Dq_2LmFUlg0KlNIJvFufCUMZujNc9LcoMnSURpGQwGbM75vXuS-Nm9ZH-7ItgWmZaBSUS3yawN0u3K21tbWTdijU4fVNgEyS33jOztyGfvNvLEw6IBiJO3JSmpctQtN8tvZmagEYQNSP3mmBQnXJ8GeNlQymbeqrKjFycjkKnHL_5FC8V6WR858byfm_/s1999/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1180" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg-Dq_2LmFUlg0KlNIJvFufCUMZujNc9LcoMnSURpGQwGbM75vXuS-Nm9ZH-7ItgWmZaBSUS3yawN0u3K21tbWTdijU4fVNgEyS33jOztyGfvNvLEw6IBiJO3JSmpctQtN8tvZmagEYQNSP3mmBQnXJ8GeNlQymbeqrKjFycjkKnHL_5FC8V6WR858byfm_/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Mixed-input matrix multiplication performance on NVIDIA A100 40GB SMX4 chip for a compute-bound matrix problem shape <code>m=3456, n=4096, k=2048.</code></td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>We would like to mention several folks who have contributed through technical brainstorming and improving the blog post including, Quentin Colombet, Jacques Pienaar, Allie Culp, Calin Cascaval, Ashish Gondimalla, Matt Walsh, Marek Kolodziej, and Aman Bhatia. We would like to thank our NVIDIA partners Rawn Henry, Pradeep Ramani, Vijay Thakkar, Haicheng Wu, Andrew Kerr, Matthew Nicely, and Vartika Singh.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-14187365826019400762024-01-23T14:27:00.000-08:002024-01-23T14:27:09.785-08:00Exphormer: Scaling transformers for graph-structured data<span class="byline-author">Posted by Ameya Velingker, Research Scientist, Google Research, and Balaji Venkatachalam, Software Engineer, Google</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhbovKreBr7RlKc4L36E6rLqiZBZzJSq5GLijCkomHREon5tYXd-7C2pppMXnL5Mj2d82kZGnPlarrrMzQOfRnN8kVvqDh1GnadIJ-hbaaS8VjYzCpaD-DgYor5cKx-OhTGZk9iCy5MjtwG2Q9eTyQiipDr5ViMdl2vkxfbLzWnB3wmLb8YfvVsTJ1FnOmw/s1600/EXPHORMER%2005large.gif" style="display: none;" />
<p>
<a href="https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)">Graphs</a>, in which objects and their relations are represented as nodes (or vertices) and edges (or links) between pairs of nodes, are ubiquitous in computing and machine learning (ML). For example, social networks, road networks, and molecular structure and interactions are all domains in which underlying datasets have a natural graph structure. ML can be used to learn the properties of nodes, edges, or entire graphs.
</p>
<a name='more'></a>
<p>
A common approach to learning on graphs are <a href="https://distill.pub/2021/gnn-intro/">graph neural networks</a> (GNNs), which operate on graph data by applying an optimizable transformation on node, edge, and global attributes. The most typical class of GNNs operates via a <a href="https://wandb.ai/graph-neural-networks/spatial/reports/An-Introduction-to-Message-Passing-Graph-Neural-Networks--VmlldzoyMDI2NTg2">message-passing</a> framework, whereby each layer aggregates the representation of a node with those of its immediate neighbors.
</p>
<p>
Recently, <a href="https://arxiv.org/abs/2012.09699">graph transformer models</a> have emerged as a popular alternative to message-passing GNNs. These models build on the success of <a href="https://en.wikipedia.org/wiki/Transformer_(machine-learning_model)">Transformer architectures</a> in natural language processing (NLP), adapting them to graph-structured data. The attention mechanism in graph transformers can be modeled by an interaction graph, in which edges represent pairs of nodes that attend to each other. Unlike message passing architectures, graph transformers have an interaction graph that is separate from the input graph. The typical interaction graph is a complete graph, which signifies a full attention mechanism<em> </em>that models direct interactions between all pairs of nodes. However, this creates quadratic computational and memory bottlenecks that limit the applicability of graph transformers to datasets on small graphs with at most a few thousand nodes. Making graph transformers scalable has been considered one of the most important research directions in the field (see <a href="https://towardsdatascience.com/graph-ml-in-2022-where-are-we-now-f7f8242599e0">the first open problem here</a>).
</p>
<p>
A natural remedy is to use a <em>sparse</em> interaction graph with fewer edges. <a href="https://dl.acm.org/doi/10.1145/3530811">Many sparse and efficient transformers have been proposed</a> to eliminate the quadratic bottleneck for sequences, however, they do not generally extend to graphs in a principled manner.
</p>
<p>
In “<a href="https://arxiv.org/abs/2303.06147">Exphormer: Sparse Transformers for Graphs</a>”, presented at <a href="https://icml.cc/Conferences/2023/Dates">ICML 2023</a>, we address the scalability challenge by introducing a sparse attention framework for transformers that is designed specifically for graph data. The Exphormer framework makes use of expander graphs, a powerful tool from <a href="https://en.wikipedia.org/wiki/Spectral_graph_theory">spectral graph theory</a>, and is able to achieve strong empirical results on a wide variety of datasets. Our implementation of Exphormer is now available on <a href="https://github.com/hamed1375/Exphormer">GitHub</a>.
</p>
<br />
<h2>Expander graphs</h2>
<p>
A key idea at the heart of Exphormer is the use of <a href="https://en.wikipedia.org/wiki/Expander_graph">expander graphs</a>, which are sparse yet well-connected graphs that have some useful properties — 1) the matrix representation of the graphs have similar linear-algebraic properties as a complete graph, and 2) they exhibit rapid mixing of random walks, i.e., a small number of steps in a random walk from any starting node is enough to ensure convergence to a “stable” distribution on the nodes of the graph. Expanders have found applications to diverse areas, such as algorithms, pseudorandomness, complexity theory, and error-correcting codes.
</p>
<p>
A common class of expander graphs are <em>d</em>-regular expanders, in which there are <em>d</em> edges from every node (i.e., every node has degree <em>d</em>). The quality of an expander graph is measured by its <em>spectral gap</em>, an algebraic property of its <a href="https://en.wikipedia.org/wiki/Adjacency_matrix">adjacency matrix</a> (a matrix representation of the graph in which rows and columns are indexed by nodes and entries indicate whether pairs of nodes are connected by an edge). Those that maximize the spectral gap are known as <a href="https://en.wikipedia.org/wiki/Ramanujan_graph">Ramanujan graphs</a> — they achieve a gap of <em>d</em> - 2*√(<em>d</em>-1), which is essentially the best possible among <em>d</em>-regular graphs. A number of deterministic and randomized constructions of Ramanujan graphs have been proposed over the years for various values of <em>d</em>. We use a <a href="https://arxiv.org/abs/cs/0405020">randomized expander construction of Friedman</a>, which produces near-Ramanujan graphs.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg495FZQZ12yMiNhU8C7XUKEJ88H5_v2PPrzhwcDOVnSaVEtdCXaL7py-LzwZZkybKwIaePLHKpdmD6qALfskdjeaA8ML9QYHMwWkxz2ZnhWYqoV1PpnNgbRRfm0pSVYJVrtUpONyyF5PfswJ_QoxD-9vI9F3rF6VQbIRDDIbgvOFc35vTEF9uxizKNpli9/s843/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="843" data-original-width="800" height="320" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg495FZQZ12yMiNhU8C7XUKEJ88H5_v2PPrzhwcDOVnSaVEtdCXaL7py-LzwZZkybKwIaePLHKpdmD6qALfskdjeaA8ML9QYHMwWkxz2ZnhWYqoV1PpnNgbRRfm0pSVYJVrtUpONyyF5PfswJ_QoxD-9vI9F3rF6VQbIRDDIbgvOFc35vTEF9uxizKNpli9/s320/image1.gif" width="304" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><span id="docs-internal-guid-2920b38b-7fff-2fa8-a3cd-06dfd3ba9968"><span face="Arial, sans-serif" style="font-size: 10pt; font-style: italic; font-variant-alternates: normal; font-variant-east-asian: normal; font-variant-numeric: normal; font-variant-position: normal; vertical-align: baseline; white-space-collapse: preserve;">Expander graphs are at the heart of Exphormer. A good expander is sparse yet exhibits rapid mixing of random walks, making its global connectivity suitable for an interaction graph in a graph transformer model.</span></span></td></tr></tbody></table>
<p>Exphormer replaces the dense, fully-connected interaction graph of a standard Transformer with edges of a sparse <em>d</em>-regular expander graph. Intuitively, the spectral approximation and mixing properties of an expander graph allow distant nodes to communicate with each other after one stacks multiple attention layers in a graph transformer architecture, even though the nodes may not attend to each other directly. Furthermore, by ensuring that <em>d</em> is constant (independent of the size of the number of nodes), we obtain a linear number of edges in the resulting interaction graph.</p>
<br />
<h2>Exphormer: Constructing a sparse interaction graph</h2>
<p>
Exphormer combines expander edges with the input graph and virtual nodes. More specifically, the sparse attention mechanism of Exphormer builds an interaction graph consisting of three types of edges:
</p>
<ul>
<li>Edges from the input graph (<em>local attention</em>)
</li><li>Edges from a constant-degree expander graph (<em>expander attention</em>)
</li><li>Edges from every node to a small set of virtual nodes (<em>global attention</em>)
</li>
</ul>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS7VdL6OcWCmXd-wTtx-qs_nA7qTYZFJOTHS7RZNS3Io_w4km3NM4opPsQBXu1u50KjDA43CsG0hoi1l7I9gq_KGBMvwKEjlWQKBzCeytLQHujF-4K4r9E4F4Q0APvw7le4twjGbDyEiVfEzhbsovhzk2_g4Xd4jwCo66HW7xbnLvm3WPBsHaoq-hDAYX8/s800/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="430" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiS7VdL6OcWCmXd-wTtx-qs_nA7qTYZFJOTHS7RZNS3Io_w4km3NM4opPsQBXu1u50KjDA43CsG0hoi1l7I9gq_KGBMvwKEjlWQKBzCeytLQHujF-4K4r9E4F4Q0APvw7le4twjGbDyEiVfEzhbsovhzk2_g4Xd4jwCo66HW7xbnLvm3WPBsHaoq-hDAYX8/s16000/image1.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><span id="docs-internal-guid-ac11d16d-7fff-62da-cf18-7ba830f677d3"><span face="Arial, sans-serif" style="font-size: 10pt; font-style: italic; font-variant-alternates: normal; font-variant-east-asian: normal; font-variant-numeric: normal; font-variant-position: normal; vertical-align: baseline; white-space-collapse: preserve;">Exphormer builds an interaction graph by combining three types of edges. The resulting graph has good connectivity properties and retains the inductive bias of the input dataset graph while still remaining sparse.</span></span></td></tr></tbody></table>
<p>
Each component serves a specific purpose: the edges from the input graph retain the inductive bias from the input graph structure (which typically gets lost in a fully-connected attention module). Meanwhile, expander edges allow good global connectivity and random walk mixing properties (which spectrally approximate the complete graph with far fewer edges). Finally, virtual nodes serve as global “memory sinks” that can directly communicate with every node. While this results in additional edges from each virtual node equal to the number of nodes in the input graph, the resulting graph is still sparse. The degree of the expander graph and the number of virtual nodes are hyperparameters to tune for improving the quality metrics.
</p>
<p>
Furthermore, since we use an expander graph of constant degree and a small constant number of virtual nodes for the global attention, the resulting sparse attention mechanism is linear in the size of the original input graph, i.e., it models a number of direct interactions on the order of the total number of nodes and edges.
</p>
<p>
We additionally show that Exphormer is as expressive as the dense transformer and obeys universal approximation properties. In particular, when the sparse attention graph of Exphormer is augmented with self loops (edges connecting a node to itself), it can universally approximate continuous functions [<a href="https://arxiv.org/abs/1912.10077">1</a>, <a href="https://arxiv.org/abs/2006.04862">2</a>].
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Relation to sparse Transformers for sequences</h3>
<p>
It is interesting to compare Exphormer to sparse attention methods for sequences. Perhaps the architecture most conceptually similar to our approach is <a href="https://blog.research.google/2021/03/constructing-transformers-for-longer.html">BigBird</a>, which builds an interaction graph by combining different components. BigBird also uses virtual nodes, but, unlike Exphormer, it uses window attention and random attention from an <a href="https://en.wikipedia.org/wiki/Erd%C5%91s%E2%80%93R%C3%A9nyi_model">Erdős-Rényi</a> random graph model for the remaining components.
</p>
<p>
Window attention in BigBird looks at the tokens surrounding a token in a sequence — the local neighborhood attention in Exphormer can be viewed as a generalization of window attention to graphs.
</p>
<p>
The Erdős-Rényi graph on <em>n</em> nodes, <em>G(n, p)</em>, which connects every pair of nodes independently with probability <em>p</em>, also functions as an expander graph for suitably high <em>p</em>. However, a superlinear number of edges (Ω(<em>n</em> log <em>n</em>)) is needed to ensure that an Erdős-Rényi graph is connected, let alone a good expander. On the other hand, the expanders used in Exphormer have only a <em>linear</em> number of edges.
</p>
<br />
<h2>Experimental results</h2>
<p>
Earlier works have shown the use of full graph Transformer-based models on datasets with graphs of size up to 5,000 nodes. To evaluate the performance of Exphormer, we build upon the celebrated <a href="https://github.com/rampasek/GraphGPS">GraphGPS framework</a> [<a href="https://arxiv.org/abs/2205.12454">3</a>], which combines both message passing and graph transformers and achieves state-of-the-art performance on a number of datasets. We show that replacing dense attention with Exphormer for the graph attention component in the GraphGPS framework allows one to achieve models with comparable or better performance, often with fewer trainable parameters.
</p>
<p>
Furthermore, Exphormer notably allows graph transformer architectures to scale well beyond the usual graph size limits mentioned above. Exphormer can scale up to datasets of 10,000+ node graphs, such as the <a href="https://arxiv.org/abs/1811.05868">Coauthor dataset</a>, and even beyond to larger graphs such as the well-known <a href="https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv">ogbn-arxiv dataset</a>, a citation network, which consists of 170K nodes and 1.1 million edges.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-HJWH6mqX6N9ytZPbz6wawfMLzF2ey50Ot2BcowvPbQ3FaNwhlEZ3htvDbhq1C6ckLykf0yk3A1sIG0aPGaT8G_aSLj_A-AOfl8NIZdygdkn0C26RzZS9d-9KjyP1f_Zy7suN-iqvYR4zSCgqCXrhP8hVIirUgi6VGEBGx9I_AZikzc_ACKskBMBMPoSw/s1600/ExphormerPerformance.png" style="display: block; margin-left: auto; margin-right: auto; padding: 1em 0px; text-align: center;"><img alt="" border="0" data-original-height="190" data-original-width="1522" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi-HJWH6mqX6N9ytZPbz6wawfMLzF2ey50Ot2BcowvPbQ3FaNwhlEZ3htvDbhq1C6ckLykf0yk3A1sIG0aPGaT8G_aSLj_A-AOfl8NIZdygdkn0C26RzZS9d-9KjyP1f_Zy7suN-iqvYR4zSCgqCXrhP8hVIirUgi6VGEBGx9I_AZikzc_ACKskBMBMPoSw/s1600/ExphormerPerformance.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Results comparing Exphormer to standard GraphGPS on the five <a href="https://arxiv.org/abs/2206.08164">Long Range Graph Benchmark</a> datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of the paper’s publication.</td></tr></tbody></table>
<!--<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhDbVRMNKr2z64PowKGcaM4NDeiIfzrpfyXe02tRD8tpr_DS99oIjewDwOZZJkNgOr7ZSYwsE5jVqpwOz0Tj2z68SkQzCWtZrhC3cXf2WWfJEZmSfOq3xlGIjdfx-9V0CkbYYv6LU63i1B-suztAyK0Dx8udq2SYSX4TEeP5Erw021KZY8L4FEVNV3BOXaL/s1600/ExphormerPerformance.png" style="display: block; margin-left: auto; margin-right: auto; padding: 1em 0px; text-align: center;"><img alt="" border="0" data-original-height="202" data-original-width="1655" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhDbVRMNKr2z64PowKGcaM4NDeiIfzrpfyXe02tRD8tpr_DS99oIjewDwOZZJkNgOr7ZSYwsE5jVqpwOz0Tj2z68SkQzCWtZrhC3cXf2WWfJEZmSfOq3xlGIjdfx-9V0CkbYYv6LU63i1B-suztAyK0Dx8udq2SYSX4TEeP5Erw021KZY8L4FEVNV3BOXaL/s1600/ExphormerPerformance.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Results comparing Exphormer to standard GraphGPS on the five <a href="https://arxiv.org/abs/2206.08164">Long Range Graph Benchmark</a> datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of publication.</td></tr></tbody></table>-->
<!--<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;">
<tbody><tr>
<td align="left"><strong>Model </strong>
</td>
<td align="center"><strong> PascalVOC-SP </strong>
<br>
<font size="-1">F1 score </font><strong>↑</strong>
</td>
<td align="center"><strong> COCO-SP </strong>
<br>
<font size="-1">F1 score </font><strong>↑</strong>
</td>
<td align="center"><strong> Peptides-Func </strong>
<br>
<font size="-1">AP </font><strong>↑</strong>
</td>
<td align="center"><strong> Peptides-Struct </strong>
<br>
<font size="-1">MAE </font><strong>↓</strong>
</td>
<td align="center"><strong> PCQM-Contact</strong>
<br>
<font size="-1">MRR </font><strong>↑</strong>
</td>
</tr>
<tr><td colspan="6"><div style="line-height: 40%;"><br /></div></td></tr>
<tr>
<td>Standard GraphGPS
</td>
<td align="center"> 0.375 ± 0.011
</td>
<td align="center"> 0.341 ± 0.004
</td>
<td align="center"> <strong>0.654 ± 0.004</strong>
</td>
<td align="center"> 0.250 ± 0.001
</td>
<td align="center"> 0.334 ± 0.001
</td>
</tr>
<tr>
<td><em>Exphormer (ours) </em>
</td>
<td align="center"><strong><em> 0.398 ± 0.004 </em></strong>
</td>
<td align="center"><strong><em> 0.346 ± 0.001 </em></strong>
</td>
<td align="center"><em> 0.653 ± 0.004 </em>
</td>
<td align="center"><strong><em> 0.248 ± 0.001 </em></strong>
</td>
<td align="center"><strong><em> 0.364 ± 0.002</em></strong>
</td>
</tr>
</tbody></table>-->
<p>
Finally, we observe that Exphormer, which creates an overlay graph of small diameter via expanders, exhibits the ability to effectively learn long-range dependencies. The <a href="https://arxiv.org/abs/2206.08164">Long Range Graph Benchmark</a> is a suite of five graph learning datasets designed to measure the ability of models to capture long-range interactions. Results show that Exphormer-based models outperform standard GraphGPS models (which were previously state-of-the-art on four out of five datasets at the time of publication).
</p>
<br />
<h2>Conclusion</h2>
<p>
Graph transformers have emerged as an important architecture for ML that adapts the highly successful sequence-based transformers used in NLP to graph-structured data. Scalability has, however, proven to be a major challenge in enabling the use of graph transformers on datasets with large graphs. In this post, we have presented Exphormer, a sparse attention framework that uses expander graphs to improve scalability of graph transformers. Exphormer is shown to have important theoretical properties and exhibit strong empirical performance, particularly on datasets where it is crucial to learn long range dependencies. For more information, we point the reader to a short presentation <a href="https://icml.cc/virtual/2023/poster/23782">video</a> from ICML 2023.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>We thank our research collaborators Hamed Shirzad and Danica J. Sutherland from The University of British Columbia as well as Ali Kemal Sinop from Google Research. Special thanks to Tom Small for creating the animation used in this post.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-2208495499867096932024-01-18T10:03:00.000-08:002024-01-18T10:03:43.608-08:00Introducing ASPIRE for selective prediction in LLMs<span class="byline-author">Posted by Jiefeng Chen, Student Researcher, and Jinsung Yoon, Research Scientist, Cloud AI Team</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiMaqP9dd9YDXh04PEWHSFquEToz6U5M4YUxzfExokjfteUfuGAKhqs1LV5DUMJOBoiF3GjGxg7NqezNazuTeWePMsuH_OW7NM4z4ooMPhWnR22iyzENgpmG2-xJDbRbeeyyLbG-3dIdgYjl2IxX0K-bFvpbrAJsQA7Mu70MqxEuVFJXvwnP_-o4sPK8wYe/s320/ASPIRE%20hero.jpg" style="display: none;" />
<p>
In the fast-evolving landscape of artificial intelligence, large language models (LLMs) have revolutionized the way we interact with machines, pushing the boundaries of natural language understanding and generation to unprecedented heights. Yet, the leap into high-stakes decision-making applications remains a chasm too wide, primarily due to the inherent uncertainty of model predictions. Traditional LLMs generate responses recursively, yet they lack an intrinsic mechanism to assign a confidence score to these responses. Although one can derive a confidence score by summing up the probabilities of individual tokens in the sequence, traditional approaches typically fall short in reliably distinguishing between correct and incorrect answers. But what if LLMs could gauge their own confidence and only make predictions when they're sure?
</p>
<a name='more'></a>
<p>
<a href="https://papers.nips.cc/paper_files/paper/2017/hash/4a8423d5e91fda00bb7e46540e2b0cf1-Abstract.html">Selective prediction</a> aims to do this by enabling LLMs to output an answer along with a selection score, which indicates the probability that the answer is correct. With selective prediction, one can better understand the reliability of LLMs deployed in a variety of applications. Prior research, such as <a href="https://openreview.net/pdf?id=VD-AYtP0dve">semantic uncertainty</a> and <a href="https://arxiv.org/pdf/2207.05221.pdf">self-evaluation</a>, has attempted to enable selective prediction in LLMs. A typical approach is to use heuristic prompts like “Is the proposed answer True or False?” to trigger self-evaluation in LLMs. However, this approach may not work well on challenging <a href="https://en.wikipedia.org/wiki/Question_answering">question answering</a> (QA) tasks.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIwlr34r2t085uhaOx2IbBulTJX2FB2g8LkhTgrKgycgb8cDZaRuht0cFPqmlgSkT5jHOx-rrWywmYAEEfJ0FxlC7ammU8ewrZaVo_My7cCpBNYlfgERRKgFYnF-8LhsWhcyS3KTFdBnhyphenhyphencrenwQxBkbjM8UriPKzji8zDkXYv-5rhRXiE0SlvGmXUV-jt/s1999/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1534" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIwlr34r2t085uhaOx2IbBulTJX2FB2g8LkhTgrKgycgb8cDZaRuht0cFPqmlgSkT5jHOx-rrWywmYAEEfJ0FxlC7ammU8ewrZaVo_My7cCpBNYlfgERRKgFYnF-8LhsWhcyS3KTFdBnhyphenhyphencrenwQxBkbjM8UriPKzji8zDkXYv-5rhRXiE0SlvGmXUV-jt/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The <a href="https://arxiv.org/abs/2205.01068">OPT-2.7B</a> model incorrectly answers a question from the <a href="https://aclanthology.org/P17-1147/">TriviaQA</a> dataset: “Which vitamin helps regulate blood clotting?” with “Vitamin C”. Without selective prediction, LLMs may output the wrong answer which, in this case, could lead users to take the wrong vitamin. With selective prediction, LLMs will output an answer along with a selection score. If the selection score is low (0.1), LLMs will further output “I don’t know!” to warn users not to trust it or verify it using other sources.</td></tr></tbody></table>
<br />
<p>
In "<a href="https://aclanthology.org/2023.findings-emnlp.345.pdf">Adaptation with Self-Evaluation to Improve Selective Prediction in LLMs</a>", presented at <a href="https://2023.emnlp.org/program/accepted_findings/">Findings of EMNLP 2023</a>, we introduce ASPIRE — a novel framework meticulously designed to enhance the selective prediction capabilities of LLMs. ASPIRE fine-tunes LLMs on QA tasks via <a href="https://huggingface.co/blog/peft">parameter-efficient fine-tuning, </a>and trains them to evaluate whether their generated answers are correct. ASPIRE allows LLMs to output an answer along with a confidence score for that answer. Our experimental results demonstrate that ASPIRE significantly outperforms state-of-the-art selective prediction methods on a variety of QA datasets, such as the <a href="https://aclanthology.org/Q19-1016.pdf">CoQA benchmark</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>The mechanics of ASPIRE</h2>
<p>
Imagine teaching an LLM to not only answer questions but also evaluate those answers — akin to a student verifying their answers in the back of the textbook. That's the essence of ASPIRE, which involves three stages: (1) task-specific tuning, (2) answer sampling, and (3) self-evaluation learning.
</p>
<p>
<strong>Task-specific tuning</strong>: ASPIRE performs task-specific tuning to train adaptable parameters (θ<sub>p</sub>) while freezing the LLM. Given a training dataset for a generative task, it fine-tunes the pre-trained LLM to improve its prediction performance. Towards this end, parameter-efficient tuning techniques (e.g., <a href="https://aclanthology.org/2021.emnlp-main.243/">soft prompt tuning</a> and <a href="https://openreview.net/forum?id=nZeVKeeFYf9">LoRA</a>) might be employed to adapt the pre-trained LLM on the task, given their effectiveness in obtaining strong generalization with small amounts of target task data. Specifically, the LLM parameters (θ) are frozen and adaptable parameters (θ<sub>p</sub>) are added for fine-tuning. Only θ<sub>p</sub> are updated to minimize the standard LLM training loss (e.g., <a href="https://en.wikipedia.org/wiki/Cross-entropy#Cross-entropy_minimization">cross-entropy</a>). Such fine-tuning can improve selective prediction performance because it not only improves the prediction accuracy, but also enhances the likelihood of correct output sequences.
</p>
<p>
<strong>Answer sampling</strong>: After task-specific tuning, ASPIRE uses the LLM with the learned θ<sub>p</sub> to generate different answers for each training question and create a dataset for self-evaluation learning. We aim to generate output sequences that have a high likelihood. We use <a href="https://en.wikipedia.org/wiki/Beam_search">beam search</a> as the decoding algorithm to generate high-likelihood output sequences and the <a href="https://aclanthology.org/P04-1077/">Rouge-L</a> metric to determine if the generated output sequence is correct.
</p>
<p>
<strong>Self-evaluation learning</strong>: After sampling high-likelihood outputs for each query, ASPIRE adds adaptable parameters (θ<sub>s</sub>) and only fine-tunes θ<sub>s</sub> for learning self-evaluation. Since the output sequence generation only depends on θ and θ<sub>p</sub>, freezing θ and the learned θ<sub>p</sub> can avoid changing the prediction behaviors of the LLM when learning self-evaluation. We optimize θ<sub>s</sub> such that the adapted LLM can distinguish between correct and incorrect answers on their own.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKSO3s9PgcBr1MpkJR_PYI-ogQ79JD4A4-fJ_OgT8reSKEqIWSUPD7QUVSqIUuAhNfbgEA-XVrOne8S1oJSFaE6YIH4z43bn8jzsfG768qynW-G6lG7dwOvu15UCH6tdlIXEoe2dCUAHmT2bmNijwUvigF50W8vBCsCrjBA_FGYlnsmizHiyutHYZ1A-A2/s1999/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="933" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjKSO3s9PgcBr1MpkJR_PYI-ogQ79JD4A4-fJ_OgT8reSKEqIWSUPD7QUVSqIUuAhNfbgEA-XVrOne8S1oJSFaE6YIH4z43bn8jzsfG768qynW-G6lG7dwOvu15UCH6tdlIXEoe2dCUAHmT2bmNijwUvigF50W8vBCsCrjBA_FGYlnsmizHiyutHYZ1A-A2/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The three stages of the ASPIRE framework. </td></tr></tbody></table>
<br />
<p>
In the proposed framework, θ<sub>p</sub> and θ<sub>s</sub> can be trained using any parameter-efficient tuning approach. In this work, we use <a href="https://aclanthology.org/2021.emnlp-main.243/">soft prompt tuning</a>, a simple yet effective mechanism for learning “<a href="https://blog.research.google/2022/02/guiding-frozen-language-models-with.html">soft prompts</a>” to condition frozen language models to perform specific downstream tasks more effectively than traditional discrete text prompts. The driving force behind this approach lies in the recognition that if we can develop prompts that effectively stimulate self-evaluation, it should be possible to discover these prompts through soft prompt tuning in conjunction with targeted training objectives.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhCX9MjJ_KqSauAvXOvcXCDWR1H-hoD3e6EaSCEbG5-EJoJYLmekytCRSXaXrhNGS5BH7DfwbZW7FWzUaTErsGxKSWWM8lLAOxxDX3M5U4Zv8gERXBk_uCY7OVshLexrKt5GTSwrkRdFW0dAcMaALHvrLIosv7Tn4pRd7Rh35-HGWQ13SAeHtJ5-wsNgMNt/s800/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="466" data-original-width="800" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhCX9MjJ_KqSauAvXOvcXCDWR1H-hoD3e6EaSCEbG5-EJoJYLmekytCRSXaXrhNGS5BH7DfwbZW7FWzUaTErsGxKSWWM8lLAOxxDX3M5U4Zv8gERXBk_uCY7OVshLexrKt5GTSwrkRdFW0dAcMaALHvrLIosv7Tn4pRd7Rh35-HGWQ13SAeHtJ5-wsNgMNt/s16000/image1.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Implementation of the ASPIRE framework via soft prompt tuning. We first generate the answer to the question with the first soft prompt and then compute the learned self-evaluation score with the second soft prompt.</td></tr></tbody></table>
<br />
<p>
After training θ<sub>p</sub> and θ<sub>s</sub>, we obtain the prediction for the query via beam search decoding. We then define a selection score that combines the likelihood of the generated answer with the learned self-evaluation score (i.e., the likelihood of the prediction being correct for the query) to make selective predictions.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Results</h2>
<p>
To demonstrate ASPIRE’s efficacy, we evaluate it across three question-answering datasets — <a href="https://aclanthology.org/Q19-1016/">CoQA</a>, <a href="https://aclanthology.org/P17-1147/">TriviaQA</a>, and <a href="https://aclanthology.org/D16-1264/">SQuAD</a> — using various <a href="https://arxiv.org/abs/2205.01068">open pre-trained transformer</a> (OPT) models. By training θ<sub>p</sub> with soft prompt tuning, we observed a substantial hike in the LLMs' accuracy. For example, the <a href="https://arxiv.org/abs/2205.01068">OPT-2.7B</a> model adapted with ASPIRE demonstrated improved performance over the larger, pre-trained OPT-30B model using the CoQA and SQuAD datasets. These results suggest that with suitable adaptations, smaller LLMs might have the capability to match or potentially surpass the accuracy of larger models in some scenarios.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiT3H50FS2Ml9VoWJCuNnInzRQqtNEpyUXuwRnogIG-pDIlRduV0DmvyI9iQIHTziGdqkugV9SDjIcV8WPfnb8QyzQ3ACcusD7O1FQvyVe9U9C5iCbX-uS8xHNhbvg2uv_CPwe4UJASF_dPe8s-c-xz-1hplqXYYxw4wsLOw9dJ-vWrLz-Ei4BrrW-e9Wzc/s1999/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1500" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiT3H50FS2Ml9VoWJCuNnInzRQqtNEpyUXuwRnogIG-pDIlRduV0DmvyI9iQIHTziGdqkugV9SDjIcV8WPfnb8QyzQ3ACcusD7O1FQvyVe9U9C5iCbX-uS8xHNhbvg2uv_CPwe4UJASF_dPe8s-c-xz-1hplqXYYxw4wsLOw9dJ-vWrLz-Ei4BrrW-e9Wzc/s16000/image4.png" /></a></td></tr></tbody></table>
<p>
When delving into the computation of selection scores with fixed model predictions, ASPIRE received a higher <a href="https://openreview.net/pdf?id=VD-AYtP0dve">AUROC</a> score (the probability that a randomly chosen correct output sequence has a higher selection score than a randomly chosen incorrect output sequence) than baseline methods across all datasets. For example, on the CoQA benchmark, ASPIRE improves the AUROC from 51.3% to 80.3% compared to the baselines.
</p>
<p>
An intriguing pattern emerged from the TriviaQA dataset evaluations. While the pre-trained OPT-30B model demonstrated higher baseline accuracy, its performance in selective prediction did not improve significantly when traditional self-evaluation methods — <a href="https://arxiv.org/abs/2207.05221">Self-eval and P(True)</a> — were applied. In contrast, the smaller OPT-2.7B model, when enhanced with ASPIRE, outperformed in this aspect. This discrepancy underscores a vital insight: larger LLMs utilizing conventional self-evaluation techniques may not be as effective in selective prediction as smaller, ASPIRE-enhanced models.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgbQLSLnyNPe8LW9W3KLt7Tkp3gmLSLHkTbICIi5j__yWNt7aG9PmWyW5bE4vSs8q2iFqE5dlb0KOdtKzcmg1JuKdzZWFUxYXiatDPB7N-q1NhkkH9hEcgqlw33BFUh_v_8DQJnY5lMrXexv0HUvTPYRS_Gb-M75Rx_TBhxyvLrI-AhZJV243sUzo2gIPoh/s1999/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1599" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgbQLSLnyNPe8LW9W3KLt7Tkp3gmLSLHkTbICIi5j__yWNt7aG9PmWyW5bE4vSs8q2iFqE5dlb0KOdtKzcmg1JuKdzZWFUxYXiatDPB7N-q1NhkkH9hEcgqlw33BFUh_v_8DQJnY5lMrXexv0HUvTPYRS_Gb-M75Rx_TBhxyvLrI-AhZJV243sUzo2gIPoh/s16000/image3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"></td></tr></tbody></table>
<p>
Our experimental journey with ASPIRE underscores a pivotal shift in the landscape of LLMs: The capacity of a language model is not the be-all and end-all of its performance. Instead, the effectiveness of models can be drastically improved through strategic adaptations, allowing for more precise, confident predictions even in smaller models. As a result, ASPIRE stands as a testament to the potential of LLMs that can judiciously ascertain their own certainty and decisively outperform larger counterparts in selective prediction tasks.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
In conclusion, ASPIRE is not just another framework; it's a vision of a future where LLMs can be trusted partners in decision-making. By honing the selective prediction performance, we're inching closer to realizing the full potential of AI in critical applications.
</p>
<p>
Our research has opened new doors, and we invite the community to build upon this foundation. We're excited to see how ASPIRE will inspire the next generation of LLMs and beyond. To learn more about our findings, we encourage you to read our paper and join us in this thrilling journey towards creating a more reliable and self-aware AI.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgments</h2>
<p>
<em>We gratefully acknowledge the contributions of Sayna Ebrahimi, Sercan O Arik, Tomas Pfister, and Somesh Jha.</em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-18004301292052687062024-01-12T09:04:00.000-08:002024-01-16T10:55:15.626-08:00AMIE: A research AI system for diagnostic medical reasoning and conversations<span class="byline-author">Posted by Alan Karthikesalingam and Vivek Natarajan, Research Leads, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgr_wfWpw2CVBYDc3Mlk879CUecv4uGlq36Fxe0GEnVcK2kJnzAyRkPRb8vO5jJVqrd_zvQ6W8suHyp1xhFhNFJuUj8nTNRp3TsZP7Z5uWlqw22hZZKVJJ33X5NWmT0UTkOdC4raONlnSbR8E616Mi_lJVE3DvbWYB-19eR2wpCgwAaykkquUV3DOLRY6c/s16000/AMIE.gif" style="display: none;" />
<p>
The physician-patient conversation is a cornerstone of medicine, in which skilled and intentional communication drives diagnosis, management, empathy and trust. AI systems capable of such diagnostic dialogues could increase availability, accessibility, quality and consistency of care by being useful conversational partners to clinicians and patients alike. But approximating clinicians’ considerable expertise is a significant challenge.
</p>
<a name='more'></a>
<p>
Recent progress in large language models (LLMs) outside the medical domain has shown that they can plan, reason, and use relevant context to hold rich conversations. However, there are many aspects of good diagnostic dialogue that are unique to the medical domain. An effective clinician takes a complete “clinical history” and asks intelligent questions that help to derive a differential diagnosis. They wield considerable skill to foster an effective relationship, provide information clearly, make joint and informed decisions with the patient, respond empathically to their emotions, and support them in the next steps of care. While LLMs can accurately perform tasks such as medical summarization or answering medical questions, there has been little work specifically aimed towards developing these kinds of conversational diagnostic capabilities.
</p>
<p>
Inspired by this challenge, we developed <a href="https://arxiv.org/abs/2401.05654">Articulate Medical Intelligence Explorer (AMIE)</a>, a research AI system based on a LLM and optimized for diagnostic reasoning and conversations. We trained and evaluated AMIE along many dimensions that reflect quality in real-world clinical consultations from the perspective of both clinicians and patients. To scale AMIE across a multitude of disease conditions, specialties and scenarios, we developed a novel self-play based simulated diagnostic dialogue environment with automated feedback mechanisms to enrich and accelerate its learning process. We also introduced an inference time chain-of-reasoning strategy to improve AMIE’s diagnostic accuracy and conversation quality. Finally, we tested AMIE prospectively in real examples of multi-turn dialogue by simulating consultations with trained actors.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh1TB1GjBHpX7Kazzao7l_ysdB7rXdxZQG7CodkfM4A7cYSJRKUEfhZL4iFJ4BI0ipp9o4rPam4ARcp0v98V_1CtcYPb9fCalxW3Y_vekZl1iDtkdfshLbAi_OSbwuaecYtMosCRUtgvAMYWSoASj7A7OgAPfzVEQbwMOmkfzNnWKot2dtyAmQmFDtYKy4/s1200/AMIE%20GIF%201%20v2.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="512" data-original-width="1200" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh1TB1GjBHpX7Kazzao7l_ysdB7rXdxZQG7CodkfM4A7cYSJRKUEfhZL4iFJ4BI0ipp9o4rPam4ARcp0v98V_1CtcYPb9fCalxW3Y_vekZl1iDtkdfshLbAi_OSbwuaecYtMosCRUtgvAMYWSoASj7A7OgAPfzVEQbwMOmkfzNnWKot2dtyAmQmFDtYKy4/s16000/AMIE%20GIF%201%20v2.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">AMIE was optimized for diagnostic conversations, asking questions that help to reduce its uncertainty and improve diagnostic accuracy, while also balancing this with other requirements of effective clinical communication, such as empathy, fostering a relationship, and providing information clearly.</td></tr></tbody></table>
<br />
<h2>Evaluation of conversational diagnostic AI</h2>
<p>
Besides developing and optimizing AI systems themselves for diagnostic conversations, how to assess such systems is also an open question. Inspired by accepted tools used to measure consultation quality and clinical communication skills in real-world settings, we constructed a pilot evaluation rubric to assess diagnostic conversations along axes pertaining to history-taking, diagnostic accuracy, clinical management, clinical communication skills, relationship fostering and empathy.
</p>
<p>
We then designed a randomized, double-blind crossover study of text-based consultations with validated patient actors interacting either with board-certified primary care physicians (PCPs) or the AI system optimized for diagnostic dialogue. We set up our consultations in the style of an <a href="https://en.wikipedia.org/wiki/Objective_structured_clinical_examination">objective structured clinical examination</a> (OSCE), a practical assessment commonly used in the real world to examine clinicians’ skills and competencies in a standardized and objective way. In a typical OSCE, clinicians might rotate through multiple stations, each simulating a real-life clinical scenario where they perform tasks such as conducting a consultation with a standardized patient actor (trained carefully to emulate a patient with a particular condition). Consultations were performed using a synchronous text-chat tool, mimicking the interface familiar to most consumers using LLMs today.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBTzb4Nx-dNLSwgDjJka4tYA29gFEnPna3N68cdriUmaybI1IGhqxPMLzObLg1nRh3S7pd21G7CHczMvy7tUHyETClVRu8Hv3J9gQRd_WGYwORLiylKUgNViILvFO068daetL3MzJAa2rGU7Yzjg2BkfUas0hQgP9yBwmH_Wx3wmCpGfuZ-75yejQTAg8/s1350/image4.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1200" data-original-width="1350" height="568" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjBTzb4Nx-dNLSwgDjJka4tYA29gFEnPna3N68cdriUmaybI1IGhqxPMLzObLg1nRh3S7pd21G7CHczMvy7tUHyETClVRu8Hv3J9gQRd_WGYwORLiylKUgNViILvFO068daetL3MzJAa2rGU7Yzjg2BkfUas0hQgP9yBwmH_Wx3wmCpGfuZ-75yejQTAg8/w640-h568/image4.gif" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">AMIE is a research AI system based on LLMs for diagnostic reasoning and dialogue.</td></tr></tbody></table>
<br />
<h2>AMIE: an LLM-based conversational diagnostic research AI system </h2>
<p>
We trained AMIE on real-world datasets comprising medical reasoning, medical summarization and real-world clinical conversations.
</p>
<p>
It is feasible to train LLMs using real-world dialogues developed by passively collecting and transcribing in-person clinical visits, however, two substantial challenges limit their effectiveness in training LLMs for medical conversations. First, existing real-world data often fails to capture the vast range of medical conditions and scenarios, hindering the scalability and comprehensiveness. Second, the data derived from real-world dialogue transcripts tends to be noisy, containing ambiguous language (including slang, jargon, humor and sarcasm), interruptions, ungrammatical utterances, and implicit references.
</p>
<p>
To address these limitations, we designed a self-play based simulated learning environment with automated feedback mechanisms for diagnostic medical dialogue in a virtual care setting, enabling us to scale AMIE’s knowledge and capabilities across many medical conditions and contexts. We used this environment to iteratively fine-tune AMIE with an evolving set of simulated dialogues in addition to the static corpus of real-world data described.
</p>
<p>
This process consisted of two self-play loops: (1) an “inner” self-play loop, where AMIE leveraged in-context critic feedback to refine its behavior on simulated conversations with an AI patient simulator; and (2) an “outer” self-play loop where the set of refined simulated dialogues were incorporated into subsequent fine-tuning iterations. The resulting new version of AMIE could then participate in the inner loop again, creating a virtuous continuous learning cycle.
</p>
<p>
Further, we also employed an inference time chain-of-reasoning strategy which enabled AMIE to progressively refine its response conditioned on the current conversation to arrive at an informed and grounded reply.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiASnZ5p1olZNNL1e-bzqA6s1WprbenNRebHvq2sXuRhYHtHBMw5s1sgAR6SXS4lSDkBD_WgsY6mepBCoLojtes3GOU3yCoOPRGLoGpkMV99TM1Ru0xpNSNWee-5xfkUGBeE9fnp_rY8t_0Dv3NIxVnj9iGPNSyoGtJ6N9MSsjFsIqDpJVhsAQbCBoyJ1-N/s1400/image1.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="800" data-original-width="1400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiASnZ5p1olZNNL1e-bzqA6s1WprbenNRebHvq2sXuRhYHtHBMw5s1sgAR6SXS4lSDkBD_WgsY6mepBCoLojtes3GOU3yCoOPRGLoGpkMV99TM1Ru0xpNSNWee-5xfkUGBeE9fnp_rY8t_0Dv3NIxVnj9iGPNSyoGtJ6N9MSsjFsIqDpJVhsAQbCBoyJ1-N/s16000/image1.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">AMIE uses a novel self-play based simulated dialogue learning environment to improve the quality of diagnostic dialogue across a multitude of disease conditions, specialities and patient contexts.</td></tr></tbody></table>
<!--<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiDZCFvlZe9612iZ0gY2ov193z0iDcGQI0I9I1o0KRullBPhVIxyHerVflj5dm24vxyRqlNiPyLnfHW1yc9gE88KNekb_WQ4as6qO5BExL9K_aDXc9Ypx6DYNqWmvP5QoJ2sVitVVfVMvyLY2DJ7Ck92fwFAaHEeF2JmSMnftbXpAAUjS6ADM4F3dctOz4/s1834/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1284" data-original-width="1834" height="448" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiDZCFvlZe9612iZ0gY2ov193z0iDcGQI0I9I1o0KRullBPhVIxyHerVflj5dm24vxyRqlNiPyLnfHW1yc9gE88KNekb_WQ4as6qO5BExL9K_aDXc9Ypx6DYNqWmvP5QoJ2sVitVVfVMvyLY2DJ7Ck92fwFAaHEeF2JmSMnftbXpAAUjS6ADM4F3dctOz4/w640-h448/image4.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">AMIE uses a novel self-play based simulated dialogue learning environment to improve the quality of diagnostic dialogue across a multitude of disease conditions, specialities and patient contexts.</td></tr></tbody></table>-->
<br />
<p>
We tested performance in consultations with simulated patients (played by trained actors), compared to those performed by 20 real PCPs using the randomized approach described above. AMIE and PCPs were assessed from the perspectives of both specialist attending physicians and our simulated patients in a randomized, blinded crossover study that included 149 case scenarios from OSCE providers in Canada, the UK and India in a diverse range of specialties and diseases.
</p>
<p>
Notably, our study was not designed to emulate either traditional in-person OSCE evaluations or the ways clinicians usually use text, email, chat or telemedicine. Instead, our experiment mirrored the most common way consumers interact with LLMs today, a potentially scalable and familiar mechanism for AI systems to engage in remote diagnostic dialogue.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgSuFcXXr8lxIA1-p428y31PqVkpEMlQAdKj-vdTmtVYqeuogSQAjWFM3Gj8akNVG-6Cyd9xZKbLKz0jUFABDU_JjwcM35qLVsSi5zlB3frgei5NAwhTQ7PyEbipXJK8gPvb0vY_VKGrZ-GFAmwtDf-pcJbKtY5DCEBb43N5rtnSa8gYdLxl5aQxpAZ1qo/s1999/image8.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1296" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgSuFcXXr8lxIA1-p428y31PqVkpEMlQAdKj-vdTmtVYqeuogSQAjWFM3Gj8akNVG-6Cyd9xZKbLKz0jUFABDU_JjwcM35qLVsSi5zlB3frgei5NAwhTQ7PyEbipXJK8gPvb0vY_VKGrZ-GFAmwtDf-pcJbKtY5DCEBb43N5rtnSa8gYdLxl5aQxpAZ1qo/s16000/image8.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Overview of the randomized study design to perform a virtual remote OSCE with simulated patients via online multi-turn synchronous text chat.</td></tr></tbody></table>
<br />
<h2>Performance of AMIE</h2>
<p>
In this setting, we observed that AMIE performed simulated diagnostic conversations at least as well as PCPs when both were evaluated along multiple clinically-meaningful axes of consultation quality. AMIE had greater diagnostic accuracy and superior performance for 28 of 32 axes from the perspective of specialist physicians, and 24 of 26 axes from the perspective of patient actors.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhvP4CCxNKLIBLIBfPnJZ-7mRYZqxINGJ8_Uos8K3Gd8PJrFURwBYYjJePGqHpa63nFQR2aahi3HcwPos9NCV-fknrdVRsrwJCI6qFub84f5g5gNo_SvuosZt7Rjm5LXOQuVvG0n_GmzL6jNhihROxls9ZQBA5aVPod_onwurffiTI12F6d4wwfbeNMdxQ/s1834/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1284" data-original-width="1834" height="448" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhvP4CCxNKLIBLIBfPnJZ-7mRYZqxINGJ8_Uos8K3Gd8PJrFURwBYYjJePGqHpa63nFQR2aahi3HcwPos9NCV-fknrdVRsrwJCI6qFub84f5g5gNo_SvuosZt7Rjm5LXOQuVvG0n_GmzL6jNhihROxls9ZQBA5aVPod_onwurffiTI12F6d4wwfbeNMdxQ/w640-h448/image4.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">AMIE outperformed PCPs on multiple evaluation axes for diagnostic dialogue in our evaluations.</td></tr></tbody></table>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZvkY54tqvLCbfTrEFR5e_T1eEXJZvn3V__lBts2bukKDwuJkLmDo5w-ilA8B44JwDPUv5v5hzCN9WRWttPEZ2qN1wQaGQR0SRjjVhapLDxg6Te5YLjPqgUwoDCot2sBujGLVHgIrKFXUkT3bKzL1MLHCMxEMs0pC5ZMoi-PTAhPFLgW7bsjsi5jy3pL0/s1999/image9.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="752" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhZvkY54tqvLCbfTrEFR5e_T1eEXJZvn3V__lBts2bukKDwuJkLmDo5w-ilA8B44JwDPUv5v5hzCN9WRWttPEZ2qN1wQaGQR0SRjjVhapLDxg6Te5YLjPqgUwoDCot2sBujGLVHgIrKFXUkT3bKzL1MLHCMxEMs0pC5ZMoi-PTAhPFLgW7bsjsi5jy3pL0/s16000/image9.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Specialist-rated top-k diagnostic accuracy. AMIE and PCPs top-k differential diagnosis (DDx) accuracy are compared across 149 scenarios with respect to the ground truth diagnosis (a) and all diagnoses listed within the accepted differential diagnoses (b). Bootstrapping (n=10,000) confirms all top-k differences between AMIE and PCP DDx accuracy are significant with p <0.05 after <a href="https://en.wikipedia.org/wiki/False_discovery_rate">false discovery rate</a> (FDR) correction.</td></tr></tbody></table>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3FLScvpxSucelwHpfxEC_pMR4cXG3ioiw1lRJs1XgWbuGM8mLS635ryiJJOF7ZOuuA4t0rkj1OXWXB57GW-FQcNcYq_TKfTPyCLm-EV3Ivk5yPgYdjYKxT8-yxQnDz4mNJwKop4yS3XvyNpcUzVOrhm0MrJKs5DVfl-u8hgwhwkI6kZAvCto4Z4JGcnTb/s1999/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1864" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi3FLScvpxSucelwHpfxEC_pMR4cXG3ioiw1lRJs1XgWbuGM8mLS635ryiJJOF7ZOuuA4t0rkj1OXWXB57GW-FQcNcYq_TKfTPyCLm-EV3Ivk5yPgYdjYKxT8-yxQnDz4mNJwKop4yS3XvyNpcUzVOrhm0MrJKs5DVfl-u8hgwhwkI6kZAvCto4Z4JGcnTb/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Diagnostic conversation and reasoning qualities as assessed by specialist physicians. On 28 out of 32 axes, AMIE outperformed PCPs while being comparable on the rest.</td></tr></tbody></table>
<!--<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjNa2OOpTrDPJWdNMWqfl0Mp3Yjn6d0DbFLTTwnpcAmjxddEt6rr6ryOBE_KNWbtce0bRFRYJYOVqdA9eetEptfRWgoWJ4-4LEka8RJMZ7p3qcjqGWtA2PKRxyMZKzbtRXCfvQqMQrpVgGrefJB0QwzO1GxTkNAPwhrQ1HXHFbUMZpR7fFhhBKUylktESg/s1892/image6.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1714" data-original-width="1892" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjNa2OOpTrDPJWdNMWqfl0Mp3Yjn6d0DbFLTTwnpcAmjxddEt6rr6ryOBE_KNWbtce0bRFRYJYOVqdA9eetEptfRWgoWJ4-4LEka8RJMZ7p3qcjqGWtA2PKRxyMZKzbtRXCfvQqMQrpVgGrefJB0QwzO1GxTkNAPwhrQ1HXHFbUMZpR7fFhhBKUylktESg/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Diagnostic conversation and reasoning qualities as assessed by specialist physicians. On 28 out of 32 axes, AMIE outperformed PCPs while being comparable on the rest.</td></tr></tbody></table>-->
<br />
<h2>Limitations</h2>
<p>
Our research has several limitations and should be interpreted with appropriate caution. Firstly, our evaluation technique likely underestimates the real-world value of human conversations, as the clinicians in our study were limited to an unfamiliar text-chat interface, which permits large-scale LLM–patient interactions but is not representative of usual clinical practice. Secondly, any research of this type must be seen as only a first exploratory step on a long journey. Transitioning from a LLM research prototype that we evaluated in this study to a safe and robust tool that could be used by people and those who provide care for them will require significant additional research. There are many important limitations to be addressed, including experimental performance under real-world constraints and dedicated exploration of such important topics as health equity and fairness, privacy, robustness, and many more, to ensure the safety and reliability of the technology.
</p>
<br />
<h2>AMIE as an aid to clinicians</h2>
<p>
In a <a href="https://arxiv.org/abs/2312.00164">recently released preprint</a>, we evaluated the ability of an earlier iteration of the AMIE system to generate a DDx alone or as an aid to clinicians. Twenty (20) generalist clinicians evaluated 303 challenging, real-world medical cases sourced from the <em><a href="https://www.nejm.org/">New England Journal of Medicine</a></em> (NEJM) <a href="https://www.nejm.org/case-challenges">ClinicoPathologic Conferences</a> (CPCs). Each case report was read by two clinicians randomized to one of two assistive conditions: either assistance from search engines and standard medical resources, or AMIE assistance in addition to these tools. All clinicians provided a baseline, unassisted DDx prior to using the respective assistive tools.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEia-nKW28SyAk0DHT21L2CsB18JKUmmt2sPtafGvRtJWrOEgfn1v_hXDtSIJsFP2m66tBA33MwMHXKQSL-nGKfvMTKASXUVZ5n_I4VytKfa0S3EN5vf2TeMHfmOtMLCJtfD3PCvMMc8PJsbIYu-iikFu4atfCOBa-a5yHTM2Tok1wjZpkmBbvioUhXz4Dc/s1999/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1022" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEia-nKW28SyAk0DHT21L2CsB18JKUmmt2sPtafGvRtJWrOEgfn1v_hXDtSIJsFP2m66tBA33MwMHXKQSL-nGKfvMTKASXUVZ5n_I4VytKfa0S3EN5vf2TeMHfmOtMLCJtfD3PCvMMc8PJsbIYu-iikFu4atfCOBa-a5yHTM2Tok1wjZpkmBbvioUhXz4Dc/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Assisted randomized reader study setup to investigate the assistive effect of AMIE to clinicians in solving complex diagnostic case challenges from the New England Journal of Medicine.</td></tr></tbody></table>
<p>
AMIE exhibited standalone performance that exceeded that of unassisted clinicians (top-10 accuracy 59.1% vs. 33.6%, p= 0.04). Comparing the two assisted study arms, the top-10 accuracy was higher for clinicians assisted by AMIE, compared to clinicians without AMIE assistance (24.6%, p<0.01) and clinicians with search (5.45%, p=0.02). Further, clinicians assisted by AMIE arrived at more comprehensive differential lists than those without AMIE assistance.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiO1YlFJdQMCt1ZMTQN5neWeYoljjA7Y13FP_2c7q85hSKbLCdNLJtUt1VtBFlCUBlGTIviqdr4XWnnandULaKfGlyPh89QzzaHXmb-wFxYfkwbRv5OO9Wni6Hr04jVO_W1w2cs7RQcCRWCWrW9lxM3t61BI3ZPK6hdsv7RAQAn8TN6s80nP9nwjia0xig/s1999/image7.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1424" data-original-width="1999" height="456" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiO1YlFJdQMCt1ZMTQN5neWeYoljjA7Y13FP_2c7q85hSKbLCdNLJtUt1VtBFlCUBlGTIviqdr4XWnnandULaKfGlyPh89QzzaHXmb-wFxYfkwbRv5OO9Wni6Hr04jVO_W1w2cs7RQcCRWCWrW9lxM3t61BI3ZPK6hdsv7RAQAn8TN6s80nP9nwjia0xig/w640-h456/image7.png" width="640" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">In addition to strong standalone performance, using the AMIE system led to significant assistive effect and improvements in diagnostic accuracy of the clinicians in solving these complex case challenges.</td></tr></tbody></table>
<p>
It's worth noting that NEJM CPCs are not representative of everyday clinical practice. They are unusual case reports in only a few hundred individuals so offer limited scope for probing important issues like equity or fairness.
</p>
<br />
<h2>Bold and responsible research in healthcare — the art of the possible </h2>
<p>
Access to clinical expertise remains scarce around the world. While AI has shown great promise in specific clinical applications, engagement in the dynamic, conversational diagnostic journeys of clinical practice requires many capabilities not yet demonstrated by AI systems. Doctors wield not only knowledge and skill but a dedication to myriad principles, including safety and quality, communication, partnership and teamwork, trust, and professionalism. Realizing these attributes in AI systems is an inspiring challenge that should be approached responsibly and with care. AMIE is our exploration of the “art of the possible”, a research-only system for safely exploring a vision of the future where AI systems might be better aligned with attributes of the skilled clinicians entrusted with our care. It is early experimental-only work, not a product, and has several limitations that we believe merit rigorous and extensive further scientific studies in order to envision a future in which conversational, empathic and diagnostic AI systems might become safe, helpful and accessible.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>The research described here is joint work across many teams at Google Research and Google Deepmind. We are grateful to all our co-authors - Tao Tu, Mike Schaekermann, Anil Palepu, Daniel McDuff, Jake Sunshine, Khaled Saab, Jan Freyberg, Ryutaro Tanno, Amy Wang, Brenna Li, Mohamed Amin, Sara Mahdavi, Karan Sighal, Shekoofeh Azizi, Nenad Tomasev, Yun Liu, Yong Cheng, Le Hou, Albert Webson, Jake Garrison, Yash Sharma, Anupam Pathak, Sushant Prakash, Philip Mansfield, Shwetak Patel, Bradley Green, Ewa Dominowska, Renee Wong, Juraj Gottweis, Dale Webster, Katherine Chou, Christopher Semturs, Joelle Barral, Greg Corrado and Yossi Matias. We also thank Sami Lachgar, Lauren Winer and John Guilyard for their support with narratives and the visuals. Finally, we are grateful to Michael Howell, James Manyika, Jeff Dean, Karen DeSalvo, Zoubin Ghahramani and Demis Hassabis for their support during the course of this project</em>.
</p><br />Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-79981187857771645742024-01-11T14:42:00.000-08:002024-01-11T14:42:51.944-08:00Can large language models identify and correct their mistakes?<span class="byline-author">Posted by Gladys Tyen, Intern, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhaRCK2QC8HmH0lzm2lPjqVOxFPZDoyTAPq9icazR1vrsFUTDr5OJdTIZNMDgut_ylOOmeZmA4n0BTIgFCsksZ_xATbJDnQegxWMpdqv2kyGBWKMTV9E2k3WybhBzhL3-oQnpNUWWTKURxt5y8f7gGwUkPTExml1QD2U-UqW0hglZ-cXXCkznmufJIfyPAR/s1600/Backtracking.jpg" style="display: none;" />
<p>
LLMs are increasingly popular for reasoning tasks, such as <a href="https://hotpotqa.github.io/">multi-turn QA</a>, <a href="https://arxiv.org/abs/2207.01206">task completion</a>, <a href="https://arxiv.org/abs/2304.05128">code generation</a>, or <a href="https://github.com/openai/grade-school-math">mathematics</a>. Yet much like people, they do not always solve problems correctly on the first try, especially on tasks for which they were not trained. Therefore, for such systems to be most useful, they should be able to 1) identify where their reasoning went wrong and 2) backtrack to find another solution.
</p>
<a name='more'></a>
<p>
This has led to a surge in methods related to <em>self-correction</em>, where an LLM is used to identify problems in its own output, and then produce improved results based on the feedback. Self-correction is generally thought of as a single process, but we decided to break it down into two components, <em>mistake finding<strong> </strong></em>and <em>output correction</em>.
</p>
<p>
In “<a href="https://arxiv.org/abs/2311.08516#:~:text=While%20self%2Dcorrection%20has%20shown,et%20al.%2C%202023).">LLMs cannot find reasoning errors, but can correct them!</a>”, we test state-of-the-art LLMs on mistake finding and output correction separately. We present <a href="https://github.com/WHGTyen/BIG-Bench-Mistake">BIG-Bench Mistake</a>, an evaluation benchmark dataset for mistake identification, which we use to address the following questions:
</p>
<ol>
<li>Can LLMs find logical mistakes in <a href="https://arxiv.org/abs/2201.11903">Chain-of-Thought</a> (CoT) style reasoning?
</li><li>Can mistake-finding be used as a proxy for correctness?
</li><li>Knowing where the mistake is, can LLMs then be prompted to backtrack and arrive at the correct answer?
</li><li>Can mistake finding as a skill generalize to tasks the LLMs have never seen?
</li>
</ol>
<br />
<h2>About our dataset</h2>
<p>
Mistake finding is an underexplored problem in natural language processing, with a particular lack of evaluation tasks in this domain. To best assess the ability of LLMs to find mistakes, evaluation tasks should exhibit mistakes that are non-ambiguous. To our knowledge, most current mistake-finding datasets do not go beyond the realm of <a href="https://github.com/openai/grade-school-math">mathematics</a> for this reason.
</p>
<p>
To assess the ability of LLMs to reason about mistakes outside of the math domain, we produce a new dataset for use by the research community, called<strong> </strong><a href="https://github.com/WHGTyen/BIG-Bench-Mistake">BIG-Bench Mistake</a>. This dataset consists of Chain-of-Thought traces generated using <a href="https://ai.google/discover/palm2/">PaLM 2</a> on five tasks in <a href="https://github.com/suzgunmirac/BIG-Bench-Hard">BIG-Bench</a>. Each trace is annotated with the location of the first logical mistake.
</p>
<p>
To maximize the number of mistakes in our dataset, we sample 255 traces where the answer is incorrect (so we know there is definitely a mistake), and 45 traces where the answer is correct (so there may or may not be a mistake). We then ask human labelers to go through each trace and identify the first mistake step. Each trace has been annotated by at least three labelers, whose answers had <a href="https://en.wikipedia.org/wiki/Inter-rater_reliability">inter-rater reliability</a> levels of >0.98 (using <a href="https://en.wikipedia.org/wiki/Krippendorff%27s_alpha">Krippendorff’s α</a>). The labeling was done for all tasks except the <a href="https://github.com/google/BIG-bench/tree/main/bigbench/benchmark_tasks/dyck_languages">Dyck Languages task</a>, which involves predicting the sequence of closing parentheses for a given input sequence. This task we labeled algorithmically.
</p>
<p>
The logical errors made in this dataset are simple and unambiguous, providing a good benchmark for testing an LLM’s ability to find its own mistakes before using them on harder, more ambiguous tasks.</p>
<div class="separator" style="clear: both; text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgvk5zKLBvc2Ou6RpJc9l-lLqwHW6nWARuc2IAckSQ2SPYX6-UQj9Z8FyOB5emaBvXPta4MWqR1gis9FMEXeafffprNpyPmF_XaBOQ7tQpRpEylbnSlbwytNv1BFXlz5I-ulNM0ZBC7kBhx2KkdCT5MIejwdsHKpHu6rrJ4LBVd-Na_XUn5DCy0EKtj1Uy6/s1354/BBMistakes2.png" style="margin-left: 1em; margin-right: 1em;"><img border="0" data-original-height="410" data-original-width="1354" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgvk5zKLBvc2Ou6RpJc9l-lLqwHW6nWARuc2IAckSQ2SPYX6-UQj9Z8FyOB5emaBvXPta4MWqR1gis9FMEXeafffprNpyPmF_XaBOQ7tQpRpEylbnSlbwytNv1BFXlz5I-ulNM0ZBC7kBhx2KkdCT5MIejwdsHKpHu6rrJ4LBVd-Na_XUn5DCy0EKtj1Uy6/s16000/BBMistakes2.png" /></a></div>
<br />
<h2>Core questions about mistake identification</h2>
<div style="line-height: 40%;">
<br />
</div>
<h3>1. Can LLMs find logical mistakes in Chain-of-Thought style reasoning?</h3>
<p>
First, we want to find out if LLMs can identify mistakes independently of their ability to correct them. We attempt multiple prompting methods to test <a href="https://en.wikipedia.org/wiki/Generative_pre-trained_transformer">GPT</a> series models for their ability to locate mistakes (prompts <a href="https://github.com/WHGTyen/BIG-Bench-Mistake/tree/main/mistake_finding_prompts">here</a>) under the assumption that they are generally representative of modern LLM performance.
</p>
<p>
Generally, we found these state-of-the-art models perform poorly, with the best model achieving 52.9% accuracy overall. Hence, there is a need to improve LLMs’ ability in this area of reasoning.
</p>
<p>
In our experiments, we try three different prompting methods: direct (trace), direct (step) and CoT (step). In direct (trace), we provide the LLM with the trace and ask for the location step of the mistake or <em>no mistake</em>. In direct (step), we prompt the LLM to ask itself this question for each step it takes. In CoT (step), we prompt the LLM to give its reasoning for whether each step is a mistake or not a mistake.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYeNEXWh6vhy4SvTgSE5kOYYdRV3vFdUGs9zorH7bI010gD4gFziPwtig3bvlJFnzEpOgcQZZbn_2_KDEiqwFgdtimB-IYhhROTmtTKoxmmWF0jzI1IKfU3ZSeAhqEDJgLBwkmdUrbMDd9uYo3kLvK5uhygNRU2mkuRhnW3ZofDkYw-CsjKzFUQdplpFfe/s1061/image2.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="466" data-original-width="1061" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhYeNEXWh6vhy4SvTgSE5kOYYdRV3vFdUGs9zorH7bI010gD4gFziPwtig3bvlJFnzEpOgcQZZbn_2_KDEiqwFgdtimB-IYhhROTmtTKoxmmWF0jzI1IKfU3ZSeAhqEDJgLBwkmdUrbMDd9uYo3kLvK5uhygNRU2mkuRhnW3ZofDkYw-CsjKzFUQdplpFfe/s16000/image2.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A diagram showing the three prompting methods direct (trace), direct (step) and CoT (step).</td></tr></tbody></table>
<p>
Our finding is in line and builds upon <a href="https://arxiv.org/abs/2310.01798">prior results</a>, but goes further in showing that LLMs struggle with even simple and unambiguous mistakes (for comparison, our human raters without prior expertise solve the problem with a high degree of agreement). We hypothesize that this is a big reason why LLMs are unable to self-correct reasoning errors. See <a href="https://arxiv.org/abs/2311.08516">the paper</a> for the full results.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>2. Can mistake-finding be used as a proxy for correctness of the answer?</h3>
<p>
When people are confronted with a problem where we are unsure of the answer, we can work through our solutions step-by-step. If no error is found, we can make the assumption that we did the right thing.
</p>
<p>
While we hypothesized that this would work similarly for LLMs, we discovered that this is a poor strategy. On our dataset of 85% incorrect traces and 15% correct traces, using this method is not much better than the naïve strategy of always labeling traces as incorrect, which gives a weighted average <a href="https://en.wikipedia.org/wiki/F-score">F1</a> of 78.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi07Eh2ZJrPHzVPJkom0V-tc51me104pXKoAmHrVaNwNgL8CW4QCIv4js4_aZzabllySdTx5vpHv_5T0NwKDB7nDcfHaNpx7C-fkoWKArltSWSWoXSTB5_4IPr2uOdjpsZKVMBfqVJUejyEuvy5SFC0y8933eBb6hxuvtbyoa-CcWfyQdDtBwBdMadk04JU/s698/Self-correcting-LLMs-Tasks.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="427" data-original-width="698" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi07Eh2ZJrPHzVPJkom0V-tc51me104pXKoAmHrVaNwNgL8CW4QCIv4js4_aZzabllySdTx5vpHv_5T0NwKDB7nDcfHaNpx7C-fkoWKArltSWSWoXSTB5_4IPr2uOdjpsZKVMBfqVJUejyEuvy5SFC0y8933eBb6hxuvtbyoa-CcWfyQdDtBwBdMadk04JU/s16000/Self-correcting-LLMs-Tasks.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A diagram showing how well mistake-finding with LLMs can be used as a proxy for correctness of the answer on each dataset.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>3. Can LLMs backtrack knowing where the error is?</h3>
<p>
Since we’ve shown that LLMs exhibit poor performance in finding reasoning errors in CoT traces, we want to know whether LLMs can even correct errors <em>at all</em>, even if they know where the error is.
</p>
<p>
Note that knowing the <em>mistake location</em> is different from knowing <em>the right answer</em>: CoT traces can contain logical mistakes even if the final answer is correct, or vice versa. In most real-world situations, we won’t know what the right answer is, but we might be able to identify logical errors in intermediate steps.
</p>
<p>
We propose the following backtracking method:
</p>
<ol>
<li>Generate CoT traces as usual, at temperature = 0. (Temperature is a parameter that controls the randomness of generated responses, with higher values producing more diverse and creative outputs, usually at the expense of quality.)
</li><li>Identify the location of the first logical mistake (for example with a classifier, or here we just use labels from our dataset).
</li><li>Re-generate the mistake step at temperature = 1 and produce a set of eight outputs. Since the original output is known to lead to incorrect results, the goal is to find an alternative generation at this step that is significantly different from the original.
</li><li>From these eight outputs, select one that is different from the original mistake step. (We just use exact matching here, but in the future this can be something more sophisticated.)
</li><li>Using the new step, generate the rest of the trace as normal at temperature = 0.
</li>
</ol>
<p>
It’s a very simple method that does not require any additional prompt crafting and avoids having to re-generate the entire trace. We test it using the mistake location data from BIG-Bench Mistake, and we find that it can correct CoT errors.
</p>
<p>
<a href="https://arxiv.org/abs/2310.01798">Recent work</a> showed that self-correction methods, like <a href="https://arxiv.org/abs/2303.11366">Reflexion</a> and <a href="https://arxiv.org/abs/2303.17491">RCI</a>, cause deterioration in accuracy scores because there are more correct answers becoming incorrect than vice versa. Our method, on the other hand, produces more gains (by correcting wrong answers) than losses (by changing right answers to wrong answers).
</p>
<p>
We also compare our method with a random baseline, where we randomly assume a step to be a mistake. Our results show that this random baseline does produce some gains, but not as much as backtracking with the correct mistake location, and with more losses.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4CQ3amwJgMJ7OZToizp04vYIOj7F4kTdVO5DgthHi_HQQNa2FrkKjBA7LB249yN44kXFs0lZVuX1W4n3GLQI51Fy95ls-gK_rMLQQETYNmvqI7OS7U6xHgXx2cUnhTcwmZrpFWrS1vd01G14gWOexxZDTcpOIbGTHfXRgLWj22OteqMh_iTK2Rg_fC7xt/s744/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="480" data-original-width="744" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4CQ3amwJgMJ7OZToizp04vYIOj7F4kTdVO5DgthHi_HQQNa2FrkKjBA7LB249yN44kXFs0lZVuX1W4n3GLQI51Fy95ls-gK_rMLQQETYNmvqI7OS7U6xHgXx2cUnhTcwmZrpFWrS1vd01G14gWOexxZDTcpOIbGTHfXRgLWj22OteqMh_iTK2Rg_fC7xt/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A diagram showing the gains and losses in accuracy for our method as well as a random baseline on each dataset.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>4. Can mistake finding generalize to tasks the LLMs have never seen?</h3>
<p>
To answer this question, we fine-tuned a small model on four of the BIG-Bench tasks and tested it on the fifth, held-out task. We do this for every task, producing five fine-tuned models in total. Then we compare the results with just zero-shot prompting <a href="https://blog.google/technology/ai/google-palm-2-ai-large-language-model/">PaLM 2-L-Unicorn</a>, a much larger model.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi0wDD7i6x7jVZpzCE3jQDkun0ros6zlSxrHP9wMEZAky3WiWaVB1U8ffguLlcl1vIrDy-8AxyZhxPlymeUas4FJaCqDQdQFW7YAXTGH6MaXwZs9SyrkE4Q4h1zlgFgblXwDmxTTR0uQugbOXK93s7uAE-Q4GbOBO5z94uby9KtOgc0rMBEU1qq4hQVYmuV/s759/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="281" data-original-width="759" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi0wDD7i6x7jVZpzCE3jQDkun0ros6zlSxrHP9wMEZAky3WiWaVB1U8ffguLlcl1vIrDy-8AxyZhxPlymeUas4FJaCqDQdQFW7YAXTGH6MaXwZs9SyrkE4Q4h1zlgFgblXwDmxTTR0uQugbOXK93s7uAE-Q4GbOBO5z94uby9KtOgc0rMBEU1qq4hQVYmuV/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><span style="text-align: left;">Bar chart showing the accuracy improvement of the fine-tuned small model compared to zero-shot prompting with PaLM 2-L-Unicorn.</span></td></tr></tbody></table>
<p>
Our results show that the much smaller fine-tuned reward model generally performs better than zero-shot prompting a large model, even though the reward model has never seen data from the task in the test set. The only exception is logical deduction, where it performs on par with zero-shot prompting.
</p>
<p>
This is a very promising result as we can potentially just use a small fine-tuned reward model to perform backtracking and improve accuracy on any task, even if we don’t have the data for it. This smaller reward model is completely independent of the generator LLM, and can be updated and further fine-tuned for individual use cases.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiC2HmNdzNKjZYLC_aPbRVHZmFs6_ko4Bs5CjljgTEeXipJsW0_H3HlbE-8TLCQK9GtYuUBT-liCpaZ5zi2dcbF1GkvhjouJbJBE9mVl1yUCJEZAVY8Gk8d-P_HlmeqxcPIpsKwSQeSE93LV1aimd_GuLA5VrWOYtfeLkLpEXXkrgJW5R6fV06_OBxJi6CI/s867/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="446" data-original-width="867" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiC2HmNdzNKjZYLC_aPbRVHZmFs6_ko4Bs5CjljgTEeXipJsW0_H3HlbE-8TLCQK9GtYuUBT-liCpaZ5zi2dcbF1GkvhjouJbJBE9mVl1yUCJEZAVY8Gk8d-P_HlmeqxcPIpsKwSQeSE93LV1aimd_GuLA5VrWOYtfeLkLpEXXkrgJW5R6fV06_OBxJi6CI/s16000/image3.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">An illustration showing how our backtracking method works.</td></tr></tbody></table>
<br />
<h2>Conclusion</h2>
<p>
In this work, we created an evaluation benchmark dataset that the wider academic community can use to evaluate future LLMs. We further showed that LLMs currently struggle to find logical errors. However, if they could, we show the effectiveness of backtracking as a strategy that can provide gains on tasks. Finally, a smaller reward model can be trained on general mistake-finding tasks and be used to improve out-of-domain mistake finding, showing that mistake-finding can generalize.
</p>
<br />
<h2>Acknowledgements</h2>
<p>
<em>Thank you to Peter Chen, Tony Mak, Hassan Mansoor and Victor Cărbune for contributing ideas and helping with the experiments and data collection. We would also like to thank Sian Gooding and Vicky Zayats for their comments and suggestions on the paper.</em>
</p><br />Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-71972758764571610882024-01-08T14:07:00.000-08:002024-01-19T09:59:57.076-08:00Responsible AI at Google Research: User Experience Team<span class="byline-author">Posted by Ayça Çakmakli, UX Lead, Google Research, Responsible AI and Human Centered Technology Team</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhssQETGEGXjqvZARNVbKQATaocb0RDAkEOzrkJXtbqiJhZ0_hAAeb8zgOqbiGutvlvU1BTaE96y-0Mc6xXX-bP1-xyVa6xPsmmSwqxZlk_nn6UgmycZGYztCOgV1G3IKT9YCnmKFggXUmrEKFW1Y9NtfXNOHmSfaLoIxk8UxQked-9QDeDkSOCZMBaIYrL/s16000/hero.jpg" style="display: none;" />
<p>
Google’s Responsible AI User Experience (Responsible AI UX) team is a product-minded team embedded within Google Research. This unique positioning requires us to apply responsible AI development practices to our user-centered user experience (UX) design process. In this post, we describe the importance of UX design and responsible AI in product development, and share a few examples of how our team’s capabilities and cross-functional collaborations have led to responsible development across Google.
</p>
<a name='more'></a>
<p>
First, the UX part. We are a multi-disciplinary team of product design experts: designers, engineers, researchers, and strategists who manage the user-centered UX design process from early-phase ideation and problem framing to later-phase user-interface (UI) design, prototyping and refinement. We believe that effective product development occurs when there is clear alignment between significant unmet user needs and a product's primary value proposition, and that this alignment is reliably achieved via a thorough user-centered UX design process.
</p>
<p>
And second, recognizing generative AI’s (GenAI) potential to significantly impact society, we embrace our role as the primary user advocate as we continue to evolve our UX design process to meet the unique challenges AI poses, maximizing the benefits and minimizing the risks. As we navigate through each stage of an AI-powered product design process, we place a heightened emphasis on the ethical, societal, and long-term impact of our decisions. We contribute to the ongoing development of comprehensive <a href="https://ai.google/responsibility/ai-governance-operations">safety and inclusivity protocols</a> that define design and deployment guardrails around key issues like content curation, security, privacy, model capabilities, model access, equitability, and fairness that help mitigate GenAI risks.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgF6fWFT9CPcFgDfbPhNuCcrNiTCCSUlP1c0Dnr_sSYCnFt3J-7j3axB8sgk34-jdo6L7Xsp9XpNSz7_xp6uEZD5_GumzOt491oPcnWsbI74tkmBNh5QQ07ra2R-1CrgcnbhVexR48bt_YVRriqIGhF_qgO_PIDseseNvOYISz6bPHjYg5zBkS5FxeM08m-/s1920/image4.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1080" data-original-width="1920" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgF6fWFT9CPcFgDfbPhNuCcrNiTCCSUlP1c0Dnr_sSYCnFt3J-7j3axB8sgk34-jdo6L7Xsp9XpNSz7_xp6uEZD5_GumzOt491oPcnWsbI74tkmBNh5QQ07ra2R-1CrgcnbhVexR48bt_YVRriqIGhF_qgO_PIDseseNvOYISz6bPHjYg5zBkS5FxeM08m-/s16000/image4.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Responsible AI UX is constantly evolving its user-centered product design process to meet the needs of a GenAI-powered product landscape with greater sensitivity to the needs of users and society and an emphasis on ethical, societal, and long-term impact.</td></tr></tbody></table>
<p>
Responsibility in product design is also reflected in the user and societal problems we choose to address and the programs we resource. Thus, we encourage the <a href="https://blog.research.google/2023/11/emerging-practices-for-society-centered.html">prioritization of user problems with significant scale and severity</a> to help maximize the positive impact of GenAI technology.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjrU1niTXyrjzqlRSA8w6qyV4zKtquz0QP8faCIfMp9oPYYZjNCWe0pjW3z3AE9dLMHIR8OKVmzbUlU3oRW9POZnEpmlVGK8hws3E5sgNhj9cR7bY78gGteQN1ekl9LDz61s-WQjxTcPYnxfqO6RXs-Ax-dCOIe9vn-xdX-K9Hjta1PIFDHjlvrxbXeQSk9/s1920/image3.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1080" data-original-width="1920" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjrU1niTXyrjzqlRSA8w6qyV4zKtquz0QP8faCIfMp9oPYYZjNCWe0pjW3z3AE9dLMHIR8OKVmzbUlU3oRW9POZnEpmlVGK8hws3E5sgNhj9cR7bY78gGteQN1ekl9LDz61s-WQjxTcPYnxfqO6RXs-Ax-dCOIe9vn-xdX-K9Hjta1PIFDHjlvrxbXeQSk9/s16000/image3.png" /></a></td></tr></tbody></table>
<p>
Communication across teams and disciplines is essential to responsible product design. The seamless flow of information and insight from user research teams to product design and engineering teams, and vice versa, is essential to good product development. One of our team’s core objectives is to ensure the practical application of deep user-insight into AI-powered product design decisions at Google by bridging the communication gap between the vast technological expertise of our engineers and the user/societal expertise of our academics, research scientists, and user-centered design research experts. We’ve built a multidisciplinary team with expertise in these areas, deepening our empathy for the communication needs of our audience, and enabling us to better interface between our user & society experts and our technical experts. We create frameworks, guidebooks, prototypes, cheatsheets, and multimedia tools to help bring insights to life for the right people at the right time.</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjQrlbUi_jNRom8l7asIw8JHH12fjthtD_iPFDrWhxlItMd3L01hZpxdMx3bGEoRa3QUzHBcal0wvd3eLOKMyDZscFoIgfI4IHYdKkyaLxNhifdcl2ODH5nOV3VyYFtpCZaeze-zhCghpHY72da-OSdhvuTYrkqrYC0887_rVckCCTPCzOL-ZugV5oqHtlX/s1920/image5.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1080" data-original-width="1920" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjQrlbUi_jNRom8l7asIw8JHH12fjthtD_iPFDrWhxlItMd3L01hZpxdMx3bGEoRa3QUzHBcal0wvd3eLOKMyDZscFoIgfI4IHYdKkyaLxNhifdcl2ODH5nOV3VyYFtpCZaeze-zhCghpHY72da-OSdhvuTYrkqrYC0887_rVckCCTPCzOL-ZugV5oqHtlX/s16000/image5.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"></td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Facilitating responsible GenAI prototyping and development </h2>
<p>
During collaborations between Responsible AI UX, the <a href="https://ai.googleblog.com/2023/05/responsible-ai-at-google-research-pair.html">People + AI Research</a> (PAIR) initiative and <a href="https://labs.google/">Labs</a>, we identified that <a href="https://dl.acm.org/doi/abs/10.1145/3491101.3503564">prototyping</a> can afford a creative opportunity to engage with large language models (LLM), and is often the first step in GenAI product development. To address the need to introduce LLMs into the prototyping process, we explored a range of different prompting designs. Then, we went out into the field, employing various external, first-person UX design research methodologies to draw out insight and gain empathy for the user’s perspective. Through user/designer co-creation sessions, iteration, and prototyping, we were able to bring internal stakeholders, product managers, engineers, writers, sales, and marketing teams along to ensure that the user point of view was well understood and to reinforce alignment across teams.
</p>
<p>
The result of this work was <a href="https://developers.googleblog.com/2023/03/announcing-palm-api-and-makersuite.html">MakerSuite</a>, a generative AI platform launched at <a href="https://blog.research.google/2023/05/google-research-at-io-2023.html">Google I/O 2023</a> that enables people, even those without any ML experience, to prototype creatively using LLMs. The team’s first-hand experience with users and understanding of the challenges they face allowed us to incorporate our <a href="https://ai.google/responsibility/principles/">AI Principles</a> into the MakerSuite product design. Product features like <a href="https://developers.generativeai.google/guide/safety_setting">safety filters</a>, for example, enable users to manage outcomes, leading to easier and more responsible product development with MakerSuite.
</p>
<p>
Because of our close collaboration with product teams, we were able to adapt text-only prototyping to support multimodal interaction with <a href="https://makersuite.google.com/app/prompts/new_freeform">Google AI Studio</a>, an evolution of MakerSuite. Now, Google AI Studio enables developers and non-developers alike to seamlessly leverage Google’s latest <a href="https://ai.google.dev/">Gemini</a> model to merge multiple modality inputs, like text and image, in product explorations. Facilitating product development in this way provides us with the opportunity to better use AI to identify <a href="https://arxiv.org/pdf/2310.15428.pdf">appropriateness of outcomes</a> and unlocks opportunities for developers and non-developers to play with AI sandboxes. Together with our partners, we continue to actively push this effort in the products we support.</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht40iWgBGeov03IQ-OXhPgLMF8vC9NPSZag-3PyHJMRKHRoTKOAShqTJCueqijS_jdyd7kOMQa-PjmZBQg-EqyAn3f56tucPSLjvFEmaVTuS3Eo9YYq9gIV22MqFeL0qiU4AblFB46S46Y-czdfct-2dfPCh_NaH9zMHJwUlGWRYb37XLHj6_tvqAQrsV3/s1999/image6.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1406" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEht40iWgBGeov03IQ-OXhPgLMF8vC9NPSZag-3PyHJMRKHRoTKOAShqTJCueqijS_jdyd7kOMQa-PjmZBQg-EqyAn3f56tucPSLjvFEmaVTuS3Eo9YYq9gIV22MqFeL0qiU4AblFB46S46Y-czdfct-2dfPCh_NaH9zMHJwUlGWRYb37XLHj6_tvqAQrsV3/s16000/image6.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><a href="https://makersuite.google.com/app/prompts/new_freeform">Google AI studio</a> enables developers and non-developers to leverage Google Cloud infrastructure and merge multiple modality inputs in their product explorations.</td></tr></tbody></table>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Equitable speech recognition</h2>
<p>
Multiple <a href="https://www.pnas.org/doi/10.1073/pnas.1915768117">external studies</a>, as well as Google’s <a href="https://www.frontiersin.org/articles/10.3389/frai.2021.725911/full">own research,</a> have identified an unfortunate deficiency in the ability of current speech recognition technology to understand Black speakers on average, relative to White speakers. As multimodal AI tools begin to rely more heavily on speech prompts, this problem will grow and continue to alienate users. To address this problem, the Responsible AI UX team is <a href="https://blog.google/technology/research/project-elevate-black-voices-google-research/">partnering with world-renowned linguists and scientists at Howard University</a>, a prominent <a href="https://en.wikipedia.org/wiki/Historically_black_colleges_and_universities" target="_blank">HBCU</a>, to build a high quality African-American English dataset to improve the design of our speech technology products to make them more accessible. Called Project Elevate Black Voices, this effort will allow Howard University to share the dataset with those looking to improve speech technology while establishing a framework for responsible data collection, ensuring the data benefits Black communities. Howard University will retain the ownership and licensing of the dataset and serve as stewards for its responsible use. At Google, we’re providing funding support and collaborating closely with our partners at Howard University to ensure the success of this program.
</p>
<br />
<div class="separator" style="clear: both; text-align: center;">
<iframe allowfullscreen="" class="BLOG_video_class" frameborder="0" height="360" src="https://www.youtube.com/embed/t_pdlrU8qhs?si=5xY1AoGc_d2HTzQf" width="640" youtube-src-id="5xY1AoGc_d2HTzQf"></iframe>
</div>
<br />
<div style="line-height: 40%;">
<br />
</div>
<h2>Equitable computer vision</h2>
<p>
The <a href="http://gendershades.org/">Gender Shades</a> project highlighted that computer vision systems struggle to detect people with darker skin tones, and performed particularly poorly for women with darker skin tones. This is largely due to the fact that the datasets used to train these models were not inclusive to a wide range of skin tones. To address this limitation, the Responsible AI UX team has been partnering with sociologist <a href="https://www.ellismonk.com/">Dr. Ellis Monk</a> to release the <a href="https://blog.google/products/search/monk-skin-tone-scale/">Monk Skin Tone Scale</a> (MST), a skin tone scale designed to be more inclusive of the spectrum of skin tones around the world. It provides a tool to assess the inclusivity of datasets and model performance across an inclusive range of skin tones, resulting in features and products that work better for everyone.
</p>
<p>
We have integrated MST into a range of <a href="https://blog.google/products/search/monk-skin-tone-scale/">Google products</a>, such as Search, Google Photos, and others. We also open sourced MST, <a href="https://dl.acm.org/doi/pdf/10.1145/3632120">published our research</a>, <a href="https://blog.research.google/2023/05/consensus-and-subjectivity-of-skin-tone_15.html">described our annotation practices</a>, and <a href="https://skintone.google/mste-dataset">shared an example dataset</a> to encourage others to easily integrate it into their products. The Responsible AI UX team continues to collaborate with Dr. Monk, utilizing the MST across multiple product applications and continuing to do international research to ensure that it is globally inclusive.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Consulting & guidance</h2>
<p>
As teams across Google continue to develop products that leverage the capabilities of GenAI models, our team recognizes that the challenges they face are varied and that market competition is significant. To support teams, we develop actionable assets to facilitate a more streamlined and responsible product design process that considers available resources. We act as a product-focused design consultancy, identifying ways to scale services, share expertise, and apply our design principles more broadley. Our goal is to help all product teams at Google connect significant unmet user needs with technology benefits via great responsible product design.
</p>
<p>
One way we have been doing this is with the creation of the <a href="https://pair.withgoogle.com/guidebook/">People + AI Guidebook</a>, an evolving summative resource of many of the responsible design lessons we’ve learned and recommendations we’ve made for internal and external stakeholders. With its forthcoming, rolling <a href="https://medium.com/people-ai-research/updating-the-people-ai-guidebook-in-the-age-of-generative-ai-cace6c846db4">updates</a> focusing specifically on how to best design and consider user needs with GenAI, we hope that our internal teams, external stakeholders, and larger community will have useful and actionable guidance at the most critical milestones in the product development journey.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjOMslkRaMAkYwAUCOHS5tWTuCBQJ7sPNjkDbopWKKl21XD_1Q_VrK3tCctspa72hY63uOMZKipV5flHTW69S5bVjCVBvE8oMKmEax3VWNj7Wx20UlYRPZABdJhq0DJlegrtSIbQBxtkf18ygJtSxk2kY5f8L82WKEw9vLmiRDgrZiIXtsJ5RtbgvmISRw4/s1100/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="622" data-original-width="1100" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjOMslkRaMAkYwAUCOHS5tWTuCBQJ7sPNjkDbopWKKl21XD_1Q_VrK3tCctspa72hY63uOMZKipV5flHTW69S5bVjCVBvE8oMKmEax3VWNj7Wx20UlYRPZABdJhq0DJlegrtSIbQBxtkf18ygJtSxk2kY5f8L82WKEw9vLmiRDgrZiIXtsJ5RtbgvmISRw4/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">The People + AI Guidebook has six chapters, designed to cover different aspects of the product life cycle.</td></tr></tbody></table>
<p>
If you are interested in reading more about Responsible AI UX and how we are specifically thinking about designing responsibly with Generative AI, please check out this <a href="https://medium.com/people-ai-research/meet-ay%C3%A7a-%C3%A7akmakli-googles-new-head-of-responsible-ai-ux-d8f2700df95b">Q&A piece</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>Shout out to our the Responsible AI UX team members: Aaron Donsbach, Alejandra Molina, Courtney Heldreth, Diana Akrong, Ellis Monk, Femi Olanubi, Hope Neveux, Kafayat Abdul, Key Lee, Mahima Pushkarna, Sally Limb, Sarah Post, Sures Kumar Thoddu Srinivasan, Tesh Goyal, Ursula Lauriston, and Zion Mengesha. Special thanks to Michelle Cohn for her contributions to this work. </em>
</p>Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-88161834733856381312023-12-22T10:37:00.000-08:002024-01-11T16:01:33.313-08:002023: A year of groundbreaking advances in AI and computing<span class="byline-author">Posted by Jeff Dean, Chief Scientist, Google DeepMind & Google Research, Demis Hassabis, CEO, Google DeepMind, and James Manyika, SVP, Google Research, Technology & Society</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiU12G_p2DZO4gQ-P95aP2IFvtKRHRG5Oik9VzJ4WtT_VznggjUrL4tTzqto-C8yx6ULgvugKgH9usiZxnKGk97pOFyHWnu-1S4sSbR2Jjq5T36tQRbZucTAP7gmXkGw77xN6s39IKxPaxbo5tw_Cq52ZfPWOCBkWL-2XTuzsIh6viRIqgGcdUdNq-pHjzl/s1100/year_in_review-hero.jpg" style="display: none;" />
<p>
This has been a year of incredible progress in the field of Artificial Intelligence (AI) research and its practical applications.
</p>
<p>
As ongoing research pushes AI even farther, we look back to our <a href="https://ai.google/static/documents/google-why-we-focus-on-ai.pdf">perspective</a> published in January of this year, titled “Why we focus on AI (and to what end),” where we noted:
</p>
<a name='more'></a>
<div style="margin-left: 40px;">
<p>
We are committed to leading and setting the standard in developing and shipping useful and beneficial applications, applying ethical principles grounded in human values, and evolving our approaches as we learn from research, experience, users, and the wider community.
</p>
<p>
We also believe that getting AI right — which to us involves innovating and delivering widely accessible benefits to people and society, while mitigating its risks — must be a collective effort involving us and others, including researchers, developers, users (individuals, businesses, and other organizations), governments, regulators, and citizens.
</p>
<p>
We are convinced that the AI-enabled innovations we are focused on developing and delivering boldly and responsibly are useful, compelling, and have the potential to assist and improve lives of people everywhere — this is what compels us.
</p>
</div>
<p>
In this Year-in-Review post we’ll go over some of Google Research's and Google DeepMind’s efforts putting these paragraphs into practice safely throughout 2023.
</p>
<br />
<h2>Advances in products & technologies </h2>
<p>
This was the year generative AI captured the world’s attention, creating imagery, music, stories, and engaging conversation about everything imaginable, at a level of creativity and a speed almost implausible a few years ago.
</p>
<p>
In February, we <a href="https://blog.google/technology/ai/bard-google-ai-search-updates/">first launched</a> <a href="https://bard.google.com">Bard</a>, a tool that you can use to explore creative ideas and explain things simply. It can generate text, translate languages, write different kinds of creative content and more.
</p>
<p>
In May, we watched the results of months and years of our foundational and applied work announced on stage <a href="https://blog.research.google/2023/05/google-research-at-io-2023.html">at Google I/O</a>. Principally, this included <a href="https://ai.google/discover/palm2/">PaLM 2</a>, a large language model (LLM) that brought together compute-optimal scaling, an improved dataset mixture, and model architecture to excel at advanced reasoning tasks.
</p>
<div class="separator" style="clear: both; text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg5R7E2f0BQIsiLZm25kxfR_Ix_Bvm6HlQQCXI12sB42siKUAf8eZvVEDU5bi8EQc22BoG6SV_H8NbC-PKd2pPv7FhC-uBR43ZWpbrgvaGJ7699j-uUctPbFBO9Bf-u81gkfU1OP4oGZhs6KKkub2znNsvcElREn5kwKh2npPJxFJdSVZZUIPZhyphenhyphen3B_jCq4/s1920/lockup_ic_PaLM-2_H_4297x745px_clr_@1x.jpg" style="margin-left: 1em; margin-right: 1em;"><img border="0" data-original-height="555" data-original-width="1920" height="116" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg5R7E2f0BQIsiLZm25kxfR_Ix_Bvm6HlQQCXI12sB42siKUAf8eZvVEDU5bi8EQc22BoG6SV_H8NbC-PKd2pPv7FhC-uBR43ZWpbrgvaGJ7699j-uUctPbFBO9Bf-u81gkfU1OP4oGZhs6KKkub2znNsvcElREn5kwKh2npPJxFJdSVZZUIPZhyphenhyphen3B_jCq4/w400-h116/lockup_ic_PaLM-2_H_4297x745px_clr_@1x.jpg" width="400" /></a></div>
<p>By fine-tuning and instruction-tuning PaLM 2 for different purposes, we were able to integrate it into numerous Google products and features, including:</p>
<ul>
<li>An update to Bard, which enabled multilingual capabilities. Since its initial launch, Bard is now available in more than <a href="https://support.google.com/bard/answer/13575153?hl=en">40 languages and over 230 countries and territories</a>, and <a href="https://blog.google/products/bard/google-bard-new-features-update-sept-2023/">with extensions</a>, Bard can find and show relevant information from Google tools used every day — like Gmail, Google Maps, YouTube, and more.
</li>
<li><a href="https://blog.google/products/search/generative-ai-search/">Search Generative Experience</a> (SGE), which uses LLMs to reimagine both how to organize information and how to help people navigate through it, creating a more fluid, conversational interaction model for our core Search product. This work extended the search engine experience from primarily focused on information retrieval into something much more — capable of retrieval, synthesis, creative generation and continuation of previous searches — while continuing to serve as a connection point between users and the web content they seek.
</li>
<li><a href="https://google-research.github.io/seanet/musiclm/examples/">MusicLM</a>, a text-to-music model powered by <a href="https://ai.googleblog.com/2022/10/audiolm-language-modeling-approach-to.html">AudioLM</a> and <a href="https://arxiv.org/abs/2208.12415">MuLAN</a>, which can make music from text, humming, images or video and musical accompaniments to singing.
</li>
<li>Duet AI, our AI-powered collaborator that provides users with assistance when they use Google Workspace and Google Cloud. <a href="https://workspace.google.com/blog/product-announcements/duet-ai">Duet AI in Google Workspace</a>, for example, helps users write, create images, analyze spreadsheets, draft and summarize emails and chat messages, and summarize meetings. <a href="https://cloud.google.com/blog/products/application-modernization/introducing-duet-ai-for-google-cloud">Duet AI in Google Cloud</a> helps users code, deploy, scale, and monitor applications, as well as identify and accelerate resolution of cybersecurity threats.
</li>
<li>And many <a href="https://blog.google/technology/developers/google-io-2023-100-announcements/">other developments</a>.
</li>
</ul>
<p>
In June, following last year’s release of our text-to-image generation model <a href="https://imagen.research.google/">Imagen</a>, we released <a href="https://blog.research.google/2023/06/imagen-editor-and-editbench-advancing.html">Imagen Editor</a>, which provides the ability to use region masks and natural language prompts to interactively edit generative images to provide much more precise control over the model output.
</p>
<div style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhfNnxCCKKLZgjajw30Ml1CiPruUIScAPwpa77LQ8Auqkf_KJh9rlJWhYRnQf1g4L0qhVDUJNHabkpL_ZW60FJs8XUWV1kT5M32YU-6oYBC5383noYqno-cYzUboAAOgXvDlWtqwu-zl94M2r02Fsid0jjgLBJl3JCVR4lc8ZVw7-c7q9OlHjzmrRXz5i5H/s1261/image4.png"><img border="0" data-original-height="306" data-original-width="1261" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhfNnxCCKKLZgjajw30Ml1CiPruUIScAPwpa77LQ8Auqkf_KJh9rlJWhYRnQf1g4L0qhVDUJNHabkpL_ZW60FJs8XUWV1kT5M32YU-6oYBC5383noYqno-cYzUboAAOgXvDlWtqwu-zl94M2r02Fsid0jjgLBJl3JCVR4lc8ZVw7-c7q9OlHjzmrRXz5i5H/s16000/image4.png" /></a></div>
<p>
Later in the year, we released Imagen 2, which improved outputs via a specialized image aesthetics model based on human preferences for qualities such as good lighting, framing, exposure, and sharpness.
</p>
<p>
In October, we launched a feature that <a href="https://blog.research.google/2023/10/google-search-can-now-help-with-english-speaking-practice.html">helps people practice speaking and improve their language skills</a>. The key technology that enabled this functionality was a novel deep learning model developed in collaboration with the Google Translate team, called Deep Aligner. This single new model has led to dramatic improvements in alignment quality across all tested language pairs, reducing average alignment error rate from 25% to 5% compared to alignment approaches based on <a href="https://aclanthology.org/C96-2141/">Hidden Markov models</a> (HMMs).
</p>
<p>
In November, in partnership with <a href="https://blog.youtube/inside-youtube/ai-and-music-experiment/">YouTube</a>, we announced <a href="https://deepmind.google/discover/blog/transforming-the-future-of-music-creation/">Lyria</a>, our most advanced AI music generation model to date. We released two experiments designed to open a new playground for creativity, DreamTrack and music AI tools, in concert with <a href="https://blog.youtube/inside-youtube/partnering-with-the-music-industry-on-ai/">YouTube’s Principles for partnering with the music industry on AI technology</a>.
</p>
<p>
Then in December, we launched <a href="https://blog.google/technology/ai/google-gemini-ai/">Gemini</a>, our most capable and general AI model. Gemini was built to be multimodal from the ground up across text, audio, image and videos. Our initial family of Gemini models comes in three different sizes, Nano, Pro, and Ultra. Nano models are our smallest and most efficient models for powering on-device experiences in products like Pixel. The Pro model is highly-capable and best for scaling across a wide range of tasks. The Ultra model is our largest and most capable model for highly complex tasks.
</p>
<br />
<div class="separator" style="clear: both; text-align: center;"><a href="https://www.youtube.com/watch?v=jV1vkHv4zq8"><iframe allowfullscreen="" class="BLOG_video_class" height="360" src="https://www.youtube.com/embed/jV1vkHv4zq8" width="640" youtube-src-id="jV1vkHv4zq8"></iframe></a></div>
<br />
<p>In a <a href="https://storage.googleapis.com/deepmind-media/gemini/gemini_1_report.pdf">technical report</a> about <a href="https://deepmind.google/technologies/gemini">Gemini models</a>, we showed that Gemini Ultra’s performance exceeds current state-of-the-art results on 30 of the 32 widely-used academic benchmarks used in LLM research and development. With a score of 90.04%, Gemini Ultra was the first model to outperform human experts on <a href="https://arxiv.org/abs/2009.03300">MMLU</a>, and achieved a state-of-the-art score of 59.4% on the new <a href="https://arxiv.org/abs/2009.03300">MMMU</a> benchmark.
</p>
<p>
Building on <a href="https://deepmind.google/discover/blog/competitive-programming-with-alphacode/">AlphaCode</a>, the first AI system to perform at the level of the median competitor in competitive programming, we <a href="https://storage.googleapis.com/deepmind-media/AlphaCode2/AlphaCode2_Tech_Report.pdf">introduced AlphaCode 2</a> powered by a specialized version of Gemini. When evaluated on the same platform as the original AlphaCode, we found that AlphaCode 2 solved 1.7x more problems, and performed better than 85% of competition participants
</p>
<p>
At the same time, <a href="https://blog.google/products/bard/google-bard-try-gemini-ai/">Bard got its biggest upgrade</a> with its use of the Gemini Pro model, making it far more capable at things like understanding, summarizing, reasoning, coding, and planning. In six out of eight benchmarks, Gemini Pro outperformed GPT-3.5, including in MMLU, one of the key standards for measuring large AI models, and <a href="https://huggingface.co/datasets/gsm8k">GSM8K</a>, which measures grade school math reasoning. Gemini Ultra will come to Bard early next year through Bard Advanced, a new cutting-edge AI experience.
</p>
<p>
Gemini Pro is also available on <a href="https://cloud.google.com/blog/products/ai-machine-learning/gemini-support-on-vertex-ai">Vertex AI</a>, Google Cloud’s end-to-end AI platform that empowers developers to build applications that can process information across text, code, images, and video. <a href="https://blog.google/technology/ai/gemini-api-developers-cloud/">Gemini Pro was also made available in AI Studio</a> in December.
</p>
<p>
To best illustrate some of Gemini’s capabilities, we produced a <a href="https://deepmind.google/technologies/gemini/#hands-on">series of short videos</a> with explanations of how Gemini could:
</p>
<ul>
<li><a href="https://www.youtube.com/watch?v=sPiOP_CB54A">Unlock insights in scientific literature</a>
</li><li><a href="https://www.youtube.com/watch?v=LvGmVmHv69s&t=1s">Excel at competitive programming</a>
</li><li><a href="https://www.youtube.com/watch?v=D64QD7Swr3s">Process and understand raw audio</a>
</li><li><a href="https://www.youtube.com/watch?v=K4pX1VAxaAI">Explain reasoning in math and physics</a>
</li><li><a href="https://www.youtube.com/watch?v=v5tRc_5-8G4">Reason about user intent to generate bespoke experiences</a>
</li>
</ul>
<br />
<h2>ML/AI Research</h2>
<p>
In addition to our advances in products and technologies, we’ve also made a number of important advancements in the broader fields of machine learning and AI research.
</p>
<p>
At the heart of the most advanced ML models is the Transformer model architecture, <a href="https://blog.research.google/2017/08/transformer-novel-neural-network.html">developed by Google researchers in 2017</a>. Originally developed for language, it has proven useful in domains as varied as <a href="https://blog.research.google/2020/12/transformers-for-image-recognition-at.html">computer vision</a>, <a href="https://deepmind.google/discover/blog/transforming-the-future-of-music-creation/">audio</a>, <a href="https://deepmind.google/discover/blog/a-catalogue-of-genetic-mutations-to-help-pinpoint-the-cause-of-diseases/">genomics</a>, <a href="https://deepmind.google/technologies/alphafold/">protein folding</a>, and more. This year, our work on <a href="https://blog.research.google/2023/03/scaling-vision-transformers-to-22.html">scaling vision transformers</a> demonstrated state-of-the-art results across a wide variety of vision tasks, and has also been useful in building <a href="https://blog.research.google/2023/03/palm-e-embodied-multimodal-language.html">more capable robots</a>.
</p>
<p>
</p>
<p>
Expanding the versatility of models requires the ability to perform higher-level and multi-step reasoning. This year, we approached this target following several research tracks. For example, <a href="https://blog.research.google/2023/08/teaching-language-models-to-reason.html">algorithmic prompting</a> is a new method that teaches language models reasoning by demonstrating a sequence of algorithmic steps, which the model can then apply in new contexts. This approach improves accuracy on one middle-school mathematics benchmark from 25.9% to 61.1%.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhqMgWRH7DtSwqbAImqRRsW26oyKnDiinTNvtkUuvASZJSaChsNXG1-4EeDkTr22E7xjRzwcFdWCZSKFuuBoLfsZiH27pZ1d6XMef8ns6RGx619oZnHdeCVZb7EOPWigNqbGsmu4FrU2Xgampr0HIASv7ks8ha9DE8L3hmAhKU8_Aps8L_1evceD2MKyG23/s1200/image5.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="166" data-original-width="1200" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhqMgWRH7DtSwqbAImqRRsW26oyKnDiinTNvtkUuvASZJSaChsNXG1-4EeDkTr22E7xjRzwcFdWCZSKFuuBoLfsZiH27pZ1d6XMef8ns6RGx619oZnHdeCVZb7EOPWigNqbGsmu4FrU2Xgampr0HIASv7ks8ha9DE8L3hmAhKU8_Aps8L_1evceD2MKyG23/s16000/image5.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">By providing algorithmic prompts, we can teach a model the rules of arithmetic via in-context learning.</em></td></tr></tbody></table>
<p>
In the domain of visual question answering, in a collaboration with UC Berkeley researchers, we showed how we could <a href="https://blog.research.google/2023/07/modular-visual-question-answering-via.html">better answer complex visual questions</a> (“Is the carriage to the right of the horse?”) by combining a visual model with a language model trained to answer visual questions by synthesizing a program to perform multi-step reasoning.
</p>
<p>
We are now using a <a href="https://blog.research.google/2023/05/large-sequence-models-for-software.html">general model that understands many aspects of the software development life cycle</a> to automatically generate code review comments, respond to code review comments, make performance-improving suggestions for pieces of code (by learning from past such changes in other contexts), fix code in response to compilation errors, and more.
</p>
<p>
In a multi-year research collaboration with the Google Maps team, we were able to scale inverse reinforcement learning and apply it to the <a href="https://blog.research.google/2023/09/world-scale-inverse-reinforcement.html">world-scale problem of improving route suggestions</a> for over 1 billion users. Our work culminated in a 16–24% relative improvement in global route match rate, helping to ensure that routes are better aligned with user preferences.
</p>
<p>
We also continue to work on techniques to improve the inference performance of machine learning models. In work on <a href="https://blog.research.google/2023/08/neural-network-pruning-with.html">computationally-friendly approaches to pruning connections in neural networks</a>, we were able to devise an approximation algorithm to the computationally intractable best-subset selection problem that is able to prune 70% of the edges from an image classification model and still retain almost all of the accuracy of the original.
</p>
<p>
In work on <a href="https://blog.research.google/2023/06/speed-is-all-you-need-on-device.html">accelerating on-device diffusion models</a>, we were also able to apply a variety of optimizations to attention mechanisms, convolutional kernels, and fusion of operations to make it practical to run high quality image generation models on-device; for example, enabling “a photorealistic and high-resolution image of a cute puppy with surrounding flowers” to be generated in just 12 seconds on a smartphone.
</p>
<br />
<div class="separator" style="clear: both; text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhM5NuphyphenhyphenlH7_H5PfABZrn1a-CuFJm8XyEpAlodQqpcrTvYj3XNWarRULjbSgKqY1trCsC0hbVvHTU9l4pUyn3vFupsfyLDVgdMgsTYHtE1b8AWQwZWUgXGAAJ_F12rcOqVDjbe1q5OX0TdYEQBOt-FpKIqvfYivuAcPU3bVTZcYxUdFOdtaW5JE6Fii7ej/s522/image7.gif" style="margin-left: 1em; margin-right: 1em;"><img border="0" data-original-height="522" data-original-width="270" height="400" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhM5NuphyphenhyphenlH7_H5PfABZrn1a-CuFJm8XyEpAlodQqpcrTvYj3XNWarRULjbSgKqY1trCsC0hbVvHTU9l4pUyn3vFupsfyLDVgdMgsTYHtE1b8AWQwZWUgXGAAJ_F12rcOqVDjbe1q5OX0TdYEQBOt-FpKIqvfYivuAcPU3bVTZcYxUdFOdtaW5JE6Fii7ej/w208-h400/image7.gif" width="208" /></a></div>
<br />
<p>Advances in capable language and multimodal models have also benefited our robotics research efforts. We combined separately trained language, vision, and robotic control models into <a href="https://blog.research.google/2023/03/palm-e-embodied-multimodal-language.html">PaLM-E</a>, an embodied multi-modal model for robotics, and <a href="https://deepmind.google/discover/blog/rt-2-new-model-translates-vision-and-language-into-action/">Robotic Transformer 2</a> (RT-2), a novel vision-language-action (VLA) model that <a href="https://deepmind.google/discover/blog/robocat-a-self-improving-robotic-agent/">learns</a> from both web and robotics data, and translates this knowledge into generalized instructions for robotic control.</p>
<p>
</p><table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEijJL6HUJ6swdYZlAdRsPttHH9EmdE7TpGlK92U9hxHu29ANiHMQLC3QB1PX8HFWwatiJ6V-rjeImQ67oRUGtQMR4TDXB7mOVqsz-BouN1y29vbUU7rc-nTkj2H-V0V3VNCTMujhLSyNM3duUEVOYhm0nf8wBVrIghfEdpRsKtHn_gg2dWVxzzSXTk6ROTb/s616/image8.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="559" data-original-width="616" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEijJL6HUJ6swdYZlAdRsPttHH9EmdE7TpGlK92U9hxHu29ANiHMQLC3QB1PX8HFWwatiJ6V-rjeImQ67oRUGtQMR4TDXB7mOVqsz-BouN1y29vbUU7rc-nTkj2H-V0V3VNCTMujhLSyNM3duUEVOYhm0nf8wBVrIghfEdpRsKtHn_gg2dWVxzzSXTk6ROTb/s16000/image8.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">RT-2 architecture and training: We co-fine-tune a pre-trained vision-language model on robotics and web data. The resulting model takes in robot camera images and directly predicts actions for a robot to perform.</em></td></tr></tbody></table>
<p>
Furthermore, we showed how <a href="https://blog.research.google/2023/08/saytap-language-to-quadrupedal.html">language can also be used to control the gait of quadrupedal robots</a> and explored the <a href="https://blog.research.google/2023/08/language-to-rewards-for-robotic-skill.html">use of language to help formulate more explicit reward functions</a> to bridge the gap between human language and robotic actions. Then, in <a href="https://blog.research.google/2023/05/barkour-benchmarking-animal-level.html">Barkour</a> we benchmarked the agility limits of quadrupedal robots.</p>
<br />
<h2>Algorithms & optimization</h2>
<p>
Designing efficient, robust, and scalable algorithms remains a high priority. This year, our work included: applied and scalable algorithms, market algorithms, system efficiency and optimization, and privacy.
</p>
<p>
We introduced <a href="https://deepmind.google/discover/blog/alphadev-discovers-faster-sorting-algorithms/">AlphaDev</a>, an AI system that uses reinforcement learning to discover enhanced computer science algorithms. AlphaDev uncovered a faster algorithm for sorting, a method for ordering data, which led to improvements in the LLVM libc++ sorting library that were up to 70% faster for shorter sequences and about 1.7% faster for sequences exceeding 250,000 elements.
</p>
<p>
We developed a novel model to <a href="https://arxiv.org/abs/2305.12322">predict the properties of large graphs</a>, enabling estimation of performance for large programs. We released a new dataset, <a href="https://arxiv.org/abs/2308.13490">TPUGraphs</a>, to accelerate <a href="https://www.kaggle.com/competitions/predict-ai-model-runtime">open research in this area</a>, and showed how we can use <a href="https://blog.research.google/2023/12/advancements-in-machine-learning-for.html">modern ML to improve ML efficiency</a>.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjC00uLmFXRBzNXoaXirQMkAV7j91dBVwgvY4BEFuOLP4V9MquLuKt9imKFHBHsc7laEKUIuXUe_-v0DJyCauIQMrOzUaSOKl15hxBeAbgLGPWNYehM7Z8seK3P9JcjyMmeSZNyXWMYNME84KZwIygF1deRSpaZ1oBeK-uFnxEqCJcm6M0z4hA18JVGew-H/s1223/image11.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="1042" data-original-width="1223" height="341" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjC00uLmFXRBzNXoaXirQMkAV7j91dBVwgvY4BEFuOLP4V9MquLuKt9imKFHBHsc7laEKUIuXUe_-v0DJyCauIQMrOzUaSOKl15hxBeAbgLGPWNYehM7Z8seK3P9JcjyMmeSZNyXWMYNME84KZwIygF1deRSpaZ1oBeK-uFnxEqCJcm6M0z4hA18JVGew-H/w400-h341/image11.png" width="400" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">The TPUGraphs dataset has 44 million graphs for ML program optimization.</em></td></tr></tbody></table>
<p>We developed a new <a href="https://en.wikipedia.org/wiki/Load_balancing_(computing)">load balancing</a> algorithm for distributing queries to a server, called <a href="https://arxiv.org/abs/2312.10172">Prequal</a>, which minimizes a combination of requests-in-flight and estimates the latency. Deployments across several systems have saved CPU, latency, and RAM significantly. We also designed a new <a href="https://arxiv.org/abs/2305.02508">analysis framework</a> for the classical caching problem with capacity reservations.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRpdeenR8s32zMYfc6hmCq1OKu_Fk8NnhymDBWweQky1o9OIqARGqtzA-SUPAX1UZVQsrvPDXXbr20ZBz70RNBS3njSwCtkGYmBbWYOV7J87xFTMXCDRoiOh4EtGqf_aKBtegTJrbyru3ompVpfYzMx6EKWX26xHs_POZJ2dd3lbvDX-JggUoXFDn-HDJa/s1896/image12.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="422" data-original-width="1896" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgRpdeenR8s32zMYfc6hmCq1OKu_Fk8NnhymDBWweQky1o9OIqARGqtzA-SUPAX1UZVQsrvPDXXbr20ZBz70RNBS3njSwCtkGYmBbWYOV7J87xFTMXCDRoiOh4EtGqf_aKBtegTJrbyru3ompVpfYzMx6EKWX26xHs_POZJ2dd3lbvDX-JggUoXFDn-HDJa/s16000/image12.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">Heatmaps of normalized CPU usage transitioning to <a href="Prequal">Prequal</a> at 08:00.</em></td></tr></tbody></table>
<p>
We improved state-of-the-art in clustering and <a href="https://en.wikipedia.org/wiki/Graph_neural_network">graph algorithms</a> by developing new techniques for<a href="https://arxiv.org/abs/2106.05513"> computing minimum-cut</a>, <a href="https://arxiv.org/abs/2309.17243">approximating correlation clustering</a>, and <a href="https://arxiv.org/abs/2308.00503">massively parallel graph clustering</a>. Additionally, we introduced<a href="https://arxiv.org/abs/2308.03578"> TeraHAC</a>, a novel hierarchical clustering algorithm for trillion-edge graphs, designed a <a href="https://blog.research.google/2023/11/best-of-both-worlds-achieving.html">text clustering algorithm</a> for better scalability while maintaining quality, and designed the most efficient <a href="https://arxiv.org/abs/2307.03043">algorithm for approximating the Chamfer Distance</a>, the standard similarity function for multi-embedding models, offering >50× speedups over highly-optimized exact algorithms and scaling to billions of points.</p><p></p>
<p>
We continued optimizing Google’s large embedding models (LEMs), which power many of our core products and recommender systems. Some new techniques include <a href="https://arxiv.org/abs/2305.12102">Unified Embedding</a> for battle-tested feature representations in web-scale ML systems and <a href="https://arxiv.org/abs/2209.14881">Sequential Attention</a>, which uses attention mechanisms to discover high-quality sparse model architectures during training.
</p>
<!--<p>
This year, we also continued our research in market algorithms to design computationally efficient marketplaces and causal inference. First, we remain committed to advancing the rapidly growing interest in ads automation for which our recent work <a href="https://dl.acm.org/doi/abs/10.1145/3543507.3583416">explains the adoption of autobidding mechanisms</a> and <a href="https://dl.acm.org/doi/abs/10.1145/3580507.3597725">examines the effect of different auction formats on the incentives of advertisers</a>. In the multi-channel setting, our findings shed light on how the choice between local and global optimizations affects the design of multi-channel <a href="https://dl.acm.org/doi/abs/10.1145/3580507.3597707">auction systems</a> and <a href="https://dl.acm.org/doi/10.5555/3618408.3618709">bidding systems</a>.
</p>-->
<p>
Beyond auto-bidding systems, we also studied auction design in other complex settings, such as <a href="https://arxiv.org/abs/2204.01962">buy-many mechanisms</a>, <a href="https://arxiv.org/abs/2207.09429">auctions for heterogeneous bidders</a>, <a href="https://arxiv.org/abs/2309.10766">contract designs</a>, and innovated <a href="https://dl.acm.org/doi/10.5555/3618408.3618478">robust online bidding algorithms</a>. Motivated by the application of generative AI in collaborative creation (e.g., joint ad for advertisers), we proposed <a href="https://arxiv.org/abs/2310.10826">a novel token auction model </a>where LLMs bid for influence in the collaborative AI creation. Finally, we show how to <a href="https://dl.acm.org/doi/pdf/10.1145/3580507.3597702">mitigate personalization effects in experimental design</a>, which, for example, may cause recommendations to drift over time.
</p>
<p>
The Chrome Privacy Sandbox, a multi-year collaboration between Google Research and Chrome, has publicly launched several APIs, including for <a href="https://privacysandbox.com/intl/en_us/learning-hub/#protected-audience">Protected Audience</a>, <a href="https://privacysandbox.com/intl/en_us/learning-hub/#topics">Topics</a>, and <a href="https://privacysandbox.com/intl/en_us/learning-hub/#attribution-reporting">Attribution Reporting</a>. This is a major step in protecting user privacy while supporting the open and free web ecosystem. These efforts have been facilitated by fundamental research on <a href="https://arxiv.org/abs/2304.07210">re-identification risk</a>, <a href="https://arxiv.org/abs/2301.05605">private streaming computation</a>, <a href="https://blog.research.google/2023/12/summary-report-optimization-in-privacy.html">optimization</a> of privacy caps and budgets, <a href="https://arxiv.org/pdf/2308.13510.pdf">hierarchical aggregation</a>, and training models with <a href="https://arxiv.org/pdf/2312.05659.pdf">label privacy</a>.
</p>
<br />
<h2>Science and society</h2>
<p>
In the not too distant future, there is a very real possibility that AI applied to scientific problems can accelerate the rate of discovery in certain domains by 10× or 100×, or more, and lead to major advances in diverse areas including bioengineering, <a href="https://deepmind.google/discover/blog/millions-of-new-materials-discovered-with-deep-learning">materials science</a>, <a href="https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/">weather prediction</a>, <a href="https://blog.google/outreach-initiatives/sustainability/google-ai-climate-change-solutions/">climate forecasting</a>, <a href="https://blog.research.google/2023/09/google-research-embarks-on-effort-to.html">neuroscience</a>, <a href="https://blog.research.google/2023/04/an-ml-based-approach-to-better.html">genetic medicine</a>, and <a href="https://health.google/health-research/publications/">healthcare</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Sustainability and climate change </h3>
<p>
In <a href="https://blog.google/outreach-initiatives/sustainability/google-ai-reduce-greenhouse-emissions-project-greenlight/">Project Green Light</a>, we partnered with 13 cities around the world to help improve traffic flow at intersections and reduce stop-and-go emissions. Early numbers from these partnerships indicate a potential for up to 30% reduction in stops and up to 10% reduction in emissions.
</p>
<p>
In our <a href="https://sites.research.google/contrails/">contrails work</a>, we analyzed large-scale weather data, historical satellite images, and past flights. We <a href="https://blog.google/technology/ai/ai-airlines-contrails-climate-change/">trained an AI model</a> to predict where contrails form and reroute airplanes accordingly. In partnership with American Airlines and Breakthrough Energy, we used this system to demonstrate contrail reduction by 54%.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnU7wUIVSDG3HicS_tHszwAD_RlhU_SXJO0quz5CUJ0RLT03erljh8ckLW8NLAlYtOPpX6lzsohacC7x2X-_2abBGDGWGpIN0KZpYJ0YH3FOzRDhq-zVTeXne4LjpKGPoDtNtljKkee2R4hE3ju3wYhltF5q8oaPZ1I9R39eB2uYBnFfRxjka1vCg8H5Vv/s957/image14.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="400" data-original-width="957" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjnU7wUIVSDG3HicS_tHszwAD_RlhU_SXJO0quz5CUJ0RLT03erljh8ckLW8NLAlYtOPpX6lzsohacC7x2X-_2abBGDGWGpIN0KZpYJ0YH3FOzRDhq-zVTeXne4LjpKGPoDtNtljKkee2R4hE3ju3wYhltF5q8oaPZ1I9R39eB2uYBnFfRxjka1vCg8H5Vv/s16000/image14.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">Contrails detected over the United States using AI and GOES-16 satellite imagery.</em></td></tr></tbody></table>
<p>
We are also developing novel technology-driven approaches to <a href="https://blog.google/outreach-initiatives/sustainability/google-ai-climate-change-solutions/">help communities with the effects of climate change</a>. For example, we have <a href="https://blog.google/outreach-initiatives/sustainability/flood-hub-ai-flood-forecasting-more-countries/">expanded our flood forecasting coverage to 80 countries</a>, which directly impacts more than 460 million people. We have initiated a <a href="https://blog.research.google/2023/10/looking-back-at-wildfire-research-in.html">number of research efforts</a> to help mitigate the increasing danger of wildfires, including <a href="https://blog.research.google/2023/02/real-time-tracking-of-wildfire.html">real-time tracking of wildfire boundaries</a> using satellite imagery, and work that <a href="https://blog.research.google/2023/10/improving-traffic-evacuations-case-study.html">improves emergency evacuation plans</a> for communities at risk to rapidly-spreading wildfires. Our <a href="https://www.americanforests.org/article/american-forests-unveils-updates-for-tree-equity-score-tool-to-address-climate-justice/">partnership</a> with American Forests puts data from our <a href="https://insights.sustainability.google/places/ChIJVTPokywQkFQRmtVEaUZlJRA/trees?hl=en-US">Tree Canopy</a> project to work in their <a href="https://treeequityscore.org/">Tree Equity Score</a> platform, helping communities identify and address unequal access to trees.</p>
<p>
Finally, we continued to develop better models for weather prediction at longer time horizons. Improving on <a href="https://blog.research.google/2020/03/a-neural-weather-model-for-eight-hour.html">MetNet</a> and <a href="https://blog.research.google/2021/11/metnet-2-deep-learning-for-12-hour.html">MetNet-2</a>, in this year’s work on <a href="https://blog.research.google/2023/11/metnet-3-state-of-art-neural-weather.html">MetNet-3</a>, we now outperform traditional numerical weather simulations up to twenty-four hours. In the area of medium-term, global weather forecasting, our work on <a href="https://deepmind.google/discover/blog/graphcast-ai-model-for-faster-and-more-accurate-global-weather-forecasting/">GraphCast</a> showed significantly better prediction accuracy for up to 10 days compared to <a href="https://en.wikipedia.org/wiki/Integrated_Forecast_System">HRES</a>, the most accurate operational deterministic forecast, produced by the <a href="https://www.ecmwf.int/">European Centre for Medium-Range Weather Forecasts</a> (ECMWF). In collaboration with ECMWF, we released <a href="https://blog.research.google/2023/08/weatherbench-2-benchmark-for-next.html">WeatherBench-2</a>, a benchmark for evaluating the accuracy of weather forecasts in a common framework.
</p>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><iframe allowfullscreen="" class="BLOG_video_class" height="360" src="https://www.youtube.com/embed/Q6fOlW-Y_Ss" width="640" youtube-src-id="Q6fOlW-Y_Ss"></iframe></td></tr><tr><td class="tr-caption" style="text-align: center;">A selection of GraphCast’s predictions rolling across 10 days showing specific humidity at 700 hectopascals (about 3 km above surface), surface temperature, and surface wind speed.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Health and the life sciences</h3>
<p>
The potential of AI to dramatically improve processes in healthcare is significant. Our initial <a href="https://www.nature.com/articles/s41586-023-06291-2">Med-PaLM</a> model was the first model capable of achieving a passing score on the U.S. medical licensing exam. Our more recent <a href="https://blog.google/technology/health/ai-llm-medpalm-research-thecheckup/">Med-PaLM 2 model</a> improved by a further 19%, achieving an expert-level accuracy of 86.5%. These <a href="https://sites.research.google/med-palm/">Med-PaLM models</a> are language-based, enable clinicians to ask questions and have a dialogue about complex medical conditions, and are <a href="https://cloud.google.com/blog/topics/healthcare-life-sciences/introducing-medlm-for-the-healthcare-industry">available</a> to healthcare organizations as part of <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/medlm/overview">MedLM</a> through Google Cloud.
</p>
<p>
In the same way our general language models are evolving to handle multiple modalities, we have recently shown research on a <a href="https://blog.research.google/2023/08/multimodal-medical-ai.html">multimodal version of Med-PaLM</a> capable of interpreting medical images, textual data, and other modalities, <a href="https://arxiv.org/abs/2307.14334">describing a path</a> for how we can realize the exciting potential of AI models to help advance real-world clinical care.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgIHNosDcouStVqnHjoNr_b7YQQGCEXsxlCOy1ltKotOv5GQfl6JA9We90L6Ej3TZ2tiutOrASoon-BPt__Fh9jN2NiXY1W6z5emcW63JkkS7sS1LrsMz7mINkzvSuD_i4NR3GdRbsQFlVrNrC3cv1bNHETISJ_Ml0n7ddXbtHmO_AzfSd8EPq7-ud7iBNH/s1600/image2.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="900" data-original-width="1600" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgIHNosDcouStVqnHjoNr_b7YQQGCEXsxlCOy1ltKotOv5GQfl6JA9We90L6Ej3TZ2tiutOrASoon-BPt__Fh9jN2NiXY1W6z5emcW63JkkS7sS1LrsMz7mINkzvSuD_i4NR3GdRbsQFlVrNrC3cv1bNHETISJ_Ml0n7ddXbtHmO_AzfSd8EPq7-ud7iBNH/s16000/image2.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">Med-PaLM M is a large multimodal generative model that flexibly encodes and interprets biomedical data including clinical language, imaging, and genomics with the same model weights.</em></td></tr></tbody></table>
<p>We have also been working on <a href="https://deepmind.google/discover/blog/codoc-developing-reliable-ai-tools-for-healthcare/">how best to harness AI models in clinical workflows</a>. We have shown that <a href="https://blog.research.google/2023/03/learning-from-deep-learning-case-study.html">coupling deep learning with interpretability methods</a> can yield new insights for clinicians. We have also shown that self-supervised learning, with careful consideration of privacy, safety, fairness and ethics, <a href="https://blog.research.google/2023/04/robust-and-efficient-medical-imaging.html">can reduce the amount of de-identified data needed</a> to train clinically relevant medical imaging models by 3×–100×, reducing the barriers to adoption of models in real clinical settings. We also released an <a href="https://blog.research.google/2023/11/enabling-large-scale-health-studies-for.html">open source mobile data collection platform</a> for people with chronic disease to provide tools to the community to build their own studies.</p>
<p>
AI systems can also discover completely new signals and biomarkers in existing forms of medical data. In work on <a href="https://blog.research.google/2023/03/detecting-novel-systemic-biomarkers-in.html">novel biomarkers discovered in retinal images</a>, we demonstrated that a number of systemic biomarkers spanning several organ systems (e.g., kidney, blood, liver) can be predicted from external eye photos. In other work, we showed that combining <a href="https://blog.research.google/2023/04/developing-aging-clock-using-deep.html">retinal images and genomic information</a> helps identify some underlying factors of aging.
</p>
<p>
In the genomics space, we worked with 119 scientists across 60 institutions to create a <a href="https://blog.research.google/2023/05/building-better-pangenomes-to-improve.html">new map of the human genome</a>, or pangenome. This more equitable pangenome better represents the genomic diversity of global populations. Building on our ground-breaking <a href="https://www.nature.com/articles/s41586-021-03819-2">AlphaFold</a> work, our work on <a href="https://deepmind.google/discover/blog/a-catalogue-of-genetic-mutations-to-help-pinpoint-the-cause-of-diseases/">AlphaMissense</a> this year provides a catalog of predictions for 89% of all 71 million possible <a href="https://en.wikipedia.org/wiki/Missense_mutation">missense variants</a> as either likely pathogenic or likely benign.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiHNPYsNIjVTRyfPZODcVpp-XnQAtz2b0HcKsa8F9GyJXoKfp-9PH8N_IcXO_lJ7WfuTZ2ezeAVYCDPnqeyu-qeXahqu1lwJXb6Zq00Mt7M-WMyFff9eUL67L_QZBwK95DNlVAcFpd9cr1GW5gxqNb0mOJszQphyphenhyphen6Yi_QelNzeYnvZ1XNKZqNU2EP_x8MjW/s1070/image1.jpg" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="491" data-original-width="1070" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiHNPYsNIjVTRyfPZODcVpp-XnQAtz2b0HcKsa8F9GyJXoKfp-9PH8N_IcXO_lJ7WfuTZ2ezeAVYCDPnqeyu-qeXahqu1lwJXb6Zq00Mt7M-WMyFff9eUL67L_QZBwK95DNlVAcFpd9cr1GW5gxqNb0mOJszQphyphenhyphen6Yi_QelNzeYnvZ1XNKZqNU2EP_x8MjW/s16000/image1.jpg" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">Examples of AlphaMissense predictions overlaid on AlphaFold predicted structures (red – predicted as pathogenic; blue – predicted as benign; grey – uncertain). Red dots represent known pathogenic missense variants, blue dots represent known benign variants. <strong>Left:</strong> HBB protein. Variants in this protein can cause sickle cell anaemia. <strong>Right:</strong> CFTR protein. Variants in this protein can cause cystic fibrosis.</em></td></tr></tbody></table>
<p>
We also shared <a href="https://deepmind.google/discover/blog/a-glimpse-of-the-next-generation-of-alphafold/">an update</a> on progress towards the next generation of AlphaFold. Our latest model can now generate predictions for nearly all molecules in the <a href="https://www.wwpdb.org/">Protein Data Bank</a> (PDB), frequently reaching atomic accuracy. This unlocks new understanding and significantly improves accuracy in multiple key biomolecule classes, including ligands (small molecules), proteins, nucleic acids (DNA and RNA), and those containing post-translational modifications (PTMs).</p><p></p>
<p>
On the neuroscience front, we <a href="https://blog.research.google/2023/09/google-research-embarks-on-effort-to.html">announced a new collaboration</a> with Harvard, Princeton, the NIH, and others to map an entire mouse brain at synaptic resolution, beginning with a first phase that will focus on the <a href="https://en.wikipedia.org/wiki/Hippocampal_formation">hippocampal formation</a> — the area of the brain responsible for memory formation, spatial navigation, and other important functions.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Quantum computing</h3>
<p>
Quantum computers have the potential to solve big, real-world problems across science and industry. But to realize that potential, they must be significantly larger than they are today, and they must reliably perform tasks that cannot be performed on classical computers.
</p>
<p>
This year, we took an important step towards the development of a large-scale, useful quantum computer. Our breakthrough is the first demonstration of <a href="https://blog.research.google/2023/02/suppressing-quantum-errors-by-scaling.html">quantum error correction</a>, showing that it’s possible to reduce errors while also increasing the number of qubits. To enable real-world applications, these qubit building blocks must perform more reliably, lowering the error rate from ~1 in 10<sup>3</sup> typically seen today, to ~1 in 10<sup>8</sup>.
</p>
<br />
<h2>Responsible AI research</h2>
<div style="line-height: 40%;">
<br />
</div>
<h3>Design for Responsibility </h3>
<p>
Generative AI is having a transformative impact in a wide range of fields including healthcare, education, security, energy, transportation, manufacturing, and entertainment. Given these advances, the importance of designing technologies consistent with our <a href="https://ai.google/responsibility/principles/">AI Principles</a> remains a top priority. We also recently published case studies of <a href="https://blog.research.google/2023/11/emerging-practices-for-society-centered.html">emerging practices in society-centered AI</a>. And in our annual <a href="https://storage.googleapis.com/gweb-uniblog-publish-prod/documents/2023_Google_AI_Principles_Progress_Update.pdf" target="_blank">AI Principles Progress Update</a>, we offer details on how our Responsible AI research is integrated into products and risk management processes.
</p>
<p>
Proactive design for Responsible AI begins with identifying and documenting potential harms. For example, we recently <a href="https://deepmind.google/discover/blog/evaluating-social-and-ethical-risks-from-generative-ai/">introduced</a> a <a href="https://arxiv.org/abs/2310.11986">three-layered</a> context-based framework for comprehensively evaluating the social and ethical risks of AI systems. During model design, harms can be mitigated with the use of <a href="https://blog.research.google/2023/11/responsible-ai-at-google-research_16.html">responsible datasets</a>.
</p>
<p>
We are <a href="https://blog.google/technology/research/project-elevate-black-voices-google-research/">partnering with Howard University</a> to build high quality African-American English (AAE) datasets to improve our products and make them work well for more people. Our research on <a href="https://dl.acm.org/doi/10.1145/3593013.3594016">globally inclusive cultural representation</a> and our publication of the <a href="https://skintone.google/">Monk Skin Tone scale</a> furthers our commitments to equitable representation of all people. The insights we gain and techniques we develop not only help us improve our own models, they also power <a href="https://blog.google/intl/en-in/company-news/using-ai-to-study-demographic-representation-in-indian-tv/">large-scale studies of representation in popular media</a> to inform and inspire more inclusive content creation around the world.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEirYAo6ClPM9zD8ctnoTvdyhnwW1VdVT2p677EMKGrN0oULjyK9TdS05z1OzhTTuyxp0kfuUUbHEieZXYRW6hUe3XlJM6HYOS68rXneSGQeLTy_Bo_SCvbxnjHdG2CB_8aLCz5pP-B_dciLKZYlIo9j8bEyUdKtZQhg7DukG9pJudJKU-o0DRuM5XrbvkBI/s1196/image9.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="158" data-original-width="1196" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEirYAo6ClPM9zD8ctnoTvdyhnwW1VdVT2p677EMKGrN0oULjyK9TdS05z1OzhTTuyxp0kfuUUbHEieZXYRW6hUe3XlJM6HYOS68rXneSGQeLTy_Bo_SCvbxnjHdG2CB_8aLCz5pP-B_dciLKZYlIo9j8bEyUdKtZQhg7DukG9pJudJKU-o0DRuM5XrbvkBI/s16000/image9.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">Monk Skin Tone (MST) Scale. See more at <a href="http://skintone.google/">skintone.google</a>.</em></td></tr></tbody></table>
<p>
With advances in generative image models, <a href="https://blog.research.google/2023/08/responsible-ai-at-google-research.html">fair and inclusive representation of people</a> remains a top priority. In the development pipeline, we are working to <a href="https://blog.research.google/2023/07/using-societal-context-knowledge-to.html">amplify underrepresented voices and to better integrate social context knowledge</a>. We proactively address potential harms and bias using <a href="https://arxiv.org/pdf/2306.06135.pdf">classifiers and filters</a>, <a href="https://arxiv.org/pdf/2311.17259.pdf">careful dataset analysis</a>, and in-model mitigations such as fine-tuning, <a href="https://arxiv.org/abs/2310.16523">reasoning</a>, <a href="https://arxiv.org/abs/2306.14308">few-shot prompting</a>, <a href="https://arxiv.org/abs/2310.16959">data augmentation</a> and <a href="https://arxiv.org/abs/2310.17022">controlled decoding</a>, and our research showed that generative AI enables <a href="https://arxiv.org/abs/2302.06541">higher quality safety classifiers</a> to be developed with far less data. We also released <a href="https://developers.googleblog.com/2023/10/make-with-makersuite-part-2-tuning-llms.html">a powerful way to better tune models with less data</a> giving developers more control of responsibility challenges in generative AI.</p>
<p>
We have developed new<a href="https://arxiv.org/abs/2303.08114"> state-of-the-art explainability methods</a> to identify the role of training data on model behaviors. By <a href="https://arxiv.org/abs/2302.06598">combining training data attribution methods with agile classifiers</a>, we found that we can identify mislabelled training examples. This makes it possible to reduce the noise in training data, leading to significant improvements in model accuracy.
</p>
<p>
We initiated several efforts to improve safety and transparency about online content. For example, we introduced <a href="https://deepmind.google/discover/blog/identifying-ai-generated-images-with-synthid/">SynthID</a>, a tool for watermarking and identifying AI-generated images. SynthID is imperceptible to the human eye, doesn't compromise image quality, and allows the watermark to remain detectable, even after modifications like adding filters, changing colors, and saving with various lossy compression schemes.
</p>
<p>
We also launched <a href="https://blog.google/products/search/google-search-new-fact-checking-features/">About This Image</a> to help people assess the credibility of images, showing information like an image's history, how it's used on other pages, and available metadata about an image. And we <a href="https://arxiv.org/abs/2210.03535">explored safety methods</a> that have been developed in other fields, learning from established situations where there is low-risk tolerance.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBYUNXlxnFiBnnB_-EFKhdc-1W8MwG-VZKrhyKtWFmfAcqlrbSdPl7TAslOAMaH1Zon0TvGKpj23nlO7XZyg2ovFuNHpgXbsyUUPrxzf1RtJFlBPzR5Hh9KAus1l79qrBFP5JJDScQgn_5cq3ZVf7T0VuPiNLLJ-PsENlba_BA0nORQkofZYY7K1DhJGXt/s616/image6.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="346" data-original-width="616" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEiBYUNXlxnFiBnnB_-EFKhdc-1W8MwG-VZKrhyKtWFmfAcqlrbSdPl7TAslOAMaH1Zon0TvGKpj23nlO7XZyg2ovFuNHpgXbsyUUPrxzf1RtJFlBPzR5Hh9KAus1l79qrBFP5JJDScQgn_5cq3ZVf7T0VuPiNLLJ-PsENlba_BA0nORQkofZYY7K1DhJGXt/s16000/image6.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">SynthID generates an imperceptible digital watermark for AI-generated images.</em></td></tr></tbody></table>
<p>
Privacy remains an essential aspect of our commitment to Responsible AI. We continued improving our state-of-the-art privacy preserving learning algorithm <a href="https://arxiv.org/abs/2103.00039">DP-FTRL</a>, developed the DP-Alternating Minimization algorithm (<a href="https://arxiv.org/pdf/2310.15454.pdf">DP-AM</a>) to enable personalized recommendations with rigorous privacy protection, and defined a new <a href="https://blog.research.google/2023/09/differentially-private-median-and-more.html">general paradigm</a> to reduce the privacy costs for many aggregation and learning tasks. We also proposed a scheme for <a href="https://openreview.net/pdf?id=q15zG9CHi8">auditing differentially private machine learning systems</a>.</p>
<p>
On the applications front we demonstrated that <a href="https://arxiv.org/pdf/2308.10888.pdf">DP-SGD offers a practical solution</a> in the large model fine-tuning regime and showed that images generated by DP diffusion models are <a href="https://arxiv.org/pdf/2302.13861.pdf">useful for a range of downstream tasks</a>. We <a href="https://blog.research.google/2023/12/sparsity-preserving-differentially.html">proposed</a> a new algorithm for DP training of large embedding models that provides efficient training on TPUs without compromising accuracy.
</p>
<p>
We also teamed up with a broad group of academic and industrial researchers to organize the <a href="https://unlearning-challenge.github.io/">first Machine Unlearning Challenge</a> to address the scenario in which training images are forgotten to protect the privacy or rights of individuals. We shared a mechanism for <a href="https://arxiv.org/pdf/2311.17035.pdf">extractable memorization</a>, and <a href="https://arxiv.org/abs/2302.03874">participatory systems</a> that give users more control over their sensitive data.
</p>
<p>
We continued to expand the world’s largest corpus of atypical speech recordings to >1M utterances in <a href="https://sites.research.google/euphonia/about/">Project Euphonia</a>, which enabled us to train a <a href="https://blog.research.google/2023/03/universal-speech-model-usm-state-of-art.html">Universal Speech Model</a> to <a href="https://blog.research.google/2023/06/responsible-ai-at-google-research-ai.html">better recognize atypical speech by 37%</a> on real-world benchmarks.
</p>
<p>
We also built an <a href="https://blog.research.google/2023/08/study-socially-aware-temporally-causal.html">audiobook recommendation system</a> for students with reading disabilities such as dyslexia.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Adversarial testing</h3>
<p>
Our work in adversarial testing <a href="https://blog.research.google/2023/03/responsible-ai-at-google-research.html">engaged community voices</a> from historically marginalized communities. We partnered with groups such as the <a href="https://arxiv.org/abs/2303.08177">Equitable AI Research Round Table</a> (EARR) to ensure we represent the diverse communities who use our models and <a href="https://dynabench.org/tasks/adversarial-nibbler">engage with external users</a> to identify potential harms in generative model outputs.
</p>
<p>
We <a href="https://blog.research.google/2023/11/responsible-ai-at-google-research_16.html">established a dedicated Google AI Red Team</a> focused on testing AI models and products for security, privacy, and abuse risks. We showed that attacks such as “<a href="https://arxiv.org/pdf/2302.10149.pdf?isApp=1">poisoning</a>” or <a href="https://arxiv.org/pdf/2306.15447.pdf">adversarial examples</a> can be applied to production models and surface additional risks such as memorization in both <a href="https://www.usenix.org/system/files/usenixsecurity23-carlini.pdf">image</a> and <a href="https://arxiv.org/pdf/2311.17035.pdf">text generative models</a>. We also demonstrated that defending against such attacks can be challenging, as merely applying defenses can cause other <a href="https://arxiv.org/pdf/2309.05610.pdf">security and privacy leakages</a>. We also introduced model evaluation for <a href="https://arxiv.org/abs/2305.15324">extreme risks</a>, such as offensive cyber capabilities or strong manipulation skills.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Democratizing AI though tools and education</h3>
<p>
As we advance the state-of-the-art in ML and AI, we also want to ensure people can understand and apply AI to specific problems. We released <a href="https://makersuite.google.com/">MakerSuite</a> (now <a href="https://makersuite.google.com">Google AI Studio</a>), a web-based tool that enables AI developers to quickly iterate and build lightweight AI-powered apps. To help AI engineers better understand and debug AI, we released <a href="https://pair-code.github.io/lit/">LIT 1.0</a>, a state-of-the-art, open-source debugger for machine learning models.
</p>
<p>
<a href="https://colab.google/">Colab</a>, our tool that helps developers and students access powerful computing resources right in their web browser, reached over 10 million users. We’ve just added <a href="https://blog.google/technology/ai/democratizing-access-to-ai-enabled-coding-with-colab/">AI-powered code assistance</a> to all users at no cost — making Colab an even more helpful and integrated experience in data and ML workflows.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhy5OhqeP7Rf5gTbCP-gYsmTFyc9ztUMXrcv_M8Lya5zLPzzRVrHmldGiR-rf1PnggxP_rlG2mo6YJ9NhnNnFHEbtn8f-FJk_cxOTtO9hL0ZhxV9hSGd0LW0-RymxkPVBbqXKDk4nRV7mJE6heVnn-6tYcLdpofuhqKPHnqxDTZf1hjifTg3k7K4VNcS2S6/s844/image10.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="470" data-original-width="844" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhy5OhqeP7Rf5gTbCP-gYsmTFyc9ztUMXrcv_M8Lya5zLPzzRVrHmldGiR-rf1PnggxP_rlG2mo6YJ9NhnNnFHEbtn8f-FJk_cxOTtO9hL0ZhxV9hSGd0LW0-RymxkPVBbqXKDk4nRV7mJE6heVnn-6tYcLdpofuhqKPHnqxDTZf1hjifTg3k7K4VNcS2S6/s16000/image10.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;"><em style="text-align: left;">One of the most used features is “Explain error” — whenever the user encounters an execution error in Colab, the code assistance model provides an explanation along with a potential fix.</em></td></tr></tbody></table>
<p>
To ensure AI produces accurate knowledge when put to use, we also recently introduced <a href="https://deepmind.google/discover/blog/funsearch-making-new-discoveries-in-mathematical-sciences-using-large-language-models/">FunSearch</a>, a new approach that generates verifiably true knowledge in mathematical sciences using evolutionary methods and large language models.</p>
<p>
For AI engineers and product designers, we’re updating the <a href="https://pair.withgoogle.com/guidebook/">People + AI Guidebook</a> with generative AI best practices, and we continue to design <a href="https://pair.withgoogle.com/explorables/">AI Explorables</a>, which includes <a href="https://pair.withgoogle.com/explorables/uncertainty-ood/">how and why models sometimes make incorrect predictions confidently</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Community engagement</h3>
<p>
We continue to advance the fields of AI and computer science by publishing much of our work and participating in and organizing conferences. We have published more than 500 papers so far this year, and have strong presences at conferences like ICML (see the <a href="https://blog.research.google/2023/07/google-at-icml-2023.html">Google Research</a> and <a href="https://deepmind.google/discover/blog/google-deepmind-research-at-icml-2023/">Google DeepMind</a> posts), ICLR (<a href="https://blog.research.google/2023/04/google-at-iclr-2023.html">Google Research</a>, <a href="https://deepmind.google/discover/blog/deepminds-latest-research-at-iclr-2023/">Google DeepMind</a>), NeurIPS (<a href="https://blog.research.google/2023/12/google-at-neurips-2023.html">Google Research</a>, <a href="https://deepmind.google/discover/blog/google-deepmind-at-neurips-2023/">Google DeepMind</a>), <a href="https://blog.research.google/2023/10/google-at-iccv-2023.html">ICCV</a>, <a href="https://blog.research.google/2023/06/google-at-cvpr-2023.html">CVPR</a>, <a href="https://blog.research.google/2023/07/google-at-acl-2023.html">ACL</a>, <a href="https://blog.research.google/2023/04/google-at-chi-2023.html">CHI</a>, and <a href="https://blog.research.google/2023/08/google-at-interspeech-2023.html">Interspeech</a>. We are also working to support researchers around the world, participating in events like the <a href="https://deeplearningindaba.com/2023/google-outreach-mentorship-programme/">Deep Learning Indaba</a>, <a href="https://khipu.ai/khipu2023/khipu-2023-speakers2023/">Khipu</a>, supporting <a href="https://blog.google/around-the-globe/google-latin-america/phd-fellowship-research-latin-america/">PhD Fellowships in Latin America</a>, and more. We also worked with partners from 33 academic labs to pool data from 22 different robot types and create the <a href="https://deepmind.google/discover/blog/scaling-up-learning-across-many-different-robot-types/">Open X-Embodiment dataset and RT-X model</a> to better advance responsible AI development.
</p>
<p>
Google has spearheaded an industry-wide effort to develop <a href="https://mlcommons.org/working-groups/ai-safety/ai-safety/">AI safety benchmarks</a> under the <a href="https://mlcommons.org/">MLCommons</a> standards organization with participation from several major players in the generative AI space including OpenAI, Anthropic, Microsoft, Meta, Hugging Face, and more. Along with others in the industry we also <a href="https://blog.google/outreach-initiatives/public-policy/google-microsoft-openai-anthropic-frontier-model-forum/">co-founded</a> the <a href="https://www.frontiermodelforum.org/">Frontier Model Forum</a> (FMF), which is focused on ensuring safe and responsible development of frontier AI models. With our FMF partners and other philanthropic organizations, we launched a $10 million <a href="https://blog.google/outreach-initiatives/public-policy/google-microsoft-anthropic-open-ai-frontier-model-forum-executive-director/">AI Safety Fund</a> to advance research into the ongoing development of the tools for society to effectively test and evaluate the most capable AI models.
</p>
<p>
In close partnership with <a href="http://Google.org">Google.org</a>, we <a href="https://blog.google/technology/ai/google-ai-data-un-global-goals/">worked with the United Nations</a> to build the <a href="https://unstats.un.org/UNSDWebsite/undatacommons/sdgs">UN Data Commons for the Sustainable Development Goals</a>, a tool that tracks metrics across the 17 <a href="https://sdgs.un.org/goals">Sustainable Development Goals</a>, and <a href="https://globalgoals.withgoogle.com/globalgoals/supported-organizations">supported projects</a> from NGOs, academic institutions, and social enterprises on <a href="https://blog.google/outreach-initiatives/google-org/httpsbloggoogleoutreach-initiativesgoogle-orgunited-nations-global-goals-google-ai-/">using AI to accelerate progress on the SDGs</a>.
</p>
<p>
The items highlighted in this post are a small fraction of the research work we have done throughout the last year. Find out more at the <a href="https://blog.research.google/">Google Research</a> and <a href="https://deepmind.google/discover/blog/">Google DeepMind</a> blogs, and our <a href="https://research.google/pubs/">list of publications</a>.
</p>
<br />
<h2>Future vision</h2>
<p>
As multimodal models become even more capable, they will empower people to make incredible progress in areas from science to education to entirely new areas of knowledge.
</p>
<p>
Progress continues apace, and as the year advances, and our products and research advance as well, people will find more and interesting creative uses for AI.
</p>
<p>
Ending this Year-in-Review where we began, as we say in <em><a href="https://ai.google/static/documents/google-why-we-focus-on-ai.pdf">Why We Focus on AI (and to what end)</a></em>:
</p>
<div style="margin-left: 40px;">
<p>
If pursued boldly and responsibly, we believe that AI can be a foundational technology that transforms the lives of people everywhere — this is what excites us!
</p>
</div>
<!--Footnotes-->
<hr width="80%" />
<p>
<span class="Apple-style-span" style="font-size: x-small;">This Year-in-Review is cross-posted on both the <a href="https://blog.research.google/2023/12/2023-year-of-groundbreaking-advances-in.html">Google Research Blo</a>g and the <a href="https://deepmind.google/discover/blog/2023-a-year-of-groundbreaking-advances-in-ai-and-computing/">Google DeepMind Blog</a>.</span></p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0tag:blogger.com,1999:blog-8474926331452026626.post-55162007034946362072023-12-19T13:08:00.000-08:002024-01-12T11:04:04.629-08:00VideoPoet: A large language model for zero-shot video generation<span class="byline-author">Posted by Dan Kondratyuk and David Ross, Software Engineers, Google Research</span>
<img src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjFG8pxLk-uvJVPesDUxU7Ox7Tb0VR4lB4jNygxiiIl9gyeX-rgtCb5jhWbPDKGad4d5GkPFr7-uzdZOngFtnPbmV6IurKEd5vHTiLesCvx8RDzBy_5u31e6C93kuirOG8pKcwEBkBrw4TBTzh3WIoX1TZYzYM3M4aj4zYD2r_p5FflI7ntpqoD9fsGQhwO/s1600/videopoetpreview.gif" style="display: none;" />
<p>
A recent wave of video generation models has burst onto the scene, in many cases showcasing stunning picturesque quality. One of the current bottlenecks in video generation is in the ability to produce coherent large motions. In many cases, even the current leading models either generate small motion or, when producing larger motions, exhibit noticeable artifacts.
</p> <a name='more'></a>
<p>
To explore the application of language models in video generation, we introduce VideoPoet (<a href="http://sites.research.google/videopoet">website</a>, <a href="https://arxiv.org/abs/2312.14125">research paper</a>), a large language model (LLM) that is capable of a wide variety of video generation tasks, including text-to-video, image-to-video, video stylization, video <a href="https://en.wikipedia.org/wiki/Inpainting">inpainting</a> and <a href="https://paperswithcode.com/task/image-outpainting">outpainting</a>, and video-to-audio. One notable observation is that the leading video generation models are almost exclusively diffusion-based (for one example, see <a href="https://imagen.research.google/video/">Imagen Video</a>). On the other hand, LLMs are widely recognized as the <em>de facto</em> standard due to their exceptional learning capabilities across various modalities, including language, code, and audio (e.g., <a href="https://google-research.github.io/seanet/audiopalm/examples/">AudioPaLM</a>). In contrast to alternative models in this space, our approach seamlessly integrates many video generation capabilities within a single LLM, rather than relying on separately trained components that specialize on each task.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Overview</h2>
<p>
The diagram below illustrates VideoPoet’s capabilities. Input images can be animated to produce motion, and (optionally cropped or masked) video can be edited for inpainting or outpainting. For stylization, the model takes in a video representing the depth and optical flow, which represent the motion, and paints contents on top to produce the text-guided style.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjfgEjHPIukR1pHUT7VBU8DecJaayp8sIV1WC4HM6EOW-K1-E_Xu9-qE2hrTBrtgrks3awr8IiT9NhxVMDOR0qoFZ8nDQT_Si3LY60CWKySkaybXRW5Uf6EwZIiDRy7qkQGuoLUxoysR-fjEr0NTMqAGP2Cm8cyUQHVOOlS7MysnwHwqhdM-eA63PXy5XsV/s1999/image7.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="682" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjfgEjHPIukR1pHUT7VBU8DecJaayp8sIV1WC4HM6EOW-K1-E_Xu9-qE2hrTBrtgrks3awr8IiT9NhxVMDOR0qoFZ8nDQT_Si3LY60CWKySkaybXRW5Uf6EwZIiDRy7qkQGuoLUxoysR-fjEr0NTMqAGP2Cm8cyUQHVOOlS7MysnwHwqhdM-eA63PXy5XsV/s16000/image7.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">An overview of VideoPoet, capable of multitasking on a variety of video-centric inputs and outputs. The LLM can optionally take text as input to guide generation for text-to-video, image-to-video, video-to-audio, stylization, and outpainting tasks. Resources used: <a href="https://commons.wikimedia.org/wiki/Main_Page">Wikimedia Commons</a> and <a href="https://davischallenge.org/">DAVIS</a>.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Language models as video generators</h2>
<p>
One key advantage of using LLMs for training is that one can reuse many of the scalable efficiency improvements that have been introduced in existing LLM training infrastructure. However, LLMs operate on discrete tokens, which can make video generation challenging. Fortunately, there exist <a href="https://magvit.cs.cmu.edu/v2/">video</a> and <a href="https://arxiv.org/abs/2107.03312">audio</a> tokenizers, which serve to encode video and audio clips as sequences of discrete tokens (i.e., integer indices), and which can also be converted back into the original representation.
</p>
<p>
VideoPoet trains an <a href="https://en.wikipedia.org/wiki/Autoregressive_model">autoregressive language model</a> to learn across video, image, audio, and text modalities through the use of multiple tokenizers (<a href="https://magvit.cs.cmu.edu/v2/">MAGVIT V2</a> for video and image and <a href="https://arxiv.org/abs/2107.03312">SoundStream</a> for audio). Once the model generates tokens conditioned on some context, these can be converted back into a viewable representation with the tokenizer decoders.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxFblHaHRJNH7Oi2_oOTosGN9XrjgjhWmnfADchMT8WR0XAo6SxiUfpUmn5R6akciiRduaKIMdgwHZzK3xW8mErarQ_ugx41ctQAMK08O9UMVevgkk-AgFI1xYFWAomd16OcOh0R-XpyZVLQXncpk2SHf-RmPzrqBbIWZc-nUG2TH6nC2R7qyHXn8eTC-u/s2680/image21.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="824" data-original-width="2680" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgxFblHaHRJNH7Oi2_oOTosGN9XrjgjhWmnfADchMT8WR0XAo6SxiUfpUmn5R6akciiRduaKIMdgwHZzK3xW8mErarQ_ugx41ctQAMK08O9UMVevgkk-AgFI1xYFWAomd16OcOh0R-XpyZVLQXncpk2SHf-RmPzrqBbIWZc-nUG2TH6nC2R7qyHXn8eTC-u/s16000/image21.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">A detailed look at the VideoPoet task design, showing the training and inference inputs and outputs of various tasks. Modalities are converted to and from tokens using tokenizer encoder and decoders. Each modality is surrounded by boundary tokens, and a task token indicates the type of task to perform.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Examples generated by VideoPoet</h2>
<p>
Some examples generated by our model are shown below.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgoUJFpPd377ceVnh3Yi0Y1oM6pVPL_FSEwugxBVKfEwHV8VA-1ZPmddz1VRBtqESjjEP83EJ4HOLLSBmXOVLEODZVFZYUuDiCRyMUIvSaRsxieR-58iAwHPBf7SNeSBU2a3Pm80JOFivTSjZsqlerHxAGg_Ko_8gCDLWWtNAQffyGCNJrw2G5PPG-vSYtz/s1100/image18.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="963" data-original-width="1100" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgoUJFpPd377ceVnh3Yi0Y1oM6pVPL_FSEwugxBVKfEwHV8VA-1ZPmddz1VRBtqESjjEP83EJ4HOLLSBmXOVLEODZVFZYUuDiCRyMUIvSaRsxieR-58iAwHPBf7SNeSBU2a3Pm80JOFivTSjZsqlerHxAGg_Ko_8gCDLWWtNAQffyGCNJrw2G5PPG-vSYtz/s16000/image18.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Videos generated by VideoPoet from various text prompts. For specific text prompts refer to <a href="http://sites.research.google/videopoet">the website</a>.</td></tr></tbody></table>
<br />
<p>
For text-to-video, video outputs are variable length and can apply a range of motions and styles depending on the text content. To ensure responsible practices, we reference artworks and styles in the public domain e.g., Van Gogh’s “Starry Night”.</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container">
<tbody>
<tr><td style="text-align: right;"><b>Text Input</b></td>
<td> </td>
<td align="center" width="20%"><em>“A Raccoon dancing in Times Square”</em></td>
<td> </td>
<td align="center" width="20%"><em>“A horse galloping through Van-Gogh’s ‘Starry Night’”</em></td>
<td> </td>
<td align="center" width="20%"><em>“Two pandas playing cards”</em></td>
<td> </td>
<td align="center" width="20%"><em>“A large blob of exploding splashing rainbow paint, with an apple emerging, 8k”</em></td>
</tr>
<tr><td style="text-align: right;"><b>Video Output</b></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhFNvzU_hzVGT7XCvA4AdBprfzwpOV_b8xs6Q54No72B88fgtHFK7eHFr3UJ9Ac0tXIemkOUR6VxNC2I8HQWznWs2x4IsB8biOEe2DXBrMxW6HveNhuiae40mF8asd3jYh2WqZMTTObSKiHb0RJIFsK_C7VzBq685OPUY_uqPC5NJGXqKs3AWJOG-Gqgk0V/s448/image6.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhFNvzU_hzVGT7XCvA4AdBprfzwpOV_b8xs6Q54No72B88fgtHFK7eHFr3UJ9Ac0tXIemkOUR6VxNC2I8HQWznWs2x4IsB8biOEe2DXBrMxW6HveNhuiae40mF8asd3jYh2WqZMTTObSKiHb0RJIFsK_C7VzBq685OPUY_uqPC5NJGXqKs3AWJOG-Gqgk0V/s16000/image6.gif" /></a></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0FAaXVgkTSmoK90XYZq3sbW1fgRUmvOLZUpZ4dubLjJOFYID6w_q5GGIU1hy1E8B4J7hF01OOAupDxJWyD2OBMJEs6AIAbJEzH0qrx7TovnikAUXZyN_MQBu33QYe17CTtI95oI6x91qJhSSPyXNGlmkFKbWX8xwd6nT7Esd9RNE8tjiSbWwWQTKoAfjg/s448/image11.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj0FAaXVgkTSmoK90XYZq3sbW1fgRUmvOLZUpZ4dubLjJOFYID6w_q5GGIU1hy1E8B4J7hF01OOAupDxJWyD2OBMJEs6AIAbJEzH0qrx7TovnikAUXZyN_MQBu33QYe17CTtI95oI6x91qJhSSPyXNGlmkFKbWX8xwd6nT7Esd9RNE8tjiSbWwWQTKoAfjg/s16000/image11.gif" /></a></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4yqm_nd4mtu-3d_AdfxkKAX453hyphenhyphenXE9e7ohBzZC9RvboIZWxF_5fLw4XaCU3x1aUtW4WRckKzfqB-yzHdV_uzPiCAAwv6zgv__7MZtxdwiWsMiQ59NDH3axwd0UOIFA8C2aq0XNSgcV1ieCN1fc9MXFVIszTG8z7gpcAkgqTn5HOBBJ-pks8I1tOudAR8/s448/image12.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEi4yqm_nd4mtu-3d_AdfxkKAX453hyphenhyphenXE9e7ohBzZC9RvboIZWxF_5fLw4XaCU3x1aUtW4WRckKzfqB-yzHdV_uzPiCAAwv6zgv__7MZtxdwiWsMiQ59NDH3axwd0UOIFA8C2aq0XNSgcV1ieCN1fc9MXFVIszTG8z7gpcAkgqTn5HOBBJ-pks8I1tOudAR8/s16000/image12.gif" /></a></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg2huAujTA8vjmb5rlgj6vXF7uIHr0N7KrnLFiTXyjuN0l6kxSp7gQRPES5hD30ZzN-XTzUZSC-ROT19x0wE5cjQA_FOovY-Zox_68e0Dl4Dxanqvyuos3S1TExRdg3WZANucc-DXtKUsnim9Kh0GwU6LyFhpCJgHal0bI0UbpBYrXtlNGBYWuT8XVNIt6u/s448/image17.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg2huAujTA8vjmb5rlgj6vXF7uIHr0N7KrnLFiTXyjuN0l6kxSp7gQRPES5hD30ZzN-XTzUZSC-ROT19x0wE5cjQA_FOovY-Zox_68e0Dl4Dxanqvyuos3S1TExRdg3WZANucc-DXtKUsnim9Kh0GwU6LyFhpCJgHal0bI0UbpBYrXtlNGBYWuT8XVNIt6u/s16000/image17.gif" /></a></td>
</tr>
</tbody></table>
<p>For image-to-video, VideoPoet can take the input image and animate it with a prompt.</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj5kAQCs7GJoJR0m_8hmYsq-Wd-KsjuF782YuV5D33BlPE8f3-AU1iTKwOVrpxnnBDHa-5AXgkXNBNil61r5eVhXa2v16VUraEt6DAa-4_v-xHJq6lJfwkJ9ATQZTGdxKsbvnfielzv_6iJyLrubPyrGhle_BUecTVJkGo_7S8sM-3yl6vVVtqNNg4nf0mw/s1536/image13.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="1536" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEj5kAQCs7GJoJR0m_8hmYsq-Wd-KsjuF782YuV5D33BlPE8f3-AU1iTKwOVrpxnnBDHa-5AXgkXNBNil61r5eVhXa2v16VUraEt6DAa-4_v-xHJq6lJfwkJ9ATQZTGdxKsbvnfielzv_6iJyLrubPyrGhle_BUecTVJkGo_7S8sM-3yl6vVVtqNNg4nf0mw/s16000/image13.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">An example of image-to-video with text prompts to guide the motion. Each video is paired with an image to its left. <strong>Left</strong>: “A ship navigating the rough seas, thunderstorm and lightning, animated oil on canvas”. <strong>Middle</strong>: “Flying through a nebula with many twinkling stars”. <strong>Right</strong>: “A wanderer on a cliff with a cane looking down at the swirling sea fog below on a windy day”. Reference: <a href="https://commons.wikimedia.org/wiki/Main_Page">Wikimedia Commons</a>, public domain**.</td></tr></tbody></table>
<p>
For video stylization, we predict the optical flow and depth information before feeding into VideoPoet with some additional input text.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQ2eDU0pNOfQJF8cKPOV5QkKthIO3-98jbqyZJ7lFLk7cckq8Gg4FXrro6oikQRxDVDKKz8rg9CU6wihsfU68RnnLkHGxJUeFNGBobjsbJ4VHGFtUg-nerlt2rPiJ9bu8i2VkkXX5yEK650t4ay8F7K2zSW5-TjjkbR61TskVhsaQCw1-8lkP1gMDg96i1/s1536/image16.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="1536" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhQ2eDU0pNOfQJF8cKPOV5QkKthIO3-98jbqyZJ7lFLk7cckq8Gg4FXrro6oikQRxDVDKKz8rg9CU6wihsfU68RnnLkHGxJUeFNGBobjsbJ4VHGFtUg-nerlt2rPiJ9bu8i2VkkXX5yEK650t4ay8F7K2zSW5-TjjkbR61TskVhsaQCw1-8lkP1gMDg96i1/s16000/image16.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Examples of video stylization on top of VideoPoet text-to-video generated videos with text prompts, depth, and optical flow used as conditioning. The left video in each pair is the input video, the right is the stylized output. <b>Left</b>: “Wombat wearing sunglasses holding a beach ball on a sunny beach.” <b>Middle</b>: “Teddy bears ice skating on a crystal clear frozen lake.” <b>Right</b>: “A metal lion roaring in the light of a forge.”</td></tr></tbody></table>
<p>
VideoPoet is also capable of generating audio. Here we first generate 2-second clips from the model and then try to predict the audio without any text guidance. This enables generation of video and audio from a single model.</p>
<br />
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container">
<tbody>
<tr>
<td><video controls="controls" playsinline="" width="100%"> <source src="https://storage.googleapis.com/videopoet/videos/105_drums_with_audio.mp4" type="video/mp4"></source> </video></td>
<td> </td>
<td><video controls="controls" playsinline="" width="100%"> <source src="https://storage.googleapis.com/videopoet/videos/107_cat_piano_with_audio.mp4" type="video/mp4"></source> </video></td>
<td> </td>
<td><video controls="controls" playsinline="" width="100%"> <source src="https://storage.googleapis.com/videopoet/videos/108_train_with_audio.mp4" type="video/mp4"></source> </video></td>
<td> </td>
<td><video controls="controls" playsinline="" width="100%"> <source src="https://storage.googleapis.com/videopoet/videos/104_dog_popcorn_with_audio.mp4" type="video/mp4"></source> </video></td>
</tr>
</tbody></table>
<br />
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td class="tr-caption" style="text-align: center;">An example of video-to-audio, generating audio from a video example without any text input.</td></tr></tbody></table>
<p>
By default, the VideoPoet model generates videos in portrait orientation to tailor its output towards short-form content. To showcase its capabilities, we have produced a brief movie composed of many short clips generated by VideoPoet. For the script, we asked <a href="https://bard.google.com/">Bard</a> to write a short story about a traveling raccoon with a scene-by-scene breakdown and a list of accompanying prompts. We then generated video clips for each prompt, and stitched together all resulting clips to produce the final video below.
</p>
<br />
<br />
<div class="separator" style="clear: both; text-align: center;">
<iframe allowfullscreen="" class="BLOG_video_class" frameborder="0" height="360" src="https://www.youtube.com/embed/70wZKfx6Ylk" width="640" youtube-src-id="70wZKfx6Ylk"></iframe>
</div>
<br />
<p>
When we developed VideoPoet, we noticed some nice properties of the model’s capabilities, which we highlight below.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Long video</h3>
<p>
We are able to generate longer videos simply by conditioning on the last 1 second of video and predicting the next 1 second. By chaining this repeatedly, we show that the model can not only extend the video well but also faithfully preserve the appearance of all objects even over several iterations.
</p>
<p>
Here are two examples of VideoPoet generating long video from text input:
<br />
<br />
</p><table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container">
<tbody>
<tr><td style="text-align: right;"><b>Text Input</b></td>
<td> </td>
<td align="center" width="35%"><em>“An astronaut starts dancing on Mars. Colorful fireworks then explode in the background.”</em></td>
<td> </td>
<td align="center" width="35%"><em>“FPV footage of a very sharp elven city of stone in the jungle with a brilliant blue river, waterfall, and large steep vertical cliff faces.”</em></td>
<td> </td>
</tr>
<tr><td style="text-align: right;"><b>Video Output</b></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhaT8uErAOI-bKPCwsVpEQ_SSxqdgjVB9ai-Db3M9YXhGM3X9N0Jwt-UcDb6X8n2V_4-Tf76xkwlSS4ftV8TvAIV6ZbjeXK5JPtQr8Mb_ZcPKiIvOdwFXJOBEfDk1Gp1hzRBkoYwmoH3bAu6NVNBo-ficSneXgvhDF7fwMGVqMik0KWGwv_GLf7clQ2l5e5/s448/image14.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhaT8uErAOI-bKPCwsVpEQ_SSxqdgjVB9ai-Db3M9YXhGM3X9N0Jwt-UcDb6X8n2V_4-Tf76xkwlSS4ftV8TvAIV6ZbjeXK5JPtQr8Mb_ZcPKiIvOdwFXJOBEfDk1Gp1hzRBkoYwmoH3bAu6NVNBo-ficSneXgvhDF7fwMGVqMik0KWGwv_GLf7clQ2l5e5/s16000/image14.gif" /></a></td>
<td> </td>
<td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgDbroT8i5N-AdcRsTLjkLWoqzif1nJj-z1bRxcZwM_-1213gK6Or85VxKIFBMsnvAF_KLAmXWWLeDPxA5ZqWtNI5nqfp_6wzgKnqACckJMjBU2eNy8ySgvWjuFOPNarVcBwx9ZlrAdyMrWCdt29xDE7dPHhlwcxbRSyO7-ZllKs1SRGDgVm5XUc21I3Ddv/s448/image9.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="256" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEgDbroT8i5N-AdcRsTLjkLWoqzif1nJj-z1bRxcZwM_-1213gK6Or85VxKIFBMsnvAF_KLAmXWWLeDPxA5ZqWtNI5nqfp_6wzgKnqACckJMjBU2eNy8ySgvWjuFOPNarVcBwx9ZlrAdyMrWCdt29xDE7dPHhlwcxbRSyO7-ZllKs1SRGDgVm5XUc21I3Ddv/s16000/image9.gif" /></a></td>
<td> </td>
</tr>
</tbody></table>
<br />
<p>
It is also possible to interactively edit existing video clips generated by VideoPoet. If we supply an input video, we can change the motion of objects to perform different actions. The object manipulation can be centered at the first frame or the middle frames, which allow for a high degree of editing control.
</p>
<p>
For example, we can randomly generate some clips from the input video and select the desired next clip.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg2NFzPadmS-8v2ShkxFaqage2MkmSopCm17wtoYnVCFufD5GKZHzM9ZUeL4EvCtVLZGMJYiUA1NVJhplymInJr4_K-G9s9263JAVRMxPb9_15zipLZIwHcmYpwyZmGRwgtbFwpONB9CrOqnDmG9bJBvr7pXNbEqLtms_7QNMbSYSspRefkisKLzeWXAIW9/s1280/image10.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="1280" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEg2NFzPadmS-8v2ShkxFaqage2MkmSopCm17wtoYnVCFufD5GKZHzM9ZUeL4EvCtVLZGMJYiUA1NVJhplymInJr4_K-G9s9263JAVRMxPb9_15zipLZIwHcmYpwyZmGRwgtbFwpONB9CrOqnDmG9bJBvr7pXNbEqLtms_7QNMbSYSspRefkisKLzeWXAIW9/s16000/image10.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">An input video on the left is used as conditioning to generate four choices given the initial prompt: “Closeup of an adorable rusty broken-down steampunk robot covered in moss moist and budding vegetation, surrounded by tall grass”. For the first three outputs we show what would happen for unprompted motions. For the last video in the list below, we add to the prompt, “powering up with smoke in the background” to guide the action.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Image to video control</h3>
<p>
Similarly, we can apply motion to an input image to edit its contents towards the desired state, conditioned on a text prompt.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjlvHv0EU7YHj88knDA8pnCs5VDEsHI8OQeWDx8-dhNXcVKteNqMMrMczg9k0j-T8kevUiQua-Eet5BzT9xJ5CR-lpmFjMYyOQFZcAfXu-rlgTmes9_4-GONo7NFHYAw1q-Ivk6MVj34iyjj6KGpQ8dp6OwJH1SKGyxA2BDPvvEUUIJB6UBkvu1c1ILCgdr/s512/image5.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="512" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjlvHv0EU7YHj88knDA8pnCs5VDEsHI8OQeWDx8-dhNXcVKteNqMMrMczg9k0j-T8kevUiQua-Eet5BzT9xJ5CR-lpmFjMYyOQFZcAfXu-rlgTmes9_4-GONo7NFHYAw1q-Ivk6MVj34iyjj6KGpQ8dp6OwJH1SKGyxA2BDPvvEUUIJB6UBkvu1c1ILCgdr/s16000/image5.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Animating a painting with different prompts. <b>Left</b>: “A woman turning to look at the camera.” <b>Right</b>: “A woman yawning.” **</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Camera motion</h3>
<p>
We can also accurately control camera movements by appending the type of desired camera motion to the text prompt. As an example, we generated an image by our model with the prompt, <em>“Adventure game concept art of a sunrise over a snowy mountain by a crystal clear river”</em>. The examples below append the given text suffix to apply the desired motion.
</p>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4w_RIRlokn88vR5u6GdRE32zOr5wKLUphlx4BwiZDQL0J6i-yhyphenhyphensc4wCsJ07izsk_MkLVT-TAfRIqU9sJ4E_cYVRszJ1bw-Ha4jYsv1oBgSMjAENhCPHIvg2aXBIULeH4UzR-0K5AHPixzxBYD7rP11xf4m2s7e1WFUyRHLv-RkKhAC9whQvwUovphkgy/s1536/image2.gif" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="448" data-original-width="1536" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEh4w_RIRlokn88vR5u6GdRE32zOr5wKLUphlx4BwiZDQL0J6i-yhyphenhyphensc4wCsJ07izsk_MkLVT-TAfRIqU9sJ4E_cYVRszJ1bw-Ha4jYsv1oBgSMjAENhCPHIvg2aXBIULeH4UzR-0K5AHPixzxBYD7rP11xf4m2s7e1WFUyRHLv-RkKhAC9whQvwUovphkgy/s16000/image2.gif" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">Prompts from left to right: “Zoom out”, “Dolly zoom”, “Pan left”, “Arc shot”, “Crane shot”, “FPV drone shot”.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h2>Evaluation results</h2>
<p>
We evaluate VideoPoet on text-to-video generation with a variety of benchmarks to compare the results to other approaches. To ensure a neutral evaluation, we ran all models on a wide variation of prompts without cherry-picking examples and asked people to rate their preferences. The figure below highlights the percentage of the time VideoPoet was chosen as the preferred option in green for the following questions.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h3>Text fidelity</h3>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjotT5lhyphenhyphenreDLFlq_hZRSAKKI2jWx9Pp2y1xvBTQflO-H7EQ3VZYmsRTiaTV1creTpMo0It1-IfiFh313zzhhjDPeSxSW3nnRWsQC1toPBSsRlQD0T6UmzFqSNGaeQ0CGBKmn6xyAJXTG3NaF9o0icxag6f_eRzjTvu71gNRB3lOLN4xK8iQWA7dT5FKo2F/s1999/image1.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="800" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjotT5lhyphenhyphenreDLFlq_hZRSAKKI2jWx9Pp2y1xvBTQflO-H7EQ3VZYmsRTiaTV1creTpMo0It1-IfiFh313zzhhjDPeSxSW3nnRWsQC1toPBSsRlQD0T6UmzFqSNGaeQ0CGBKmn6xyAJXTG3NaF9o0icxag6f_eRzjTvu71gNRB3lOLN4xK8iQWA7dT5FKo2F/s16000/image1.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">User preference ratings for text fidelity, i.e., what percentage of videos are preferred in terms of accurately following a prompt.</td></tr></tbody></table>
<div style="line-height: 40%;">
<br />
</div>
<h3>Motion interestingness</h3>
<table align="center" cellpadding="0" cellspacing="0" class="tr-caption-container" style="margin-left: auto; margin-right: auto;"><tbody><tr><td style="text-align: center;"><a href="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIWoXGSpe80GopYbQLJcyISxwM7DVtB-OFr2gqCUC33D3J6aLCS9sE4LKuoXhA89YNZE9yg_VmvUwJg2N1_nKt9m5z2NJgWiM2Ylqs2_Y2nAULojUuwpNmLv7LhYv4aGs4WgffyECcQtKM3Z83bmosuuXvHw4DeekkzAIpCkF2LlN6jExQysy68Ovgmgk1/s1999/image15.png" style="margin-left: auto; margin-right: auto;"><img border="0" data-original-height="797" data-original-width="1999" src="https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjIWoXGSpe80GopYbQLJcyISxwM7DVtB-OFr2gqCUC33D3J6aLCS9sE4LKuoXhA89YNZE9yg_VmvUwJg2N1_nKt9m5z2NJgWiM2Ylqs2_Y2nAULojUuwpNmLv7LhYv4aGs4WgffyECcQtKM3Z83bmosuuXvHw4DeekkzAIpCkF2LlN6jExQysy68Ovgmgk1/s16000/image15.png" /></a></td></tr><tr><td class="tr-caption" style="text-align: center;">User preference ratings for motion interestingness, i.e., what percentage of videos are preferred in terms of producing interesting motion.</td></tr></tbody></table>
<p>
Based on the above, on average people selected 24–35% of examples from VideoPoet as following prompts better than a competing model vs. 8–11% for competing models. Raters also preferred 41–54% of examples from VideoPoet for more interesting motion than 11–21% for other models.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Conclusion</h2>
<p>
Through VideoPoet, we have demonstrated LLMs’ highly-competitive video generation quality across a wide variety of tasks, especially in producing interesting and high quality motions within videos. Our results suggest the promising potential of LLMs in the field of video generation. For future directions, our framework should be able to support “any-to-any” generation, e.g., extending to text-to-audio, audio-to-video, and video captioning should be possible, among many others.
</p>
<p>
To view more examples in original quality, see the <a href="http://sites.research.google/videopoet">website demo</a>.
</p>
<div style="line-height: 40%;">
<br />
</div>
<h2>Acknowledgements</h2>
<p>
<em>This research has been supported by a large body of contributors, including Dan Kondratyuk, Lijun Yu, Xiuye Gu, José Lezama, Jonathan Huang, Rachel Hornung, Hartwig Adam, Hassan Akbari, Yair Alon, Vighnesh Birodkar, Yong Cheng, Ming-Chang Chiu, Josh Dillon, Irfan Essa, Agrim Gupta, Meera Hahn, Anja Hauth, David Hendon, Alonso Martinez, David Minnen, David Ross, Grant Schindler, Mikhail Sirotenko, Kihyuk Sohn, Krishna Somandepalli, Huisheng Wang, Jimmy Yan, Ming-Hsuan Yang, Xuan Yang, Bryan Seybold, and Lu Jiang.</em>
</p>
<p>
<em>We give special thanks to Alex Siegman,Victor Gomes, and Brendan Jou for managing computing resources. We also give thanks to Aren Jansen, Marco Tagliasacchi, Neil Zeghidour, John Hershey for audio tokenization and processing, Angad Singh for storyboarding in “Rookie the Raccoon”, Cordelia Schmid for research discussions, David Salesin, Tomas Izo, and Rahul Sukthankar for their support, and Jay Yagnik as architect of the initial concept.</em>
</p>
<br />
<p>
<em>**</em>
<br />
<em>(a) <a href="https://commons.wikimedia.org/wiki/File:Rembrandt_Christ_in_the_Storm_on_the_Lake_of_Galilee.jpg">The Storm on the Sea of Galilee</a>, by Rembrandt 1633, public domain.</em>
<br />
<em>(b) <a href="https://commons.wikimedia.org/wiki/File:Pillars_of_creation_2014_HST_WFC3-UVIS_full-res.jpg">Pillars of Creation</a>, by NASA 2014, public domain.</em>
<br />
<em>(c) <a href="https://commons.wikimedia.org/wiki/File:Caspar_David_Friedrich_-_Wanderer_above_the_Sea_of_Fog.jpeg">Wanderer above the Sea of Fog</a>, by Caspar David Friedrich, 1818, public domain</em>
<br />
<em>(d) <a href="https://commons.wikimedia.org/wiki/File:Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg">Mona Lisa</a>, by Leonardo Da Vinci, 1503, public domain.</em>
</p>
Google AIhttp://www.blogger.com/profile/12098626514775266161noreply@blogger.com0