Skip to content

Commit

Permalink
Update code to use ui.Image
Browse files Browse the repository at this point in the history
  • Loading branch information
vishnukvmd committed Nov 30, 2023
1 parent 4d0ad75 commit 67200a9
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 34 deletions.
236 changes: 223 additions & 13 deletions example/lib/clip_image_encoder.dart
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'dart:ui';

import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:image/image.dart' as img;
import 'package:onnxruntime_example/processed_image.dart';
import 'package:flutter/painting.dart' as paint show decodeImageFromList;

class ClipImageEncoder {
OrtSessionOptions? _sessionOptions;
Expand All @@ -29,7 +33,7 @@ class ClipImageEncoder {
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
const assetFileName = 'assets/models/visual.onnx';
const assetFileName = 'assets/models/clip_visual.onnx';
final rawAssetFile = await rootBundle.load(assetFileName);
final bytes = rawAssetFile.buffer.asUint8List();
try {
Expand All @@ -54,27 +58,106 @@ class ClipImageEncoder {
_session?.release();
}

inferByImage(String imagePath) {
inferByImage(String imagePath) async {
final runOptions = OrtRunOptions();
final startTime = DateTime.now();

// Change this with path
//final rgb8 = img.Image(width: 784, height: 890, format: img.Format.float32);
final rgb = img.decodeJpg(File(imagePath).readAsBytesSync()) as img.Image;
final inputImage = img.copyResize(rgb,
width: 224, height: 224, interpolation: img.Interpolation.linear);
// Code from Satan
// final rgb8 = img.Image(width: 784, height: 890, format: img.Format.float32);
// final rgb = img.decodePng(File(imagePath).readAsBytesSync()) as img.Image;
// final inputImage = img.copyResize(rgb,
// width: 224, height: 224, interpolation: img.Interpolation.linear);

// Hard coded Input
// final processedImage = getProcessedImage();

final image =
await paint.decodeImageFromList(File(imagePath).readAsBytesSync());
final resizedImage =
await resizeImage(image, 224, 224, maintainAspectRatio: true);
final croppedImage = await cropImage(
resizedImage.$1,
x: 0,
y: 0,
width: 224,
height: 224,
);
final mean = [0.48145466, 0.4578275, 0.40821073];
final std = [0.26862954, 0.26130258, 0.27577711];
final processedImage = imageToByteListFloat32(inputImage, 224, mean, std);
final ByteData imgByteData = await getByteDataFromImage(image);
final processedImage =
imageToByteListFloat32(croppedImage, imgByteData, 224, mean, std);

final inputOrt = OrtValueTensor.createTensorWithDataList(
processedImage, [1, 3, 224, 224]);
final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs);
print((outputs?[0]?.value as List<List<double>>)[0]);
final result = (outputs?[0]?.value as List<List<double>>)[0];
final endTime = DateTime.now();
print((endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms");
print(result.toString());
}

Future<(Image, Size)> resizeImage(
Image image,
int width,
int height, {
FilterQuality quality = FilterQuality.medium,
bool maintainAspectRatio = false,
}) async {
if (image.width == width && image.height == height) {
return (image, Size(width.toDouble(), height.toDouble()));
}
final recorder = PictureRecorder();
final canvas = Canvas(
recorder,
Rect.fromPoints(
const Offset(0, 0),
Offset(width.toDouble(), height.toDouble()),
),
);

double scaleW = width / image.width;
double scaleH = height / image.height;
if (maintainAspectRatio) {
final scale = min(width / image.width, height / image.height);
scaleW = scale;
scaleH = scale;
}
final scaledWidth = (image.width * scaleW).round();
final scaledHeight = (image.height * scaleH).round();

canvas.drawImageRect(
image,
Rect.fromPoints(
const Offset(0, 0),
Offset(image.width.toDouble(), image.height.toDouble()),
),
Rect.fromPoints(
const Offset(0, 0),
Offset(scaledWidth.toDouble(), scaledHeight.toDouble()),
),
Paint()..filterQuality = quality,
);

final picture = recorder.endRecording();
final resizedImage = await picture.toImage(width, height);
return (
resizedImage,
Size(scaledWidth.toDouble(), scaledHeight.toDouble())
);
}

Float32List imageToByteListFloat32(
img.Image image, int inputSize, List<double> mean, List<double> std) {
Image image,
ByteData data,
int inputSize,
List<double> mean,
List<double> std,
) {
var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
var buffer = Float32List.view(convertedBytes.buffer);
int pixelIndex = 0;
Expand All @@ -83,12 +166,139 @@ class ClipImageEncoder {

for (var i = 0; i < inputSize; i++) {
for (var j = 0; j < inputSize; j++) {
var pixel = image.getPixel(j, i);
buffer[pixelIndex++] = (pixel.r - mean[0]) / std[0];
buffer[pixelIndex++] = (pixel.g - mean[1]) / std[1];
buffer[pixelIndex++] = (pixel.b - mean[2]) / std[2];
var pixel = readPixelColor(image, data, j, i);
buffer[pixelIndex++] = (pixel.red - mean[0]) / std[0];
buffer[pixelIndex++] = (pixel.green - mean[1]) / std[1];
buffer[pixelIndex++] = (pixel.blue - mean[2]) / std[2];
}
}
return convertedBytes.buffer.asFloat32List();
}

Color readPixelColor(
Image image,
ByteData byteData,
int x,
int y,
) {
if (x < 0 || x >= image.width || y < 0 || y >= image.height) {
// throw ArgumentError('Invalid pixel coordinates.');
return const Color(0x00000000);
}
assert(byteData.lengthInBytes == 4 * image.width * image.height);

final int byteOffset = 4 * (image.width * y + x);
return Color(_rgbaToArgb(byteData.getUint32(byteOffset)));
}

int _rgbaToArgb(int rgbaColor) {
final int a = rgbaColor & 0xFF;
final int rgb = rgbaColor >> 8;
return rgb + (a << 24);
}

Future<Image> cropImage(
Image image, {
required double x,
required double y,
required double width,
required double height,
Size? maxSize,
Size? minSize,
double rotation = 0.0, // rotation in radians
FilterQuality quality = FilterQuality.medium,
}) async {
// Calculate the scale for resizing based on maxSize and minSize
double scaleX = 1.0;
double scaleY = 1.0;
if (maxSize != null) {
final minScale = min(maxSize.width / width, maxSize.height / height);
if (minScale < 1.0) {
scaleX = minScale;
scaleY = minScale;
}
}
if (minSize != null) {
final maxScale = max(minSize.width / width, minSize.height / height);
if (maxScale > 1.0) {
scaleX = maxScale;
scaleY = maxScale;
}
}

// Calculate the final dimensions
final targetWidth = (width * scaleX).round();
final targetHeight = (height * scaleY).round();

// Create the canvas
final recorder = PictureRecorder();
final canvas = Canvas(
recorder,
Rect.fromPoints(
const Offset(0, 0),
Offset(targetWidth.toDouble(), targetHeight.toDouble()),
),
);

// Apply rotation
final center = Offset(targetWidth / 2, targetHeight / 2);
canvas.translate(center.dx, center.dy);
canvas.rotate(rotation);

// Enlarge both the source and destination boxes to account for the rotation (i.e. avoid cropping the corners of the image)
final List<double> enlargedSrc =
getEnlargedAbsoluteBox([x, y, x + width, y + height], 1.5);
final List<double> enlargedDst = getEnlargedAbsoluteBox(
[
-center.dx,
-center.dy,
-center.dx + targetWidth,
-center.dy + targetHeight,
],
1.5,
);

canvas.drawImageRect(
image,
Rect.fromPoints(
Offset(enlargedSrc[0], enlargedSrc[1]),
Offset(enlargedSrc[2], enlargedSrc[3]),
),
Rect.fromPoints(
Offset(enlargedDst[0], enlargedDst[1]),
Offset(enlargedDst[2], enlargedDst[3]),
),
Paint()..filterQuality = quality,
);

final picture = recorder.endRecording();

return picture.toImage(targetWidth, targetHeight);
}

List<double> getEnlargedAbsoluteBox(List<double> box, [double factor = 2]) {
final boxCopy = List<double>.from(box, growable: false);
// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox].

final width = boxCopy[2] - boxCopy[0];
final height = boxCopy[3] - boxCopy[1];

boxCopy[0] -= width * (factor - 1) / 2;
boxCopy[1] -= height * (factor - 1) / 2;
boxCopy[2] += width * (factor - 1) / 2;
boxCopy[3] += height * (factor - 1) / 2;

return boxCopy;
}

Future<ByteData> getByteDataFromImage(
Image image, {
ImageByteFormat format = ImageByteFormat.rawRgba,
}) async {
final ByteData? byteDataRgba = await image.toByteData(format: format);
if (byteDataRgba == null) {
throw Exception('Could not convert image to ByteData');
}
return byteDataRgba;
}
}
25 changes: 4 additions & 21 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class _MyAppState extends State<MyApp> {

_clipImageEncoder = ClipImageEncoder();
_clipImageEncoder?.initModel();
_clipTextEncoder = ClipTextEncoder();
_clipTextEncoder?.initModel();
// _clipTextEncoder = ClipTextEncoder();
// _clipTextEncoder?.initModel();
}

@override
Expand Down Expand Up @@ -127,26 +127,9 @@ class _MyAppState extends State<MyApp> {
// final endTime = DateTime.now().millisecondsSinceEpoch;
// print('infer cost time=${endTime - startTime}ms');
//_clipTextEncoder?.infer();
const imgPath = "assets/images/cycle.jpg";
final path = await getAccessiblePathForAsset(imgPath, "test.jpg");
const imgPath = "assets/images/astro.png";
final path = await getAccessiblePathForAsset(imgPath, "test.png");
_clipImageEncoder?.inferByImage(path);
const windowByteCount = frameSize * 2 * RecordManager.sampleRate ~/ 1000;
final bytes = await File(_pcmPath!).readAsBytes();
var start = 0;
var end = start + windowByteCount;
List<int> frameBuffer;
final startTime = DateTime.now().millisecondsSinceEpoch;
while (end <= bytes.length) {
frameBuffer = bytes.sublist(start, end).toList();
final floatBuffer =
_transformBuffer(frameBuffer).map((e) => e / 32768).toList();
await _vadIterator?.predict(Float32List.fromList(floatBuffer));
start += windowByteCount;
end = start + windowByteCount;
}
_vadIterator?.reset();
final endTime = DateTime.now().millisecondsSinceEpoch;
print('vad cost time=${endTime - startTime}ms');
}

Int16List _transformBuffer(List<int> buffer) {
Expand Down
12 changes: 12 additions & 0 deletions example/lib/processed_image.dart

Large diffs are not rendered by default.

0 comments on commit 67200a9

Please sign in to comment.