1 /*
2 * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 /*
25 * @test
26 * @modules jdk.incubator.foreign/jdk.incubator.foreign.unsafe
27 * jdk.incubator.foreign/jdk.internal.foreign
28 * jdk.incubator.foreign/jdk.internal.foreign.abi
29 * java.base/sun.security.action
30 * @build NativeTestHelper StdLibTest
31 * @run testng/othervm -Dforeign.restricted=permit StdLibTest
32 */
33
34 import java.lang.invoke.MethodHandle;
35 import java.lang.invoke.MethodHandles;
36 import java.lang.invoke.MethodType;
37 import java.lang.invoke.VarHandle;
38 import java.nio.ByteOrder;
39 import java.time.Instant;
40 import java.time.LocalDateTime;
41 import java.time.ZoneOffset;
42 import java.time.ZonedDateTime;
43 import java.util.ArrayList;
44 import java.util.Arrays;
45 import java.util.Collections;
46 import java.util.LinkedHashSet;
47 import java.util.List;
48 import java.util.Set;
49 import java.util.function.Consumer;
50 import java.util.stream.Collectors;
51 import java.util.stream.IntStream;
52 import java.util.stream.LongStream;
53 import java.util.stream.Stream;
54
55 import jdk.incubator.foreign.CSupport;
56 import jdk.incubator.foreign.ForeignLinker;
57 import jdk.incubator.foreign.FunctionDescriptor;
58 import jdk.incubator.foreign.LibraryLookup;
59 import jdk.incubator.foreign.MemoryAddress;
60 import jdk.incubator.foreign.MemoryHandles;
61 import jdk.incubator.foreign.MemoryLayout;
62 import jdk.incubator.foreign.MemorySegment;
63 import jdk.incubator.foreign.SequenceLayout;
64 import org.testng.annotations.*;
65
66 import static jdk.incubator.foreign.CSupport.*;
67 import static org.testng.Assert.*;
68
69 @Test
70 public class StdLibTest extends NativeTestHelper {
71
72 final static ForeignLinker abi = CSupport.getSystemLinker();
73
74 final static VarHandle byteHandle = MemoryHandles.varHandle(byte.class, ByteOrder.nativeOrder());
75 final static VarHandle intHandle = MemoryHandles.varHandle(int.class, ByteOrder.nativeOrder());
76 final static VarHandle longHandle = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder());
77 final static VarHandle byteArrHandle = arrayHandle(C_CHAR, byte.class);
78 final static VarHandle intArrHandle = arrayHandle(C_INT, int.class);
79
80 static VarHandle arrayHandle(MemoryLayout elemLayout, Class<?> elemCarrier) {
81 return MemoryLayout.ofSequence(1, elemLayout)
82 .varHandle(elemCarrier, MemoryLayout.PathElement.sequenceElement());
83 }
84
85 private StdLibHelper stdLibHelper = new StdLibHelper();
86
87 @Test(dataProvider = "stringPairs")
88 void test_strcat(String s1, String s2) throws Throwable {
89 assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
90 }
91
92 @Test(dataProvider = "stringPairs")
93 void test_strcmp(String s1, String s2) throws Throwable {
94 assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
95 }
96
97 @Test(dataProvider = "strings")
98 void test_puts(String s) throws Throwable {
99 assertTrue(stdLibHelper.puts(s) >= 0);
100 }
101
102 @Test(dataProvider = "strings")
103 void test_strlen(String s) throws Throwable {
104 assertEquals(stdLibHelper.strlen(s), s.length());
105 }
106
107 @Test(dataProvider = "instants")
108 void test_time(Instant instant) throws Throwable {
109 StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond());
110 LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);
111 assertEquals(tm.sec(), localTime.getSecond());
112 assertEquals(tm.min(), localTime.getMinute());
113 assertEquals(tm.hour(), localTime.getHour());
114 //day pf year in Java has 1-offset
115 assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
116 assertEquals(tm.mday(), localTime.getDayOfMonth());
117 //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
118 assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
119 //month in Java has 1-offset
120 assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
121 assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()
122 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000)));
123 }
124
125 @Test(dataProvider = "ints")
126 void test_qsort(List<Integer> ints) throws Throwable {
127 if (ints.size() > 0) {
128 int[] input = ints.stream().mapToInt(i -> i).toArray();
129 int[] sorted = stdLibHelper.qsort(input);
130 Arrays.sort(input);
131 assertEquals(sorted, input);
132 }
133 }
134
135 @Test
136 void test_rand() throws Throwable {
137 int val = stdLibHelper.rand();
138 for (int i = 0 ; i < 100 ; i++) {
139 int newVal = stdLibHelper.rand();
140 if (newVal != val) {
141 return; //ok
142 }
143 val = newVal;
144 }
145 fail("All values are the same! " + val);
146 }
147
148 @Test(dataProvider = "printfArgs")
149 void test_printf(List<PrintfArg> args) throws Throwable {
150 String formatArgs = args.stream()
151 .map(a -> a.format)
152 .collect(Collectors.joining(","));
153
154 String formatString = "hello(" + formatArgs + ")\n";
155
156 String expected = String.format(formatString, args.stream()
157 .map(a -> a.javaValue).toArray());
158
159 int found = stdLibHelper.printf(formatString, args);
160 assertEquals(found, expected.length());
161 }
162
163 @Test(dataProvider = "printfArgs")
164 void test_vprintf(List<PrintfArg> args) throws Throwable {
165 String formatArgs = args.stream()
166 .map(a -> a.format)
167 .collect(Collectors.joining(","));
168
169 String formatString = "hello(" + formatArgs + ")\n";
170
171 String expected = String.format(formatString, args.stream()
172 .map(a -> a.javaValue).toArray());
173
174 int found = stdLibHelper.vprintf(formatString, args);
175 assertEquals(found, expected.length());
176 }
177
178 static class StdLibHelper {
179
180 final static MethodHandle strcat;
181 final static MethodHandle strcmp;
182 final static MethodHandle puts;
183 final static MethodHandle strlen;
184 final static MethodHandle gmtime;
185 final static MethodHandle qsort;
186 final static MethodHandle qsortCompar;
187 final static FunctionDescriptor qsortComparFunction;
188 final static MethodHandle rand;
189 final static MethodHandle vprintf;
190 final static MemoryAddress printfAddr;
191 final static FunctionDescriptor printfBase;
192
193 static {
194 try {
195 LibraryLookup lookup = LibraryLookup.ofDefault();
196
197 strcat = abi.downcallHandle(lookup.lookup("strcat"),
198 MethodType.methodType(MemoryAddress.class, MemoryAddress.class, MemoryAddress.class),
199 FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER));
200
201 strcmp = abi.downcallHandle(lookup.lookup("strcmp"),
202 MethodType.methodType(int.class, MemoryAddress.class, MemoryAddress.class),
203 FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
204
205 puts = abi.downcallHandle(lookup.lookup("puts"),
206 MethodType.methodType(int.class, MemoryAddress.class),
207 FunctionDescriptor.of(C_INT, C_POINTER));
208
209 strlen = abi.downcallHandle(lookup.lookup("strlen"),
210 MethodType.methodType(int.class, MemoryAddress.class),
211 FunctionDescriptor.of(C_INT, C_POINTER));
212
213 gmtime = abi.downcallHandle(lookup.lookup("gmtime"),
214 MethodType.methodType(MemoryAddress.class, MemoryAddress.class),
215 FunctionDescriptor.of(C_POINTER, C_POINTER));
216
217 qsortComparFunction = FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER);
218
219 qsort = abi.downcallHandle(lookup.lookup("qsort"),
220 MethodType.methodType(void.class, MemoryAddress.class, long.class, long.class, MemoryAddress.class),
221 FunctionDescriptor.ofVoid(C_POINTER, C_LONGLONG, C_LONGLONG, C_POINTER));
222
223 //qsort upcall handle
224 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",
225 MethodType.methodType(int.class, MemorySegment.class, MemoryAddress.class, MemoryAddress.class));
226
227 rand = abi.downcallHandle(lookup.lookup("rand"),
228 MethodType.methodType(int.class),
229 FunctionDescriptor.of(C_INT));
230
231 vprintf = abi.downcallHandle(lookup.lookup("vprintf"),
232 MethodType.methodType(int.class, MemoryAddress.class, VaList.class),
233 FunctionDescriptor.of(C_INT, C_POINTER, C_VA_LIST));
234
235 printfAddr = lookup.lookup("printf");
236
237 printfBase = FunctionDescriptor.of(C_INT, C_POINTER);
238 } catch (Throwable ex) {
239 throw new IllegalStateException(ex);
240 }
241 }
242
243 String strcat(String s1, String s2) throws Throwable {
244 try (MemorySegment buf = MemorySegment.allocateNative(s1.length() + s2.length() + 1) ;
245 MemorySegment other = toCString(s2)) {
246 char[] chars = s1.toCharArray();
247 for (long i = 0 ; i < chars.length ; i++) {
248 byteArrHandle.set(buf.baseAddress(), i, (byte)chars[(int)i]);
249 }
250 byteArrHandle.set(buf.baseAddress(), (long)chars.length, (byte)'\0');
251 return toJavaStringRestricted(((MemoryAddress)strcat.invokeExact(buf.baseAddress(), other.baseAddress())));
252 }
253 }
254
255 int strcmp(String s1, String s2) throws Throwable {
256 try (MemorySegment ns1 = toCString(s1) ;
257 MemorySegment ns2 = toCString(s2)) {
258 return (int)strcmp.invokeExact(ns1.baseAddress(), ns2.baseAddress());
259 }
260 }
261
262 int puts(String msg) throws Throwable {
263 try (MemorySegment s = toCString(msg)) {
264 return (int)puts.invokeExact(s.baseAddress());
265 }
266 }
267
268 int strlen(String msg) throws Throwable {
269 try (MemorySegment s = toCString(msg)) {
270 return (int)strlen.invokeExact(s.baseAddress());
271 }
272 }
273
274 Tm gmtime(long arg) throws Throwable {
275 try (MemorySegment time = MemorySegment.allocateNative(8)) {
276 longHandle.set(time.baseAddress(), arg);
277 return new Tm((MemoryAddress)gmtime.invokeExact(time.baseAddress()));
278 }
279 }
280
281 static class Tm {
282
283 //Tm pointer should never be freed directly, as it points to shared memory
284 private final MemoryAddress base;
285
286 static final long SIZE = 56;
287
288 Tm(MemoryAddress base) {
289 this.base = MemorySegment.ofNativeRestricted(base, SIZE, Thread.currentThread(),
290 null, null).baseAddress();
291 }
292
293 int sec() {
294 return (int)intHandle.get(base);
295 }
296 int min() {
297 return (int)intHandle.get(base.addOffset(4));
298 }
299 int hour() {
300 return (int)intHandle.get(base.addOffset(8));
301 }
302 int mday() {
303 return (int)intHandle.get(base.addOffset(12));
304 }
305 int mon() {
306 return (int)intHandle.get(base.addOffset(16));
307 }
308 int year() {
309 return (int)intHandle.get(base.addOffset(20));
310 }
311 int wday() {
312 return (int)intHandle.get(base.addOffset(24));
313 }
314 int yday() {
315 return (int)intHandle.get(base.addOffset(28));
316 }
317 boolean isdst() {
318 byte b = (byte)byteHandle.get(base.addOffset(32));
319 return b == 0 ? false : true;
320 }
321 }
322
323 int[] qsort(int[] arr) throws Throwable {
324 //init native array
325 SequenceLayout seq = MemoryLayout.ofSequence(arr.length, C_INT);
326
327 try (MemorySegment nativeArr = MemorySegment.allocateNative(seq)) {
328
329 IntStream.range(0, arr.length)
330 .forEach(i -> intArrHandle.set(nativeArr.baseAddress(), i, arr[i]));
331
332 //call qsort
333 try (MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar.bindTo(nativeArr), qsortComparFunction)) {
334 qsort.invokeExact(nativeArr.baseAddress(), seq.elementCount().getAsLong(), C_INT.byteSize(), qsortUpcallStub.baseAddress());
335 }
336
337 //convert back to Java array
338 return LongStream.range(0, arr.length)
339 .mapToInt(i -> (int)intArrHandle.get(nativeArr.baseAddress(), i))
340 .toArray();
341 }
342 }
343
344 static int qsortCompare(MemorySegment base, MemoryAddress addr1, MemoryAddress addr2) {
345 return (int)intHandle.get(addr1.rebase(base)) - (int)intHandle.get(addr2.rebase(base));
346 }
347
348 int rand() throws Throwable {
349 return (int)rand.invokeExact();
350 }
351
352 int printf(String format, List<PrintfArg> args) throws Throwable {
353 try (MemorySegment formatStr = toCString(format)) {
354 return (int)specializedPrintf(args).invokeExact(formatStr.baseAddress(),
355 args.stream().map(a -> a.nativeValue).toArray());
356 }
357 }
358
359 int vprintf(String format, List<PrintfArg> args) throws Throwable {
360 try (MemorySegment formatStr = toCString(format)) {
361 return (int)vprintf.invokeExact(formatStr.baseAddress(),
362 VaList.make(b -> args.forEach(a -> a.accept(b))));
363 }
364 }
365
366 private MethodHandle specializedPrintf(List<PrintfArg> args) {
367 //method type
368 MethodType mt = MethodType.methodType(int.class, MemoryAddress.class);
369 FunctionDescriptor fd = printfBase;
370 for (PrintfArg arg : args) {
371 mt = mt.appendParameterTypes(arg.carrier);
372 fd = fd.appendArgumentLayouts(arg.layout);
373 }
374 MethodHandle mh = abi.downcallHandle(printfAddr, mt, fd);
375 return mh.asSpreader(1, Object[].class, args.size());
376 }
377 }
378
379 /*** data providers ***/
380
381 @DataProvider
382 public static Object[][] ints() {
383 return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
384 .map(l -> new Object[] { l })
385 .toArray(Object[][]::new);
386 }
387
388 @DataProvider
389 public static Object[][] strings() {
390 return perms(0, new String[] { "a", "b", "c" }).stream()
391 .map(l -> new Object[] { String.join("", l) })
392 .toArray(Object[][]::new);
393 }
394
395 @DataProvider
396 public static Object[][] stringPairs() {
397 Object[][] strings = strings();
398 Object[][] stringPairs = new Object[strings.length * strings.length][];
399 int pos = 0;
400 for (Object[] s1 : strings) {
401 for (Object[] s2 : strings) {
402 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
403 }
404 }
405 return stringPairs;
406 }
407
408 @DataProvider
409 public static Object[][] instants() {
410 Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();
411 Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();
412 Object[][] instants = new Object[100][];
413 for (int i = 0 ; i < instants.length ; i++) {
414 Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));
415 instants[i] = new Object[] { instant };
416 }
417 return instants;
418 }
419
420 @DataProvider
421 public static Object[][] printfArgs() {
422 ArrayList<List<PrintfArg>> res = new ArrayList<>();
423 List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values()));
424 for (int i = 0 ; i < 100 ; i++) {
425 Collections.shuffle(perms);
426 res.addAll(perms);
427 }
428 return res.stream()
429 .map(l -> new Object[] { l })
430 .toArray(Object[][]::new);
431 }
432
433 enum PrintfArg implements Consumer<VaList.Builder> {
434
435 INTEGRAL(int.class, asVarArg(C_INT), "%d", 42, 42, VaList.Builder::vargFromInt),
436 STRING(MemoryAddress.class, asVarArg(C_POINTER), "%s", toCString("str").baseAddress(), "str", VaList.Builder::vargFromAddress),
437 CHAR(byte.class, asVarArg(C_CHAR), "%c", (byte) 'h', 'h', (builder, layout, value) -> builder.vargFromInt(C_INT, (int)value)),
438 DOUBLE(double.class, asVarArg(C_DOUBLE), "%.4f", 1.2345d, 1.2345d, VaList.Builder::vargFromDouble);
439
440 final Class<?> carrier;
441 final MemoryLayout layout;
442 final String format;
443 final Object nativeValue;
444 final Object javaValue;
445 @SuppressWarnings("rawtypes")
446 final VaListBuilderCall builderCall;
447
448 <Z> PrintfArg(Class<?> carrier, MemoryLayout layout, String format, Z nativeValue, Object javaValue, VaListBuilderCall<Z> builderCall) {
449 this.carrier = carrier;
450 this.layout = layout;
451 this.format = format;
452 this.nativeValue = nativeValue;
453 this.javaValue = javaValue;
454 this.builderCall = builderCall;
455 }
456
457 @Override
458 @SuppressWarnings("unchecked")
459 public void accept(VaList.Builder builder) {
460 builderCall.build(builder, layout, nativeValue);
461 }
462
463 interface VaListBuilderCall<V> {
464 void build(VaList.Builder builder, MemoryLayout layout, V value);
465 }
466 }
467
468 static <Z> Set<List<Z>> perms(int count, Z[] arr) {
469 if (count == arr.length) {
470 return Set.of(List.of());
471 } else {
472 return Arrays.stream(arr)
473 .flatMap(num -> {
474 Set<List<Z>> perms = perms(count + 1, arr);
475 return Stream.concat(
476 //take n
477 perms.stream().map(l -> {
478 List<Z> li = new ArrayList<>(l);
479 li.add(num);
480 return li;
481 }),
482 //drop n
483 perms.stream());
484 }).collect(Collectors.toCollection(LinkedHashSet::new));
485 }
486 }
487 }