Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Update react-native-pytorch-core README.md
Browse files Browse the repository at this point in the history
Summary: Update the example code in the README.md for the `react-native-pytorch-core` package to use the new JSI API.

Reviewed By: chrisklaiber

Differential Revision: D36882398

fbshipit-source-id: 047aeaf18d69b8a014c6bc22b13b6bf999681f63
  • Loading branch information
raedle authored and facebook-github-bot committed Jun 2, 2022
1 parent 9777014 commit 5ab2b3a
Showing 1 changed file with 124 additions and 32 deletions.
156 changes: 124 additions & 32 deletions react-native-pytorch-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,139 @@ The full documentation for PyTorch Live can be found on our [website](https://py
## Example Usage

```javascript
// Import dependencies
import * as React from 'react';
import {Button, StyleSheet, Text, View} from 'react-native';

import {MobileModel, ImageUtil} from 'react-native-pytorch-core';

// Have a look at how to prepare a model for PyTorch Live
// https://pytorch.org/live/docs/tutorials/prepare-custom-model/
const model = require('./mobilenet_v3_small.ptl');

// JSON array of classes that map the max idx to a class label
const IMAGE_CLASSES = require('./image_classes.json');

type ImageClassificationResult = {
maxIdx: number;
confidence: number;
};

export default function ImageClassificaion() {
const [topClass, setTopClass] = React.useState<string>('');
async function classifyImage() {
const image = await ImageUtil.fromURL('https://pytorch.org/example.jpg');
const {metrics, result} =
await MobileModel.execute<ImageClassificationResult>(model, {
image,
});

console.log(metrics);
if (result.confidence > 0.7) {
setTopClass(IMAGE_CLASSES[result.maxIdx]);
} else {
setTopClass('low confidence');
import { StyleSheet, Text, View } from 'react-native';
import {
Camera,
Image,
media,
MobileModel,
Module,
Tensor,
torch,
torchvision,
} from 'react-native-pytorch-core';
import { useSafeAreaInsets } from 'react-native-safe-area-context';

// Alias for torchvision transforms
const T = torchvision.transforms;

// URL to the image classification model that is used int his example
const MODEL_URL =
'https://github.com/pytorch/live/releases/download/v0.1.0/mobilenet_v3_small.ptl';

// URL to the ImageNetClasses JSON file, which is used below to map the
// processed model result to a class label
const IMAGENET_CLASSES_URL =
'https://github.com/pytorch/live/releases/download/v0.1.0/ImageNetClasses.json';

// Variable to hold a reference to the loaded ML model
let model: Module | null = null;

// Variable to hold a reference to the ImageNet classes
let imageNetClasses: string[] | null = null;

// App function to render a camera and a text
export default function App() {
// Safe area insets to compensate for notches and bottom bars
const insets = useSafeAreaInsets();
// Create a React state to store the top class returned from the
// classifyImage function
const [topClass, setTopClass] = React.useState(
"Press capture button to classify what's in the camera view!",
);

// Function to handle images whenever the user presses the capture button
async function handleImage(image: Image) {
// Get image width and height
const width = image.getWidth();
const height = image.getHeight();

// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [height, width, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// Crop the image in the center to be a squared image
const centerCrop = T.centerCrop(Math.min(width, height));
tensor = centerCrop(tensor);

// Resize the image tensor to 3 x 224 x 224
const resize = T.resize(224);
tensor = resize(tensor);

// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// Unsqueeze adds 1 leading dimension to the tensor
tensor = tensor.unsqueeze(0);

// If the model has not been loaded already, it will be downloaded from
// the URL and then loaded into memory.
if (model === null) {
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
}

// Run the ML inference with the pre-processed image tensor
const output = await model.forward<Tensor, Tensor>(tensor);

// Get the index of the value with the highest probability
const maxIdx = output.argmax().item();

if (imageNetClasses === null) {
const response = await fetch(IMAGENET_CLASSES_URL);
imageNetClasses = (await response.json()) as string[];
}

// Resolve the most likely class label and return it
const result = imageNetClasses[maxIdx];

// Set result as top class label state
setTopClass(result);

// Release the image from memory
image.release();
}

return (
<View style={StyleSheet.absoluteFill}>
<Button title="Classify Image" onPress={classifyImage} />
<Text>{topClass}</Text>
{/* Render camara and make it parent filling */}
<Camera
style={[StyleSheet.absoluteFill, { bottom: insets.bottom }]}
// Add handle image callback on the camera component
onCapture={handleImage}
/>
{/* Label container with custom render style and a text */}
<View style={styles.labelContainer}>
{/* Change the text to render the top class label */}
<Text>{topClass}</Text>
</View>
</View>
);
}

// Custom render style for label container
const styles = StyleSheet.create({
labelContainer: {
padding: 20,
margin: 20,
marginTop: 40,
borderRadius: 10,
backgroundColor: 'white',
},
});
```

## 👏 How to Contribute
Expand Down

0 comments on commit 5ab2b3a

Please sign in to comment.