Need for Speed : JAX

Les contributeurs
Jasper Van den Bossche
Ingénieur logiciel
Aucun élément trouvé.
S'abonner à la newsletter
Partager cet article

Entraînez votre réseau neuronal dix fois plus vite en utilisant Jax sur un TPU

Tous les jeunes branchés semblent s'extasier sur JAX ces jours-ci. Deepmind l'utilise largement pour ses recherches et construit même son propre écosystème. Boris Dayma et son équipe ont construit DALL-E Mini en un rien de temps en utilisant JAX et des TPU, ce qui vaut vraiment la peine d'être vérifié sur Hugging Face où l'on trouve déjà plus de 5000 modèles écrits en JAX. Mais qu'est-ce que JAX exactement et pourquoi est-il si spécial ? Selon leur site web, JAX offre une différenciation automatique, une vectorisation et une compilation juste à temps pour les GPU et les TPU via des transformations composables. Cela vous semble compliqué ? Ne vous inquiétez pas, dans ce blogpost nous allons vous faire visiter et vous montrer comment JAX fonctionne, en quoi il est différent de Tensorflow/Pytorch et pourquoi nous pensons que c'est un framework super intéressant.

Qu'est-ce que JAX ?

JAX est un cadre de calcul numérique et d'apprentissage automatique à haute performance de Google Research qui fonctionne très rapidement sur les GPU et les TPU, sans avoir à se préoccuper des détails de bas niveau. L'objectif de JAX était de construire un cadre qui combine la haute performance avec l'expressivité et la facilité d'utilisation de Python, afin que les chercheurs puissent expérimenter de nouveaux modèles et de nouvelles techniques sans avoir besoin d'implémentations C/C++ de bas niveau hautement optimisées. Il atteint cet objectif en utilisant le compilateur XLA (Accelerated Linear Algebra) de Google pour générer un code machine efficace plutôt que d'utiliser des noyaux précompilés. L'un des aspects intéressants de JAX est qu'il est indépendant de l'accélérateur, ce qui signifie que le même code Python peut être exécuté efficacement sur les GPU et les TPU.

JAX fonctionne par le biais de transformations de fonctions composables, ce qui signifie que JAX prend une fonction et produit une nouvelle fonction qui est interprétée différemment et que plusieurs transformations peuvent être enchaînées. La différenciation automatique, par exemple, est une transformation qui génère la dérivée d'une fonction, tandis que la vectorisation automatique prend une fonction qui opère sur un seul point de données et la transforme en une fonction qui opère sur un lot de points de données. Grâce à ces transformations, JAX permet au programmeur de rester dans le monde Python de haut niveau et de laisser le compilateur faire le travail difficile en générant le code hautement efficace nécessaire pour entraîner des modèles complexes. Nous allons passer en revue ces transformations et les appliquer dans un exemple où nous construisons un simple perceptron multicouche.

Vue d'ensemble schématique des fonctions composables.

Qu'est-ce qui le différencie de Tensorflow et Pytorch ?

JAX est un cadre orienté compilateur, ce qui signifie qu'un compilateur est chargé de transformer les fonctions Python en code machine efficace. Tensorflow et Pytorch, quant à eux, disposent de noyaux GPU et TPU précompilés pour chaque opération. Lors de l'exécution d'un programme TensorFlow, chaque opération est exécutée individuellement. Bien que les opérations elles-mêmes soient très bien optimisées, leur fusion nécessite un grand nombre d'opérations de mémoire, ce qui entraîne un goulot d'étranglement au niveau des performances. Le compilateur XLA peut générer du code pour l'ensemble de la fonction. Il peut utiliser toutes ces informations pour fusionner les opérations et économiser une tonne d'opérations de mémoire, ce qui permet de générer un code globalement plus rapide.

JAX est également plus léger que Tensorflow et Pytorch, car il n'est pas nécessaire d'implémenter chaque opération, fonction ou modèle séparément. Au lieu de cela, JAX met en œuvre l'API NumPy avec des opérations plus simples et de plus bas niveau qui peuvent être utilisées comme blocs de construction et fusionnées dans des modèles et des fonctions complexes par le compilateur.

Le compilateur, bien plus puissant que vous ne le pensez

La conception orientée compilateur est beaucoup plus puissante qu'il n'y paraît à première vue. Grâce au compilateur, il n'est plus nécessaire d'implémenter un code d'accélérateur de bas niveau. Il permet aux chercheurs d'améliorer considérablement leur productivité et ouvre la voie à l'expérimentation de nouvelles architectures de modèles. Les chercheurs peuvent même expérimenter les GPU et les TPU sans avoir à réécrire leur code. Mais comment cela fonctionne-t-il ?

JAX ne compile pas directement le code machine, mais plutôt une représentation intermédiaire indépendante du code Python de haut niveau et du code machine. Le compilateur est divisé en un frontend qui compile les fonctions Python en IR et un backend qui compile l'IR en code machine spécifique à la plateforme. Cette conception n'est pas nouvelle, un exemple de compilateur qui suit également cette conception est LLVM. Il existe des frontends pour C et Rust qui traduisent le code de haut niveau en IR LLVM. Le backend peut alors générer du code machine pour une variété de types de machines supportées, que le code original ait été écrit en C ou en Rust.

JAX transforme le code Python de haut niveau en une représentation intermédiaire qui est ensuite compilée par XLA en un programme qui s'exécute sur votre type de machine.

C'est énorme, car grâce à cette conception flexible, on peut construire un nouvel accélérateur, écrire un backend XLA pour lui et votre code JAX qui s'exécutait auparavant sur des GPU/TPU peut être exécuté sur le nouvel accélérateur. D'autre part, vous pouvez également créer un cadre dans un autre langage de programmation qui se compile à l'IR JAX et vous pouvez utiliser les GPU et les TPU grâce à XLA.

Si cette approche basée sur un compilateur fonctionne tellement mieux que les noyaux précompilés, pourquoi Tensorflow et Pytorch ne l'ont-ils pas utilisée dès le départ ? La réponse est assez simple, il est vraiment difficile de concevoir un bon compilateur numérique. Avec sa différenciation automatique, sa vectorisation et sa compilation jit, JAX dispose d'outils vraiment puissants. Cependant, JAX n'est pas non plus la solution miracle, tous ces avantages ont un petit prix, vous devez apprendre quelques nouveaux trucs et concepts liés à la programmation fonctionnelle.

Une note sur la programmation fonctionnelle

JAX ne peut pas transformer n'importe quelle fonction Python, il ne peut transformer que des fonctions pures. Une fonction pure peut être définie comme une fonction qui ne dépend que de ses entrées, ce qui signifie que pour une entrée x donnée, elle renverra toujours la même sortie y et qu'elle ne produit pas d'effets secondaires tels que des opérations IO ou la mutation de variables globales. Le dynamisme de Python signifie que le comportement d'une fonction change en fonction des types de ses entrées et JAX veut exploiter ce dynamisme en transformant les fonctions au moment de l'exécution. Au début d'une transformation, JAX vérifie ce que fait la fonction pour un ensemble d'entrées données et transforme la fonction sur la base de ces informations. Sous le capot, JAX trace la fonction, tout comme l'interpréteur Python. En n'autorisant que les fonctions pures, la transformation des fonctions juste à temps devient beaucoup plus facile et rapide.

Imaginez que le traceur doive traiter des effets de bord tels que les entrées-sorties, ce qui signifie qu'un comportement inattendu peut se produire, par exemple un utilisateur qui a saisi des données non valides, ce qui rend beaucoup plus difficile la génération d'un code efficace, en particulier lorsque des accélérateurs sont en jeu. Les variables globales peuvent changer entre deux appels de fonction et donc modifier complètement le comportement de la fonction dans laquelle elles sont utilisées, ce qui rend une fonction transformée invalide. Si vous vous intéressez aux compilateurs et aux détails de la façon dont fonctionne le traçage de JAX, nous vous recommandons de consulter la documentation pour plus de détails sur son fonctionnement interne.

Le seul inconvénient de JAX est qu'il ne peut pas vérifier si une fonction est une fonction pure. C'est au programmeur de s'assurer qu'il écrit des fonctions pures, sinon JAX transformera la fonction avec un comportement inattendu.

Représentation d'objets avec état à l'aide de Pytrees

Le fait de travailler avec des fonctions pures a également un impact sur la manière dont les structures de données sont utilisées. Dans d'autres cadres, les modèles d'apprentissage automatique sont souvent représentés avec un état, ce qui est contraire au paradigme de la programmation fonctionnelle car il s'agit d'une mutation d'un état global. Pour surmonter ce problème, JAX introduit les pytrees, des structures arborescentes construites à partir d'objets Python de type conteneur. Les classes de type conteneur peuvent être enregistrées dans le registre pytree, qui contient par défaut des listes, des tuples et des dicts. Les pytrees peuvent contenir d'autres pytrees et les classes non enregistrées dans le registre pytree sont considérées comme des feuilles. Les feuilles peuvent être considérées comme des entrées immuables pour une fonction pure. Pour chaque classe du registre pytree, il existe une fonction qui convertit un pytree en un tuple avec ses enfants et des métadonnées optionnelles, ainsi qu'une fonction qui reconvertit les enfants et les métadonnées en un type de conteneur. Ces fonctions peuvent être utilisées pour mettre à jour le modèle ou tout autre objet avec état que vous utilisez.

Transformons un peu de code !

Avant de nous plonger dans notre exemple de MLP, nous allons montrer les transformations les plus importantes de JAX.

Différenciation automatique

La première transformation est la différenciation automatique, où nous prenons une fonction Python en entrée et renvoyons une fonction qui représente le gradient de cette fonction. Ce qui est intéressant avec l'autodiffusion de JAX, c'est qu'elle peut différencier les fonctions Python qui utilisent des conteneurs Python, des conditionnelles, des boucles, etc. Dans l'exemple suivant, nous créons une fonction qui représente le gradient de la fonction `tanh`. Comme les transformations JAX sont composables, nous pouvons utiliser n appels imbriqués de la fonction grad pour calculer la nièmedérivée.

La différenciation automatique de JAX est un outil puissant et complet. Si vous souhaitez en savoir plus sur son fonctionnement, nous vous recommandons de lire The JAX Autodiff Cookbook.

Vectorisation automatique

Lors de l'apprentissage d'un modèle, vous propagez généralement un lot d'échantillons d'apprentissage dans votre modèle. Lors de la mise en œuvre d'un modèle, vous devez donc considérer votre fonction de prédiction comme une fonction qui prend en charge un lot d'échantillons et renvoie une prédiction pour chaque échantillon. Cela peut toutefois accroître considérablement la difficulté de la mise en œuvre et réduire la lisibilité de la fonction par rapport à une fonction qui fonctionnerait sur un seul échantillon. La deuxième transformation intervient : la vectorisation automatique. Nous écrivons notre fonction comme si nous ne traitions qu'un seul échantillon, puis vmap la transforme en une version vectorisée.

Au début, vmap peut s'avérer un peu difficile, en particulier lorsqu'il s'agit de travailler avec des dimensions supérieures, mais il s'agit d'une transformation vraiment puissante. Nous vous recommandons de consulter quelques exemples dans la documentation pour comprendre pleinement son potentiel.

Compilation Just-In-Time

La troisième transformation de fonction est la compilation juste à temps. L'objectif de cette transformation est d'améliorer les performances, de paralléliser le code et de l'exécuter sur un accélérateur. JAX ne compile pas directement en code machine mais plutôt en une représentation intermédiaire. Cette représentation intermédiaire est indépendante du code Python et du code machine de l'accélérateur. Le compilateur XLA prend alors la représentation intermédiaire et la compile en code machine efficace.

Il n'est pas toujours facile de décider quand et quel code compiler, afin d'utiliser le compilateur de manière optimale, nous vous recommandons de consulter la documentation. Plus loin dans ce blog, nous approfondirons la conception du compilateur et expliquerons pourquoi JAX est un framework si puissant.

Formation d'un MLP en 5 minutes à l'aide d'une TPU

Maintenant que nous avons appris les transformations les plus importantes, nous sommes prêts à mettre ces connaissances en pratique. Nous allons implémenter un MLP à partir de zéro pour classer les images MNIST et l'entraîner très rapidement sur une TPU. Notre réseau neuronal aura une couche d'entrée de 728 variables d'entrée, suivie de deux couches cachées de 512 et 256 neurones respectivement et d'une couche de sortie avec un nœud pour chaque classe.

Initialisation du modèle

La première chose à faire est de créer une structure qui représente notre modèle. En entrée de notre fonction d'initialisation, nous disposons d'une liste contenant le nombre de nœuds de chaque couche de notre réseau neuronal. Nous avons une couche d'entrée qui est égale au nombre de pixels d'une image, suivie de deux couches cachées avec 512 et 256 neurones respectivement et une couche de sortie qui est égale au nombre de classes. Nous utilisons des tableaux JAX numpy pour initialiser le modèle sur l'accélérateur, évitant ainsi de copier manuellement ces données.

Notez que la génération de nombres aléatoires est légèrement différente de numpy. Nous voulons pouvoir générer des nombres aléatoires sur des accélérateurs parallèles et nous avons besoin d'un générateur de nombres aléatoires qui fonctionne bien avec le paradigme de la programmation fonctionnelle. L'algorithme de Numpy pour générer des nombres aléatoires n'est pas très adapté à ces objectifs. Consultez les notes de conception et la documentation JAX pour plus d'informations.

Prédiction

Notre prochaine étape consiste à écrire une fonction de prédiction qui attribuera des étiquettes à un lot d'images. Nous utiliserons la vectorisation automatique pour transformer une fonction qui prend une seule image en entrée et produit une étiquette en une fonction qui prédit des étiquettes pour un lot d'entrées. L'écriture d'une fonction de prédiction n'est pas super difficile, nous passons par les couches cachées du réseau et appliquons des poids et des biais via une multiplication de matrice et une addition de vecteur et nous appliquons la fonction d'activation RELU. À la fin, nous calculons l'étiquette de sortie à l'aide de la fonction RealSoftMax. Une fois que nous avons notre fonction pour étiqueter une seule image, nous pouvons la transformer en utilisant vmap pour qu'elle puisse traiter un lot d'entrées.

Fonction de perte

La fonction de perte prend un lot d'images et calcule l'erreur absolue moyenne. Nous appelons nos prédictions par lot et calculons l'étiquette pour chaque image, nous la comparons aux étiquettes de la vérité de base codées à un coup et nous calculons le nombre moyen d'erreurs.

Fonction de mise à jour

Maintenant que nous disposons de notre fonction de prédiction et de perte, nous allons mettre en œuvre une fonction de mise à jour pour mettre à jour notre modèle de manière itérative à chaque étape de l'apprentissage. Notre fonction de mise à jour prend en compte un lot d'images et ses étiquettes de vérité terrain, ainsi que le modèle actuel et un taux d'apprentissage. Nous calculons à la fois la valeur de la perte et la valeur de son gradient. Nous mettons à jour le modèle en utilisant le taux d'apprentissage et les gradients de perte. Comme nous voulons compiler cette fonction, nous devons convertir le modèle mis à jour en un pytree. Nous renvoyons également la valeur de la perte pour contrôler la précision.

Maintenant que nous avons la fonction de mise à jour, nous allons la compiler pour qu'elle puisse fonctionner sur une TPU et améliorer considérablement ses performances. Les fonctions imbriquées appelées dans la fonction update seront également compilées et optimisées. La raison pour laquelle nous n'appliquons la transformation de compilation qu'à la fonction update et non à chaque fonction séparément est que nous voulons donner au compilateur autant d'informations que possible pour qu'il puisse optimiser le code autant que possible.

Former notre modèle

Nous pouvons définir une fonction de précision (et éventuellement d'autres mesures) et créer une boucle d'apprentissage en utilisant notre fonction de mise à jour et notre modèle initial en entrée. Nous sommes maintenant prêts à entraîner notre modèle à l'aide d'une TPU ou d'une GPU.

Conclusion

Nous avons beaucoup appris aujourd'hui. Tout d'abord, nous avons commencé à décrire JAX comme un cadre avec des transformations de fonctions composables. Les quatre transformations principales sont la vectorisation automatique, la parallélisation automatique sur plusieurs accélérateurs, la différenciation automatique des fonctions python et la compilation JIT des fonctions pour les exécuter sur des accélérateurs. Nous avons approfondi le fonctionnement interne de JAX et appris comment il est capable de créer des fonctions aussi efficaces qui fonctionnent à la fois sur les GPU et les TPU en compilant vers un IR qui est ensuite transformé en appels XLA. Cette approche permet aux chercheurs d'expérimenter de nouvelles techniques d'apprentissage automatique sans avoir à se soucier d'une version de bas niveau et hautement optimisée de leur code. Nous espérons que les ingénieurs en logiciel seront également enthousiastes, afin que de nouvelles bibliothèques puissent être créées à partir de JAX et que les accélérateurs potentiels puissent être rapidement adoptés.

Postes connexes

Voir tout le contenu
Aucun résultat n'a été trouvé.
Il n'y a pas de résultats correspondant à ces critères. Essayez de modifier votre recherche.
Grand modèle linguistique
Modèles de fondation
Entreprise
Personnes
Données Structurées
Chat GPT
Durabilité
Voix et son
Développement frontal
Protection des données et sécurité
IA responsable/éthique
Infrastructure
Hardware et capteurs
MLOps
IA générative
Natural Language Processing
Vision par ordinateur