diff --git a/calculate_average_yourwass.sh b/calculate_average_yourwass.sh index 07284ba76..50e31fb0b 100755 --- a/calculate_average_yourwass.sh +++ b/calculate_average_yourwass.sh @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash # # Copyright 2023 The original authors # @@ -19,5 +19,8 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" # sdk use java 21.0.1-graal 1>&2 -JAVA_OPTS="--enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass +JAVA_OPTS="-Xlog:all=off -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --enable-preview --enable-native-access=ALL-UNNAMED --add-modules jdk.incubator.vector" + +eval "exec 3< <({ java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_yourwass; })" +read <&3 result +echo -e "$result" diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java index 0a24b0a7e..ad57b5004 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_yourwass.java @@ -16,6 +16,8 @@ package dev.morling.onebrc; import java.util.TreeMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -31,18 +33,15 @@ import sun.misc.Unsafe; public class CalculateAverage_yourwass { - static final class Record { - public String city; - public long cityAddr; - public long cityLength; - public int min; - public int max; - public int count; - public long sum; + private long cityAddr; + private long cityLength; + private int min; + private int max; + private int count; + private long sum; Record(final long cityAddr, final long cityLength) { - this.city = null; this.cityAddr = cityAddr; this.cityLength = cityLength; this.min = 1000; @@ -62,6 +61,8 @@ private Record merge(Record r) { } } + private final static Lock _mutex = new ReentrantLock(true); + private final static TreeMap aggregateResults = new TreeMap<>(); private static short lookupDecimal[]; private static byte lookupFraction[]; private static byte lookupDotPositive[]; @@ -70,6 +71,8 @@ private Record merge(Record r) { private static final VectorSpecies SPECIES = ByteVector.SPECIES_PREFERRED; private static final int MAXINDEX = (1 << 16) + 10000; // short hash + max allowed cities for collisions at the end :p private static final String FILE = "measurements.txt"; + private static long unsafeResults; + private static int RECORDSIZE = 36; private static final Unsafe UNSAFE = getUnsafe(); private static Unsafe getUnsafe() { @@ -113,11 +116,9 @@ public static void main(String[] args) throws IOException, Throwable { } // open file - final long fileSize, mmapAddr; - try (var fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ)) { - fileSize = fileChannel.size(); - mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); - } + final FileChannel fileChannel = FileChannel.open(Path.of(FILE), StandardOpenOption.READ); + final long fileSize = fileChannel.size(); + final long mmapAddr = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileSize, Arena.global()).address(); // VAS: Virtual Address Space, as a MemorySegment upto and including the mmaped file. // If the mmaped MemorySegment is used for Vector creation as is, then there are two problems: // 1) fromMemorySegment takes an offset and not an address, so we have to do arithmetic @@ -127,36 +128,24 @@ public static void main(String[] args) throws IOException, Throwable { // XXX there lies the possibility for an out of bounds read at the end of file, which is not handled here. VAS = MemorySegment.ofAddress(0).reinterpret(mmapAddr + fileSize + SPECIES.length()); - // start and wait for threads to finish + // allocate memory for results final int nThreads = Runtime.getRuntime().availableProcessors(); + unsafeResults = UNSAFE.allocateMemory(RECORDSIZE * MAXINDEX * nThreads); + UNSAFE.setMemory(unsafeResults, RECORDSIZE * MAXINDEX * nThreads, (byte) 0); + + // start and wait for threads to finish Thread[] threadList = new Thread[nThreads]; - final Record[][] results = new Record[nThreads][]; final long chunkSize = fileSize / nThreads; for (int i = 0; i < nThreads; i++) { final int threadIndex = i; final long startAddr = mmapAddr + i * chunkSize; final long endAddr = (i == nThreads - 1) ? mmapAddr + fileSize : mmapAddr + (i + 1) * chunkSize; - threadList[i] = new Thread(() -> results[threadIndex] = threadMain(threadIndex, startAddr, endAddr, nThreads)); + threadList[i] = new Thread(() -> threadMain(threadIndex, startAddr, endAddr, nThreads)); threadList[i].start(); } for (int i = 0; i < nThreads; i++) threadList[i].join(); - // aggregate results and sort - // TODO have to compare with concurrent-parallel stream structures: - // * concurrent hashtable that have to sort afterwards - // * concurrent skiplist that is sorted but has O(n) insert - // * ..other? - final TreeMap aggregateResults = new TreeMap<>(); - for (int thread = 0; thread < nThreads; thread++) { - for (int index = 0; index < MAXINDEX; index++) { - Record record = results[thread][index]; - if (record == null) - continue; - aggregateResults.compute(record.city, (k, v) -> (v == null) ? record : v.merge(record)); - } - } - // prepare string and print StringBuilder sb = new StringBuilder(); sb.append("{"); @@ -167,12 +156,13 @@ public static void main(String[] args) throws IOException, Throwable { float max = record.max; max /= 10.f; double avg = Math.round((record.sum * 1.0) / record.count) / 10.; - sb.append(record.city).append("=").append(min).append("/").append(avg).append("/").append(max).append(", "); + sb.append(entry.getKey()).append("=").append(min).append("/").append(avg).append("/").append(max).append(", "); } int stringLength = sb.length(); sb.setCharAt(stringLength - 2, '}'); sb.setCharAt(stringLength - 1, '\n'); System.out.print(sb.toString()); + System.out.close(); } private static final boolean citiesDiffer(final long a, final long b, final long len) { @@ -185,7 +175,7 @@ private static final boolean citiesDiffer(final long a, final long b, final long return false; } - private static Record[] threadMain(int id, long startAddr, long endAddr, long nThreads) { + private static void threadMain(int id, long startAddr, long endAddr, long nThreads) { // snap to newlines if (id != 0) while (UNSAFE.getByte(startAddr++) != '\n') @@ -194,23 +184,24 @@ private static Record[] threadMain(int id, long startAddr, long endAddr, long nT while (UNSAFE.getByte(endAddr++) != '\n') ; + final long threadResults = unsafeResults + id * MAXINDEX * RECORDSIZE; final Record[] results = new Record[MAXINDEX]; final long VECTORBYTESIZE = SPECIES.length(); final ByteOrder BYTEORDER = ByteOrder.nativeOrder(); final ByteVector delim = ByteVector.broadcast(SPECIES, ';'); - long nextCityAddr = startAddr; // XXX from these three variables, - long cityAddr = nextCityAddr; // only two are necessary, but if one - long ptr = 0; // is eliminated, on my pc the benchmark gets worse.. - while (nextCityAddr < endAddr) { + long cityAddr = startAddr; + long ptr = 0; + while (cityAddr < endAddr) { // parse city - long mask = ByteVector.fromMemorySegment(SPECIES, VAS, nextCityAddr + ptr, BYTEORDER) - .compare(VectorOperators.EQ, delim).toLong(); - if (mask == 0) { + ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER); + long mask = parsed.compare(VectorOperators.EQ, delim).toLong(); + while (mask == 0) { ptr += VECTORBYTESIZE; - continue; + mask = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr + ptr, BYTEORDER).compare(VectorOperators.EQ, delim).toLong(); } final long cityLength = ptr + Long.numberOfTrailingZeros(mask); final long tempAddr = cityAddr + cityLength + 1; + ptr = 0; // compute hash table index int index; @@ -222,67 +213,79 @@ private static Record[] threadMain(int id, long startAddr, long endAddr, long nT & 0xFFFF; else index = (UNSAFE.getByte(cityAddr) << 8) & 0xFF00; - // resolve collisions with linear probing // use vector api here also, but only if city name fits in one vector length, for faster default case - Record record = results[index]; + long record = threadResults + index * RECORDSIZE; + long recordCityLength = UNSAFE.getLong(record); if (cityLength <= VECTORBYTESIZE) { - ByteVector parsed = ByteVector.fromMemorySegment(SPECIES, VAS, cityAddr, BYTEORDER); - while (record != null) { - if (cityLength == record.cityLength) { - long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, record.cityAddr, BYTEORDER) + while (recordCityLength > 0) { + if (cityLength == recordCityLength) { + long sameMask = ByteVector.fromMemorySegment(SPECIES, VAS, UNSAFE.getLong(record + 8), BYTEORDER) .compare(VectorOperators.EQ, parsed).toLong(); if (Long.numberOfTrailingZeros(~sameMask) >= cityLength) break; } - record = results[++index]; + index++; + record = threadResults + index * RECORDSIZE; + recordCityLength = UNSAFE.getLong(record); } } else { // slower normal case for city names with length > VECTORBYTESIZE - while (record != null && (cityLength != record.cityLength || citiesDiffer(record.cityAddr, cityAddr, cityLength))) - record = results[++index]; + while (recordCityLength > 0 && (cityLength != recordCityLength || citiesDiffer(UNSAFE.getLong(record + 8), cityAddr, cityLength))) { + index++; + record = threadResults + index * RECORDSIZE; + recordCityLength = UNSAFE.getLong(record); + } } - // add record for new keys - // TODO have to avoid memory allocations on hot path - if (record == null) { - results[index] = new Record(cityAddr, cityLength); - record = results[index]; + // add record for new key + if (recordCityLength == 0) { + UNSAFE.putLong(record, cityLength); + UNSAFE.putLong(record + 8, cityAddr); + UNSAFE.putInt(record + 16, 1000); + UNSAFE.putInt(record + 20, -1000); } // parse temp with lookup tables int temp; if (UNSAFE.getByte(tempAddr) == '-') { temp = -lookupDecimal[UNSAFE.getShort(tempAddr + 1)] - lookupFraction[UNSAFE.getShort(tempAddr + 3)]; - nextCityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)]; + cityAddr = tempAddr + lookupDotNegative[UNSAFE.getShort(tempAddr + 3)]; } else { temp = lookupDecimal[UNSAFE.getShort(tempAddr)] + lookupFraction[UNSAFE.getShort(tempAddr + 2)]; - nextCityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)]; + cityAddr = tempAddr + lookupDotPositive[UNSAFE.getShort(tempAddr + 2)]; } - cityAddr = nextCityAddr; - ptr = 0; - // merge record - if (temp < record.min) - record.min = temp; - if (temp > record.max) - record.max = temp; - record.sum += temp; - record.count += 1; + // merge + if (temp < UNSAFE.getInt(record + 16)) + UNSAFE.putInt(record + 16, temp); + if (temp > UNSAFE.getInt(record + 20)) + UNSAFE.putInt(record + 20, temp); + UNSAFE.putLong(record + 24, UNSAFE.getLong(record + 24) + temp); + UNSAFE.putInt(record + 32, UNSAFE.getInt(record + 32) + 1); } // create strings from raw data - // TODO should avoid this copy + // and aggregate results onto TreeMap + int idx = 0; byte b[] = new byte[100]; + _mutex.lock(); for (int i = 0; i < MAXINDEX; i++) { - Record r = results[i]; - if (r == null) + if (UNSAFE.getLong(threadResults + i * RECORDSIZE) == 0) continue; - UNSAFE.copyMemory(null, r.cityAddr, b, Unsafe.ARRAY_BYTE_BASE_OFFSET, r.cityLength); - r.city = new String(b, 0, (int) r.cityLength, StandardCharsets.UTF_8); + final long recordAddress = threadResults + i * RECORDSIZE; + + results[idx] = new Record(UNSAFE.getLong(recordAddress + 8), UNSAFE.getLong(recordAddress)); + results[idx].min = UNSAFE.getInt(recordAddress + 16); + results[idx].max = UNSAFE.getInt(recordAddress + 20); + results[idx].sum = UNSAFE.getLong(recordAddress + 24); + results[idx].count = UNSAFE.getInt(recordAddress + 32); + UNSAFE.copyMemory(null, UNSAFE.getLong(recordAddress + 8), b, Unsafe.ARRAY_BYTE_BASE_OFFSET, UNSAFE.getLong(recordAddress)); + final Record record = results[idx]; + aggregateResults.compute(new String(b, 0, (int) results[idx].cityLength, StandardCharsets.UTF_8), (k, v) -> (v == null) ? record : v.merge(record)); + idx++; } - return results; + _mutex.unlock(); } - }