1 /*
  2  * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.
  8  *
  9  * This code is distributed in the hope that it will be useful, but WITHOUT
 10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  * version 2 for more details (a copy is included in the LICENSE file that
 13  * accompanied this code).
 14  *
 15  * You should have received a copy of the GNU General Public License version
 16  * 2 along with this work; if not, write to the Free Software Foundation,
 17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  * or visit www.oracle.com if you need additional information or have any
 21  * questions.
 22  */
 23 
 24 /*
 25  * @test
 26   * @modules jdk.incubator.foreign/jdk.incubator.foreign.unsafe
 27  *          jdk.incubator.foreign/jdk.internal.foreign
 28  *          jdk.incubator.foreign/jdk.internal.foreign.abi
 29  *          java.base/sun.security.action
 30  * @build NativeTestHelper StdLibTest
 31  * @run testng/othervm -Dforeign.restricted=permit StdLibTest
 32  */
 33 
 34 import java.lang.invoke.MethodHandle;
 35 import java.lang.invoke.MethodHandles;
 36 import java.lang.invoke.MethodType;
 37 import java.lang.invoke.VarHandle;
 38 import java.nio.ByteOrder;
 39 import java.time.Instant;
 40 import java.time.LocalDateTime;
 41 import java.time.ZoneOffset;
 42 import java.time.ZonedDateTime;
 43 import java.util.ArrayList;
 44 import java.util.Arrays;
 45 import java.util.Collections;
 46 import java.util.LinkedHashSet;
 47 import java.util.List;
 48 import java.util.Set;
 49 import java.util.function.Consumer;
 50 import java.util.stream.Collectors;
 51 import java.util.stream.IntStream;
 52 import java.util.stream.LongStream;
 53 import java.util.stream.Stream;
 54 
 55 import jdk.incubator.foreign.CSupport;
 56 import jdk.incubator.foreign.ForeignLinker;
 57 import jdk.incubator.foreign.FunctionDescriptor;
 58 import jdk.incubator.foreign.LibraryLookup;
 59 import jdk.incubator.foreign.MemoryAddress;
 60 import jdk.incubator.foreign.MemoryHandles;
 61 import jdk.incubator.foreign.MemoryLayout;
 62 import jdk.incubator.foreign.MemorySegment;
 63 import jdk.incubator.foreign.SequenceLayout;
 64 import org.testng.annotations.*;
 65 
 66 import static jdk.incubator.foreign.CSupport.*;
 67 import static org.testng.Assert.*;
 68 
 69 @Test
 70 public class StdLibTest extends NativeTestHelper {
 71 
 72     final static ForeignLinker abi = CSupport.getSystemLinker();
 73 
 74     final static VarHandle byteHandle = MemoryHandles.varHandle(byte.class, ByteOrder.nativeOrder());
 75     final static VarHandle intHandle = MemoryHandles.varHandle(int.class, ByteOrder.nativeOrder());
 76     final static VarHandle longHandle = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder());
 77     final static VarHandle byteArrHandle = arrayHandle(C_CHAR, byte.class);
 78     final static VarHandle intArrHandle = arrayHandle(C_INT, int.class);
 79 
 80     static VarHandle arrayHandle(MemoryLayout elemLayout, Class<?> elemCarrier) {
 81         return MemoryLayout.ofSequence(1, elemLayout)
 82                 .varHandle(elemCarrier, MemoryLayout.PathElement.sequenceElement());
 83     }
 84 
 85     private StdLibHelper stdLibHelper = new StdLibHelper();
 86 
 87     @Test(dataProvider = "stringPairs")
 88     void test_strcat(String s1, String s2) throws Throwable {
 89         assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
 90     }
 91 
 92     @Test(dataProvider = "stringPairs")
 93     void test_strcmp(String s1, String s2) throws Throwable {
 94         assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
 95     }
 96 
 97     @Test(dataProvider = "strings")
 98     void test_puts(String s) throws Throwable {
 99         assertTrue(stdLibHelper.puts(s) >= 0);
100     }
101 
102     @Test(dataProvider = "strings")
103     void test_strlen(String s) throws Throwable {
104         assertEquals(stdLibHelper.strlen(s), s.length());
105     }
106 
107     @Test(dataProvider = "instants")
108     void test_time(Instant instant) throws Throwable {
109         StdLibHelper.Tm tm = stdLibHelper.gmtime(instant.getEpochSecond());
110         LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);
111         assertEquals(tm.sec(), localTime.getSecond());
112         assertEquals(tm.min(), localTime.getMinute());
113         assertEquals(tm.hour(), localTime.getHour());
114         //day pf year in Java has 1-offset
115         assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
116         assertEquals(tm.mday(), localTime.getDayOfMonth());
117         //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
118         assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
119         //month in Java has 1-offset
120         assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
121         assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()
122                 .isDaylightSavings(Instant.ofEpochMilli(instant.getEpochSecond() * 1000)));
123     }
124 
125     @Test(dataProvider = "ints")
126     void test_qsort(List<Integer> ints) throws Throwable {
127         if (ints.size() > 0) {
128             int[] input = ints.stream().mapToInt(i -> i).toArray();
129             int[] sorted = stdLibHelper.qsort(input);
130             Arrays.sort(input);
131             assertEquals(sorted, input);
132         }
133     }
134 
135     @Test
136     void test_rand() throws Throwable {
137         int val = stdLibHelper.rand();
138         for (int i = 0 ; i < 100 ; i++) {
139             int newVal = stdLibHelper.rand();
140             if (newVal != val) {
141                 return; //ok
142             }
143             val = newVal;
144         }
145         fail("All values are the same! " + val);
146     }
147 
148     @Test(dataProvider = "printfArgs")
149     void test_printf(List<PrintfArg> args) throws Throwable {
150         String formatArgs = args.stream()
151                 .map(a -> a.format)
152                 .collect(Collectors.joining(","));
153 
154         String formatString = "hello(" + formatArgs + ")\n";
155 
156         String expected = String.format(formatString, args.stream()
157                 .map(a -> a.javaValue).toArray());
158 
159         int found = stdLibHelper.printf(formatString, args);
160         assertEquals(found, expected.length());
161     }
162 
163     @Test(dataProvider = "printfArgs")
164     void test_vprintf(List<PrintfArg> args) throws Throwable {
165         String formatArgs = args.stream()
166                 .map(a -> a.format)
167                 .collect(Collectors.joining(","));
168 
169         String formatString = "hello(" + formatArgs + ")\n";
170 
171         String expected = String.format(formatString, args.stream()
172                 .map(a -> a.javaValue).toArray());
173 
174         int found = stdLibHelper.vprintf(formatString, args);
175         assertEquals(found, expected.length());
176     }
177 
178     static class StdLibHelper {
179 
180         final static MethodHandle strcat;
181         final static MethodHandle strcmp;
182         final static MethodHandle puts;
183         final static MethodHandle strlen;
184         final static MethodHandle gmtime;
185         final static MethodHandle qsort;
186         final static MethodHandle qsortCompar;
187         final static FunctionDescriptor qsortComparFunction;
188         final static MethodHandle rand;
189         final static MethodHandle vprintf;
190         final static MemoryAddress printfAddr;
191         final static FunctionDescriptor printfBase;
192 
193         static {
194             try {
195                 LibraryLookup lookup = LibraryLookup.ofDefault();
196 
197                 strcat = abi.downcallHandle(lookup.lookup("strcat"),
198                         MethodType.methodType(MemoryAddress.class, MemoryAddress.class, MemoryAddress.class),
199                         FunctionDescriptor.of(C_POINTER, C_POINTER, C_POINTER));
200 
201                 strcmp = abi.downcallHandle(lookup.lookup("strcmp"),
202                         MethodType.methodType(int.class, MemoryAddress.class, MemoryAddress.class),
203                         FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER));
204 
205                 puts = abi.downcallHandle(lookup.lookup("puts"),
206                         MethodType.methodType(int.class, MemoryAddress.class),
207                         FunctionDescriptor.of(C_INT, C_POINTER));
208 
209                 strlen = abi.downcallHandle(lookup.lookup("strlen"),
210                         MethodType.methodType(int.class, MemoryAddress.class),
211                         FunctionDescriptor.of(C_INT, C_POINTER));
212 
213                 gmtime = abi.downcallHandle(lookup.lookup("gmtime"),
214                         MethodType.methodType(MemoryAddress.class, MemoryAddress.class),
215                         FunctionDescriptor.of(C_POINTER, C_POINTER));
216 
217                 qsortComparFunction = FunctionDescriptor.of(C_INT, C_POINTER, C_POINTER);
218 
219                 qsort = abi.downcallHandle(lookup.lookup("qsort"),
220                         MethodType.methodType(void.class, MemoryAddress.class, long.class, long.class, MemoryAddress.class),
221                         FunctionDescriptor.ofVoid(C_POINTER, C_LONGLONG, C_LONGLONG, C_POINTER));
222 
223                 //qsort upcall handle
224                 qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",
225                         MethodType.methodType(int.class, MemorySegment.class, MemoryAddress.class, MemoryAddress.class));
226 
227                 rand = abi.downcallHandle(lookup.lookup("rand"),
228                         MethodType.methodType(int.class),
229                         FunctionDescriptor.of(C_INT));
230 
231                 vprintf = abi.downcallHandle(lookup.lookup("vprintf"),
232                         MethodType.methodType(int.class, MemoryAddress.class, VaList.class),
233                         FunctionDescriptor.of(C_INT, C_POINTER, C_VA_LIST));
234 
235                 printfAddr = lookup.lookup("printf");
236 
237                 printfBase = FunctionDescriptor.of(C_INT, C_POINTER);
238             } catch (Throwable ex) {
239                 throw new IllegalStateException(ex);
240             }
241         }
242 
243         String strcat(String s1, String s2) throws Throwable {
244             try (MemorySegment buf = MemorySegment.allocateNative(s1.length() + s2.length() + 1) ;
245                  MemorySegment other = toCString(s2)) {
246                 char[] chars = s1.toCharArray();
247                 for (long i = 0 ; i < chars.length ; i++) {
248                     byteArrHandle.set(buf.baseAddress(), i, (byte)chars[(int)i]);
249                 }
250                 byteArrHandle.set(buf.baseAddress(), (long)chars.length, (byte)'\0');
251                 return toJavaStringRestricted(((MemoryAddress)strcat.invokeExact(buf.baseAddress(), other.baseAddress())));
252             }
253         }
254 
255         int strcmp(String s1, String s2) throws Throwable {
256             try (MemorySegment ns1 = toCString(s1) ;
257                  MemorySegment ns2 = toCString(s2)) {
258                 return (int)strcmp.invokeExact(ns1.baseAddress(), ns2.baseAddress());
259             }
260         }
261 
262         int puts(String msg) throws Throwable {
263             try (MemorySegment s = toCString(msg)) {
264                 return (int)puts.invokeExact(s.baseAddress());
265             }
266         }
267 
268         int strlen(String msg) throws Throwable {
269             try (MemorySegment s = toCString(msg)) {
270                 return (int)strlen.invokeExact(s.baseAddress());
271             }
272         }
273 
274         Tm gmtime(long arg) throws Throwable {
275             try (MemorySegment time = MemorySegment.allocateNative(8)) {
276                 longHandle.set(time.baseAddress(), arg);
277                 return new Tm((MemoryAddress)gmtime.invokeExact(time.baseAddress()));
278             }
279         }
280 
281         static class Tm {
282 
283             //Tm pointer should never be freed directly, as it points to shared memory
284             private final MemoryAddress base;
285 
286             static final long SIZE = 56;
287 
288             Tm(MemoryAddress base) {
289                 this.base = MemorySegment.ofNativeRestricted(base, SIZE, Thread.currentThread(),
290                         null, null).baseAddress();
291             }
292 
293             int sec() {
294                 return (int)intHandle.get(base);
295             }
296             int min() {
297                 return (int)intHandle.get(base.addOffset(4));
298             }
299             int hour() {
300                 return (int)intHandle.get(base.addOffset(8));
301             }
302             int mday() {
303                 return (int)intHandle.get(base.addOffset(12));
304             }
305             int mon() {
306                 return (int)intHandle.get(base.addOffset(16));
307             }
308             int year() {
309                 return (int)intHandle.get(base.addOffset(20));
310             }
311             int wday() {
312                 return (int)intHandle.get(base.addOffset(24));
313             }
314             int yday() {
315                 return (int)intHandle.get(base.addOffset(28));
316             }
317             boolean isdst() {
318                 byte b = (byte)byteHandle.get(base.addOffset(32));
319                 return b == 0 ? false : true;
320             }
321         }
322 
323         int[] qsort(int[] arr) throws Throwable {
324             //init native array
325             SequenceLayout seq = MemoryLayout.ofSequence(arr.length, C_INT);
326 
327             try (MemorySegment nativeArr = MemorySegment.allocateNative(seq)) {
328 
329                 IntStream.range(0, arr.length)
330                         .forEach(i -> intArrHandle.set(nativeArr.baseAddress(), i, arr[i]));
331 
332                 //call qsort
333                 try (MemorySegment qsortUpcallStub = abi.upcallStub(qsortCompar.bindTo(nativeArr), qsortComparFunction)) {
334                     qsort.invokeExact(nativeArr.baseAddress(), seq.elementCount().getAsLong(), C_INT.byteSize(), qsortUpcallStub.baseAddress());
335                 }
336 
337                 //convert back to Java array
338                 return LongStream.range(0, arr.length)
339                         .mapToInt(i -> (int)intArrHandle.get(nativeArr.baseAddress(), i))
340                         .toArray();
341             }
342         }
343 
344         static int qsortCompare(MemorySegment base, MemoryAddress addr1, MemoryAddress addr2) {
345             return (int)intHandle.get(addr1.rebase(base)) - (int)intHandle.get(addr2.rebase(base));
346         }
347 
348         int rand() throws Throwable {
349             return (int)rand.invokeExact();
350         }
351 
352         int printf(String format, List<PrintfArg> args) throws Throwable {
353             try (MemorySegment formatStr = toCString(format)) {
354                 return (int)specializedPrintf(args).invokeExact(formatStr.baseAddress(),
355                         args.stream().map(a -> a.nativeValue).toArray());
356             }
357         }
358 
359         int vprintf(String format, List<PrintfArg> args) throws Throwable {
360             try (MemorySegment formatStr = toCString(format)) {
361                 return (int)vprintf.invokeExact(formatStr.baseAddress(),
362                         VaList.make(b -> args.forEach(a -> a.accept(b))));
363             }
364         }
365 
366         private MethodHandle specializedPrintf(List<PrintfArg> args) {
367             //method type
368             MethodType mt = MethodType.methodType(int.class, MemoryAddress.class);
369             FunctionDescriptor fd = printfBase;
370             for (PrintfArg arg : args) {
371                 mt = mt.appendParameterTypes(arg.carrier);
372                 fd = fd.appendArgumentLayouts(arg.layout);
373             }
374             MethodHandle mh = abi.downcallHandle(printfAddr, mt, fd);
375             return mh.asSpreader(1, Object[].class, args.size());
376         }
377     }
378 
379     /*** data providers ***/
380 
381     @DataProvider
382     public static Object[][] ints() {
383         return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
384                 .map(l -> new Object[] { l })
385                 .toArray(Object[][]::new);
386     }
387 
388     @DataProvider
389     public static Object[][] strings() {
390         return perms(0, new String[] { "a", "b", "c" }).stream()
391                 .map(l -> new Object[] { String.join("", l) })
392                 .toArray(Object[][]::new);
393     }
394 
395     @DataProvider
396     public static Object[][] stringPairs() {
397         Object[][] strings = strings();
398         Object[][] stringPairs = new Object[strings.length * strings.length][];
399         int pos = 0;
400         for (Object[] s1 : strings) {
401             for (Object[] s2 : strings) {
402                 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
403             }
404         }
405         return stringPairs;
406     }
407 
408     @DataProvider
409     public static Object[][] instants() {
410         Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();
411         Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();
412         Object[][] instants = new Object[100][];
413         for (int i = 0 ; i < instants.length ; i++) {
414             Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));
415             instants[i] = new Object[] { instant };
416         }
417         return instants;
418     }
419 
420     @DataProvider
421     public static Object[][] printfArgs() {
422         ArrayList<List<PrintfArg>> res = new ArrayList<>();
423         List<List<PrintfArg>> perms = new ArrayList<>(perms(0, PrintfArg.values()));
424         for (int i = 0 ; i < 100 ; i++) {
425             Collections.shuffle(perms);
426             res.addAll(perms);
427         }
428         return res.stream()
429                 .map(l -> new Object[] { l })
430                 .toArray(Object[][]::new);
431     }
432 
433     enum PrintfArg implements Consumer<VaList.Builder> {
434 
435         INTEGRAL(int.class, asVarArg(C_INT), "%d", 42, 42, VaList.Builder::vargFromInt),
436         STRING(MemoryAddress.class, asVarArg(C_POINTER), "%s", toCString("str").baseAddress(), "str", VaList.Builder::vargFromAddress),
437         CHAR(byte.class, asVarArg(C_CHAR), "%c", (byte) 'h', 'h', (builder, layout, value) -> builder.vargFromInt(C_INT, (int)value)),
438         DOUBLE(double.class, asVarArg(C_DOUBLE), "%.4f", 1.2345d, 1.2345d, VaList.Builder::vargFromDouble);
439 
440         final Class<?> carrier;
441         final MemoryLayout layout;
442         final String format;
443         final Object nativeValue;
444         final Object javaValue;
445         @SuppressWarnings("rawtypes")
446         final VaListBuilderCall builderCall;
447 
448         <Z> PrintfArg(Class<?> carrier, MemoryLayout layout, String format, Z nativeValue, Object javaValue, VaListBuilderCall<Z> builderCall) {
449             this.carrier = carrier;
450             this.layout = layout;
451             this.format = format;
452             this.nativeValue = nativeValue;
453             this.javaValue = javaValue;
454             this.builderCall = builderCall;
455         }
456 
457         @Override
458         @SuppressWarnings("unchecked")
459         public void accept(VaList.Builder builder) {
460             builderCall.build(builder, layout, nativeValue);
461         }
462 
463         interface VaListBuilderCall<V> {
464             void build(VaList.Builder builder, MemoryLayout layout, V value);
465         }
466     }
467 
468     static <Z> Set<List<Z>> perms(int count, Z[] arr) {
469         if (count == arr.length) {
470             return Set.of(List.of());
471         } else {
472             return Arrays.stream(arr)
473                     .flatMap(num -> {
474                         Set<List<Z>> perms = perms(count + 1, arr);
475                         return Stream.concat(
476                                 //take n
477                                 perms.stream().map(l -> {
478                                     List<Z> li = new ArrayList<>(l);
479                                     li.add(num);
480                                     return li;
481                                 }),
482                                 //drop n
483                                 perms.stream());
484                     }).collect(Collectors.toCollection(LinkedHashSet::new));
485         }
486     }
487 }