1 /* 2 * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @modules jdk.incubator.foreign/jdk.incubator.foreign.unsafe 27 * jdk.incubator.foreign/jdk.internal.foreign 28 * jdk.incubator.foreign/jdk.internal.foreign.abi 29 * java.base/sun.security.action 30 * @build NativeTestHelper StdLibTest 31 * @run testng/othervm -Dforeign.restricted=permit StdLibTest 32 */ 33 34 import java.lang.invoke.MethodHandle; 35 import java.lang.invoke.MethodHandles; 36 import java.lang.invoke.MethodType; 37 import java.lang.invoke.VarHandle; 38 import java.nio.ByteOrder; 39 import java.time.Instant; 40 import java.time.LocalDateTime; 41 import java.time.ZoneOffset; 42 import java.time.ZonedDateTime; 43 import java.util.ArrayList; 44 import java.util.Arrays; 45 import java.util.Collections; 46 import java.util.LinkedHashSet; 47 import java.util.List; 48 import java.util.Set; 49 import java.util.function.Consumer; 50 import java.util.stream.Collectors; 51 import java.util.stream.IntStream; 52 import java.util.stream.LongStream; 53 import java.util.stream.Stream; 54 55 import jdk.incubator.foreign.CSupport; 56 import jdk.incubator.foreign.ForeignLinker; 57 import jdk.incubator.foreign.FunctionDescriptor; 58 import jdk.incubator.foreign.LibraryLookup; 59 import jdk.incubator.foreign.MemoryAddress; 60 import jdk.incubator.foreign.MemoryHandles; 61 import jdk.incubator.foreign.MemoryLayout; 62 import jdk.incubator.foreign.MemorySegment; 63 import jdk.incubator.foreign.SequenceLayout; 64 import org.testng.annotations.*; 65 66 import static jdk.incubator.foreign.CSupport.*; 67 import static org.testng.Assert.*; 68 69 @Test 70 public class StdLibTest extends NativeTestHelper { 71 72 final static ForeignLinker abi = CSupport.getSystemLinker(); 73 74 final static VarHandle byteHandle = MemoryHandles.varHandle(byte.class, ByteOrder.nativeOrder()); 75 final static VarHandle intHandle = MemoryHandles.varHandle(int.class, ByteOrder.nativeOrder()); 76 final static VarHandle longHandle = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder()); 77 final static VarHandle byteArrHandle = arrayHandle(C_CHAR, byte.class); 78 final static VarHandle intArrHandle = arrayHandle(C_INT, int.class); 79 80 static VarHandle arrayHandle(MemoryLayout elemLayout, Class<?> elemCarrier) { 81 return MemoryLayout.ofSequence(1, elemLayout) 82 .varHandle(elemCarrier, MemoryLayout.PathElement.sequenceElement()); 83 } 84 85 private StdLibHelper stdLibHelper = new StdLibHelper(); 86 87 @Test(dataProvider = "stringPairs") 88 void test_strcat(String s1, String s2) throws Throwable { 89 assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2); 90 } 91 92 @Test(dataProvider = "stringPairs") 93 void test_strcmp(String s1, String s2) throws Throwable { 94 assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2))); 95 } 96 97 @Test(dataProvider = "strings") 98 void test_puts(String s) throws Throwable { 99 assertTrue(stdLibHelper.puts(s) >= 0); 100 } 101 102 @Test(dataProvider = "strings") 103 void test_strlen(String s) throws Throwable { 104 assertEquals(stdLibHelper.strlen(s), s.length()); 105 } 106 107 @Test(dataProvider = "instants") 108 void test_time(Instant instant) throws Throwable { 109 StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond()); 110 LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC); 111 assertEquals(tm.sec(), localTime.getSecond()); 112 assertEquals(tm.min(), localTime.getMinute()); 113 assertEquals(tm.hour(), localTime.getHour()); 114 //day pf year in Java has 1-offset 115 assertEquals(tm.yday(), localTime.getDayOfYear() - 1); 116 assertEquals(tm.mday(), localTime.getDayOfMonth()); 117 //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset 118 assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1); 119 //month in Java has 1-offset 120 assertEquals(tm.mon(), localTime.getMonth().getValue() - 1); 121 assertEquals(tm.isdst(), ZoneOffset.UTC.getRules() 122 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000))); 123 } 124 125 @Test(dataProvider = "ints") 126 void test_qsort(List<Integer> ints) throws Throwable { 127 if (ints.size() > 0) { 128 int[] input = ints.stream().mapToInt(i -> i).toArray(); 129 int[] sorted = stdLibHelper.qsort(input); 130 Arrays.sort(input); 131 assertEquals(sorted, input); 132 } 133 } 134 135 @Test 136 void test_rand() throws Throwable { 137 int val = stdLibHelper.rand(); 138 for (int i = 0 ; i < 100 ; i++) { 139 int newVal = stdLibHelper.rand(); 140 if (newVal != val) { 141 return; //ok 142 } 143 val = newVal; 144 } 145 fail("All values are the same! " + val); 146 } 147 148 @Test(dataProvider = "printfArgs") 149 void test_printf(List<PrintfArg> args) throws Throwable { 150 String formatArgs = args.stream() 151 .map(a -> a.format) 152 .collect(Collectors.joining(",")); 153 154 String formatString = "hello(" + formatArgs + ")\n"; 155 156 String expected = String.format(formatString, args.stream() 157 .map(a -> a.javaValue).toArray()); 158 159 int found = stdLibHelper.printf(formatString, args); 160 assertEquals(found, expected.length()); 161 } 162 163 @Test(dataProvider = "printfArgs") 164 void test_vprintf(List<PrintfArg> args) throws Throwable { 165 String formatArgs = args.stream() 166 .map(a -> a.format) 167 .collect(Collectors.joining(",")); 168 169 String formatString = "hello(" + formatArgs + ")\n"; 170 171 String expected = String.format(formatString, args.stream() 172 .map(a -> a.javaValue).toArray()); 173 174 int found = stdLibHelper.vprintf(formatString, args); 175 assertEquals(found, expected.length()); 176 } 177 178 static class StdLibHelper { 179 180 final static MethodHandle strcat; 181 final static MethodHandle strcmp; 182 final static MethodHandle puts; 183 final static MethodHandle strlen; 184 final static MethodHandle gmtime; 185 final static MethodHandle qsort; 186 final static MethodHandle qsortCompar; 187 final static FunctionDescriptor qsortComparFunction; 188 final static MethodHandle rand; 189 final static MethodHandle vprintf; 190 final static MemoryAddress printfAddr; 191 final static FunctionDescriptor printfBase; 192 193 static { 194 try { 195 LibraryLookup lookup = LibraryLookup.ofDefault(); 196 197 strcat = abi.downcallHandle(lookup.lookup("strcat"), 198 MethodType.methodType(MemoryAddress.class, MemoryAddress.class, MemoryAddress.class), 199 FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER)); 200 201 strcmp = abi.downcallHandle(lookup.lookup("strcmp"), 202 MethodType.methodType(int.class, MemoryAddress.class, MemoryAddress.class), 203 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER)); 204 205 puts = abi.downcallHandle(lookup.lookup("puts"), 206 MethodType.methodType(int.class, MemoryAddress.class), 207 FunctionDescriptor.of(C_INT, C_POINTER)); 208 209 strlen = abi.downcallHandle(lookup.lookup("strlen"), 210 MethodType.methodType(int.class, MemoryAddress.class), 211 FunctionDescriptor.of(C_INT, C_POINTER)); 212 213 gmtime = abi.downcallHandle(lookup.lookup("gmtime"), 214 MethodType.methodType(MemoryAddress.class, MemoryAddress.class), 215 FunctionDescriptor.of(C_POINTER, C_POINTER)); 216 217 qsortComparFunction = FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER); 218 219 qsort = abi.downcallHandle(lookup.lookup("qsort"), 220 MethodType.methodType(void.class, MemoryAddress.class, long.class, long.class, MemoryAddress.class), 221 FunctionDescriptor.ofVoid(C_POINTER, C_LONGLONG, C_LONGLONG, C_POINTER)); 222 223 //qsort upcall handle 224 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare", 225 MethodType.methodType(int.class, MemorySegment.class, MemoryAddress.class, MemoryAddress.class)); 226 227 rand = abi.downcallHandle(lookup.lookup("rand"), 228 MethodType.methodType(int.class), 229 FunctionDescriptor.of(C_INT)); 230 231 vprintf = abi.downcallHandle(lookup.lookup("vprintf"), 232 MethodType.methodType(int.class, MemoryAddress.class, VaList.class), 233 FunctionDescriptor.of(C_INT, C_POINTER, C_VA_LIST)); 234 235 printfAddr = lookup.lookup("printf"); 236 237 printfBase = FunctionDescriptor.of(C_INT, C_POINTER); 238 } catch (Throwable ex) { 239 throw new IllegalStateException(ex); 240 } 241 } 242 243 String strcat(String s1, String s2) throws Throwable { 244 try (MemorySegment buf = MemorySegment.allocateNative(s1.length() + s2.length() + 1) ; 245 MemorySegment other = toCString(s2)) { 246 char[] chars = s1.toCharArray(); 247 for (long i = 0 ; i < chars.length ; i++) { 248 byteArrHandle.set(buf.baseAddress(), i, (byte)chars[(int)i]); 249 } 250 byteArrHandle.set(buf.baseAddress(), (long)chars.length, (byte)'\0'); 251 return toJavaStringRestricted(((MemoryAddress)strcat.invokeExact(buf.baseAddress(), other.baseAddress()))); 252 } 253 } 254 255 int strcmp(String s1, String s2) throws Throwable { 256 try (MemorySegment ns1 = toCString(s1) ; 257 MemorySegment ns2 = toCString(s2)) { 258 return (int)strcmp.invokeExact(ns1.baseAddress(), ns2.baseAddress()); 259 } 260 } 261 262 int puts(String msg) throws Throwable { 263 try (MemorySegment s = toCString(msg)) { 264 return (int)puts.invokeExact(s.baseAddress()); 265 } 266 } 267 268 int strlen(String msg) throws Throwable { 269 try (MemorySegment s = toCString(msg)) { 270 return (int)strlen.invokeExact(s.baseAddress()); 271 } 272 } 273 274 Tm gmtime(long arg) throws Throwable { 275 try (MemorySegment time = MemorySegment.allocateNative(8)) { 276 longHandle.set(time.baseAddress(), arg); 277 return new Tm((MemoryAddress)gmtime.invokeExact(time.baseAddress())); 278 } 279 } 280 281 static class Tm { 282 283 //Tm pointer should never be freed directly, as it points to shared memory 284 private final MemoryAddress base; 285 286 static final long SIZE = 56; 287 288 Tm(MemoryAddress base) { 289 this.base = MemorySegment.ofNativeRestricted(base, SIZE, Thread.currentThread(), 290 null, null).baseAddress(); 291 } 292 293 int sec() { 294 return (int)intHandle.get(base); 295 } 296 int min() { 297 return (int)intHandle.get(base.addOffset(4)); 298 } 299 int hour() { 300 return (int)intHandle.get(base.addOffset(8)); 301 } 302 int mday() { 303 return (int)intHandle.get(base.addOffset(12)); 304 } 305 int mon() { 306 return (int)intHandle.get(base.addOffset(16)); 307 } 308 int year() { 309 return (int)intHandle.get(base.addOffset(20)); 310 } 311 int wday() { 312 return (int)intHandle.get(base.addOffset(24)); 313 } 314 int yday() { 315 return (int)intHandle.get(base.addOffset(28)); 316 } 317 boolean isdst() { 318 byte b = (byte)byteHandle.get(base.addOffset(32)); 319 return b == 0 ? false : true; 320 } 321 } 322 323 int[] qsort(int[] arr) throws Throwable { 324 //init native array 325 SequenceLayout seq = MemoryLayout.ofSequence(arr.length, C_INT); 326 327 try (MemorySegment nativeArr = MemorySegment.allocateNative(seq)) { 328 329 IntStream.range(0, arr.length) 330 .forEach(i -> intArrHandle.set(nativeArr.baseAddress(), i, arr[i])); 331 332 //call qsort 333 try (MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar.bindTo(nativeArr), qsortComparFunction)) { 334 qsort.invokeExact(nativeArr.baseAddress(), seq.elementCount().getAsLong(), C_INT.byteSize(), qsortUpcallStub.baseAddress()); 335 } 336 337 //convert back to Java array 338 return LongStream.range(0, arr.length) 339 .mapToInt(i -> (int)intArrHandle.get(nativeArr.baseAddress(), i)) 340 .toArray(); 341 } 342 } 343 344 static int qsortCompare(MemorySegment base, MemoryAddress addr1, MemoryAddress addr2) { 345 return (int)intHandle.get(addr1.rebase(base)) - (int)intHandle.get(addr2.rebase(base)); 346 } 347 348 int rand() throws Throwable { 349 return (int)rand.invokeExact(); 350 } 351 352 int printf(String format, List<PrintfArg> args) throws Throwable { 353 try (MemorySegment formatStr = toCString(format)) { 354 return (int)specializedPrintf(args).invokeExact(formatStr.baseAddress(), 355 args.stream().map(a -> a.nativeValue).toArray()); 356 } 357 } 358 359 int vprintf(String format, List<PrintfArg> args) throws Throwable { 360 try (MemorySegment formatStr = toCString(format)) { 361 return (int)vprintf.invokeExact(formatStr.baseAddress(), 362 VaList.make(b -> args.forEach(a -> a.accept(b)))); 363 } 364 } 365 366 private MethodHandle specializedPrintf(List<PrintfArg> args) { 367 //method type 368 MethodType mt = MethodType.methodType(int.class, MemoryAddress.class); 369 FunctionDescriptor fd = printfBase; 370 for (PrintfArg arg : args) { 371 mt = mt.appendParameterTypes(arg.carrier); 372 fd = fd.appendArgumentLayouts(arg.layout); 373 } 374 MethodHandle mh = abi.downcallHandle(printfAddr, mt, fd); 375 return mh.asSpreader(1, Object[].class, args.size()); 376 } 377 } 378 379 /*** data providers ***/ 380 381 @DataProvider 382 public static Object[][] ints() { 383 return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream() 384 .map(l -> new Object[] { l }) 385 .toArray(Object[][]::new); 386 } 387 388 @DataProvider 389 public static Object[][] strings() { 390 return perms(0, new String[] { "a", "b", "c" }).stream() 391 .map(l -> new Object[] { String.join("", l) }) 392 .toArray(Object[][]::new); 393 } 394 395 @DataProvider 396 public static Object[][] stringPairs() { 397 Object[][] strings = strings(); 398 Object[][] stringPairs = new Object[strings.length * strings.length][]; 399 int pos = 0; 400 for (Object[] s1 : strings) { 401 for (Object[] s2 : strings) { 402 stringPairs[pos++] = new Object[] { s1[0], s2[0] }; 403 } 404 } 405 return stringPairs; 406 } 407 408 @DataProvider 409 public static Object[][] instants() { 410 Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant(); 411 Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant(); 412 Object[][] instants = new Object[100][]; 413 for (int i = 0 ; i < instants.length ; i++) { 414 Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond()))); 415 instants[i] = new Object[] { instant }; 416 } 417 return instants; 418 } 419 420 @DataProvider 421 public static Object[][] printfArgs() { 422 ArrayList<List<PrintfArg>> res = new ArrayList<>(); 423 List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values())); 424 for (int i = 0 ; i < 100 ; i++) { 425 Collections.shuffle(perms); 426 res.addAll(perms); 427 } 428 return res.stream() 429 .map(l -> new Object[] { l }) 430 .toArray(Object[][]::new); 431 } 432 433 enum PrintfArg implements Consumer<VaList.Builder> { 434 435 INTEGRAL(int.class, asVarArg(C_INT), "%d", 42, 42, VaList.Builder::vargFromInt), 436 STRING(MemoryAddress.class, asVarArg(C_POINTER), "%s", toCString("str").baseAddress(), "str", VaList.Builder::vargFromAddress), 437 CHAR(byte.class, asVarArg(C_CHAR), "%c", (byte) 'h', 'h', (builder, layout, value) -> builder.vargFromInt(C_INT, (int)value)), 438 DOUBLE(double.class, asVarArg(C_DOUBLE), "%.4f", 1.2345d, 1.2345d, VaList.Builder::vargFromDouble); 439 440 final Class<?> carrier; 441 final MemoryLayout layout; 442 final String format; 443 final Object nativeValue; 444 final Object javaValue; 445 @SuppressWarnings("rawtypes") 446 final VaListBuilderCall builderCall; 447 448 <Z> PrintfArg(Class<?> carrier, MemoryLayout layout, String format, Z nativeValue, Object javaValue, VaListBuilderCall<Z> builderCall) { 449 this.carrier = carrier; 450 this.layout = layout; 451 this.format = format; 452 this.nativeValue = nativeValue; 453 this.javaValue = javaValue; 454 this.builderCall = builderCall; 455 } 456 457 @Override 458 @SuppressWarnings("unchecked") 459 public void accept(VaList.Builder builder) { 460 builderCall.build(builder, layout, nativeValue); 461 } 462 463 interface VaListBuilderCall<V> { 464 void build(VaList.Builder builder, MemoryLayout layout, V value); 465 } 466 } 467 468 static <Z> Set<List<Z>> perms(int count, Z[] arr) { 469 if (count == arr.length) { 470 return Set.of(List.of()); 471 } else { 472 return Arrays.stream(arr) 473 .flatMap(num -> { 474 Set<List<Z>> perms = perms(count + 1, arr); 475 return Stream.concat( 476 //take n 477 perms.stream().map(l -> { 478 List<Z> li = new ArrayList<>(l); 479 li.add(num); 480 return li; 481 }), 482 //drop n 483 perms.stream()); 484 }).collect(Collectors.toCollection(LinkedHashSet::new)); 485 } 486 } 487 }