1 /*
  2  *  Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
  3  *  DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  *  This code is free software; you can redistribute it and/or modify it
  6  *  under the terms of the GNU General Public License version 2 only, as
  7  *  published by the Free Software Foundation.
  8  *
  9  *  This code is distributed in the hope that it will be useful, but WITHOUT
 10  *  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 11  *  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 12  *  version 2 for more details (a copy is included in the LICENSE file that
 13  *  accompanied this code).
 14  *
 15  *  You should have received a copy of the GNU General Public License version
 16  *  2 along with this work; if not, write to the Free Software Foundation,
 17  *  Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 18  *
 19  *   Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 20  *  or visit www.oracle.com if you need additional information or have any
 21  *  questions.
 22  *
 23  */
 24 
 25 /*
 26  * @test
 27  * @run testng/othervm -Dforeign.restricted=permit VaListTest
 28  */
 29 
 30 import jdk.incubator.foreign.CSupport;
 31 import jdk.incubator.foreign.CSupport.VaList;
 32 import jdk.incubator.foreign.ForeignLinker;
 33 import jdk.incubator.foreign.FunctionDescriptor;
 34 import jdk.incubator.foreign.GroupLayout;
 35 import jdk.incubator.foreign.LibraryLookup;
 36 import jdk.incubator.foreign.MemoryAddress;
 37 import jdk.incubator.foreign.MemoryLayout;
 38 import jdk.incubator.foreign.MemorySegment;
 39 import org.testng.annotations.DataProvider;
 40 import org.testng.annotations.Test;
 41 
 42 import java.lang.invoke.MethodHandle;
 43 import java.lang.invoke.MethodHandles;
 44 import java.lang.invoke.MethodType;
 45 import java.lang.invoke.VarHandle;
 46 
 47 import static jdk.incubator.foreign.CSupport.C_DOUBLE;
 48 import static jdk.incubator.foreign.CSupport.C_INT;
 49 import static jdk.incubator.foreign.CSupport.C_LONGLONG;
 50 import static jdk.incubator.foreign.CSupport.C_POINTER;
 51 import static jdk.incubator.foreign.CSupport.C_VA_LIST;
 52 import static jdk.incubator.foreign.CSupport.Win64.asVarArg;
 53 import static jdk.incubator.foreign.MemoryLayout.PathElement.groupElement;
 54 import static org.testng.Assert.assertEquals;
 55 
 56 public class VaListTest {
 57 
 58     private static final ForeignLinker abi = CSupport.getSystemLinker();
 59     private static final LibraryLookup lookup = LibraryLookup.ofLibrary("VaList");
 60 
 61     private static final VarHandle VH_int = C_INT.varHandle(int.class);
 62 
 63     private static final MethodHandle MH_sumInts = link("sumInts",
 64             MethodType.methodType(int.class, int.class, VaList.class),
 65             FunctionDescriptor.of(C_INT, C_INT, CSupport.C_VA_LIST));
 66     private static final MethodHandle MH_sumDoubles = link("sumDoubles",
 67             MethodType.methodType(double.class, int.class, VaList.class),
 68             FunctionDescriptor.of(C_DOUBLE, C_INT, CSupport.C_VA_LIST));
 69     private static final MethodHandle MH_getInt = link("getInt",
 70             MethodType.methodType(int.class, VaList.class),
 71             FunctionDescriptor.of(C_INT, C_VA_LIST));
 72     private static final MethodHandle MH_sumStruct = link("sumStruct",
 73             MethodType.methodType(int.class, VaList.class),
 74             FunctionDescriptor.of(C_INT, C_VA_LIST));
 75     private static final MethodHandle MH_sumBigStruct = link("sumBigStruct",
 76             MethodType.methodType(long.class, VaList.class),
 77             FunctionDescriptor.of(C_LONGLONG, C_VA_LIST));
 78     private static final MethodHandle MH_sumStack = link("sumStack",
 79             MethodType.methodType(void.class, MemoryAddress.class, MemoryAddress.class, int.class,
 80                 long.class, long.class, long.class, long.class,
 81                 long.class, long.class, long.class, long.class,
 82                 long.class, long.class, long.class, long.class,
 83                 long.class, long.class, long.class, long.class,
 84                 double.class, double.class, double.class, double.class,
 85                 double.class, double.class, double.class, double.class,
 86                 double.class, double.class, double.class, double.class,
 87                 double.class, double.class, double.class, double.class
 88             ),
 89             FunctionDescriptor.ofVoid(C_POINTER, C_POINTER, C_INT,
 90                 asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG),
 91                 asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG),
 92                 asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG),
 93                 asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG), asVarArg(C_LONGLONG),
 94                 asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE),
 95                 asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE),
 96                 asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE),
 97                 asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE), asVarArg(C_DOUBLE)
 98             ));
 99 
100     private static final VarHandle VH_long = C_LONGLONG.varHandle(long.class);
101     private static final VarHandle VH_double = C_DOUBLE.varHandle(double.class);
102 
103     private static MethodHandle link(String symbol, MethodType mt, FunctionDescriptor fd) {
104         try {
105             return abi.downcallHandle(lookup.lookup(symbol), mt, fd);
106         } catch (NoSuchMethodException e) {
107             throw new NoSuchMethodError(e.getMessage());
108         }
109     }
110 
111     private static MethodHandle linkVaListCB(String symbol) {
112         return link(symbol,
113             MethodType.methodType(void.class, MemoryAddress.class),
114             FunctionDescriptor.ofVoid(C_POINTER));
115 
116     }
117 
118     private static final GroupLayout BigPoint_LAYOUT = MemoryLayout.ofStruct(
119         C_LONGLONG.withName("x"),
120         C_LONGLONG.withName("y")
121     );
122     private static final VarHandle VH_BigPoint_x = BigPoint_LAYOUT.varHandle(long.class, groupElement("x"));
123     private static final VarHandle VH_BigPoint_y = BigPoint_LAYOUT.varHandle(long.class, groupElement("y"));
124     private static final GroupLayout Point_LAYOUT = MemoryLayout.ofStruct(
125         C_INT.withName("x"),
126         C_INT.withName("y")
127     );
128     private static final VarHandle VH_Point_x = Point_LAYOUT.varHandle(int.class, groupElement("x"));
129     private static final VarHandle VH_Point_y = Point_LAYOUT.varHandle(int.class, groupElement("y"));
130 
131     @Test
132     public void testIntSum() throws Throwable {
133         try (VaList vaList = VaList.make(b ->
134                 b.vargFromInt(C_INT, 10)
135                  .vargFromInt(C_INT, 15)
136                  .vargFromInt(C_INT, 20))) {
137             int x = (int) MH_sumInts.invokeExact(3, vaList);
138             assertEquals(x, 45);
139         }
140     }
141 
142     @Test
143     public void testDoubleSum() throws Throwable {
144         try (VaList vaList = VaList.make(b ->
145                 b.vargFromDouble(C_DOUBLE, 3.0D)
146                  .vargFromDouble(C_DOUBLE, 4.0D)
147                  .vargFromDouble(C_DOUBLE, 5.0D))) {
148             double x = (double) MH_sumDoubles.invokeExact(3, vaList);
149             assertEquals(x, 12.0D);
150         }
151     }
152 
153     @Test
154     public void testVaListMemoryAddress() throws Throwable {
155         try (MemorySegment msInt = MemorySegment.allocateNative(C_INT)) {
156             VH_int.set(msInt.baseAddress(), 10);
157             try (VaList vaList = VaList.make(b -> b.vargFromAddress(C_POINTER, msInt.baseAddress()))) {
158                 int x = (int) MH_getInt.invokeExact(vaList);
159                 assertEquals(x, 10);
160             }
161         }
162     }
163 
164     @Test
165     public void testWinStructByValue() throws Throwable {
166         try (MemorySegment struct = MemorySegment.allocateNative(Point_LAYOUT)) {
167             VH_Point_x.set(struct.baseAddress(), 5);
168             VH_Point_y.set(struct.baseAddress(), 10);
169 
170             try (VaList vaList = VaList.make(b -> b.vargFromSegment(Point_LAYOUT, struct))) {
171                 int sum = (int) MH_sumStruct.invokeExact(vaList);
172                 assertEquals(sum, 15);
173             }
174         }
175     }
176 
177     @Test
178     public void testWinStructByReference() throws Throwable {
179         try (MemorySegment struct = MemorySegment.allocateNative(BigPoint_LAYOUT)) {
180             VH_BigPoint_x.set(struct.baseAddress(), 5);
181             VH_BigPoint_y.set(struct.baseAddress(), 10);
182 
183             try (VaList vaList = VaList.make(b -> b.vargFromSegment(BigPoint_LAYOUT, struct))) {
184                 long sum = (long) MH_sumBigStruct.invokeExact(vaList);
185                 assertEquals(sum, 15);
186             }
187         }
188     }
189 
190     @Test
191     public void testStack() throws Throwable {
192        try (MemorySegment longSum = MemorySegment.allocateNative(C_LONGLONG);
193             MemorySegment doubleSum = MemorySegment.allocateNative(C_DOUBLE)) {
194             VH_long.set(longSum.baseAddress(), 0L);
195             VH_double.set(doubleSum.baseAddress(), 0D);
196 
197             MH_sumStack.invokeExact(longSum.baseAddress(), doubleSum.baseAddress(), 32,
198                 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
199                 1D, 2D, 3D, 4D, 5D, 6D, 7D, 8D, 9D, 10D, 11D, 12D, 13D, 14D, 15D, 16D);
200 
201             long lSum = (long) VH_long.get(longSum.baseAddress());
202             double dSum = (double) VH_double.get(doubleSum.baseAddress());
203 
204             assertEquals(lSum, 136L);
205             assertEquals(dSum, 136D);
206         }
207     }
208 
209     @Test(dataProvider = "upcalls")
210     public void testUpcall(MethodHandle target, MethodHandle callback) throws Throwable {
211         FunctionDescriptor desc = FunctionDescriptor.ofVoid(C_VA_LIST);
212         try (MemorySegment stub = abi.upcallStub(callback, desc)) {
213             target.invokeExact(stub.baseAddress());
214         }
215     }
216 
217     @DataProvider
218     public static Object[][] upcalls() {
219         return new Object[][]{
220             { linkVaListCB("upcallBigStruct"), VaListConsumer.mh(vaList -> {
221                 try (MemorySegment struct = vaList.vargAsSegment(BigPoint_LAYOUT)) {
222                     assertEquals((long) VH_BigPoint_x.get(struct.baseAddress()), 8);
223                     assertEquals((long) VH_BigPoint_y.get(struct.baseAddress()), 16);
224                 }
225             })},
226             { linkVaListCB("upcallBigStruct"), VaListConsumer.mh(vaList -> {
227                 VaList copy = vaList.copy();
228                 try (MemorySegment struct = vaList.vargAsSegment(BigPoint_LAYOUT)) {
229                     assertEquals((long) VH_BigPoint_x.get(struct.baseAddress()), 8);
230                     assertEquals((long) VH_BigPoint_y.get(struct.baseAddress()), 16);
231 
232                     VH_BigPoint_x.set(struct.baseAddress(), 0);
233                     VH_BigPoint_y.set(struct.baseAddress(), 0);
234                 }
235 
236                 // should be independent
237                 try (MemorySegment struct = copy.vargAsSegment(BigPoint_LAYOUT)) {
238                     assertEquals((long) VH_BigPoint_x.get(struct.baseAddress()), 8);
239                     assertEquals((long) VH_BigPoint_y.get(struct.baseAddress()), 16);
240                 }
241             })},
242             { linkVaListCB("upcallStruct"), VaListConsumer.mh(vaList -> {
243                 try (MemorySegment struct = vaList.vargAsSegment(Point_LAYOUT)) {
244                     assertEquals((int) VH_Point_x.get(struct.baseAddress()), 5);
245                     assertEquals((int) VH_Point_y.get(struct.baseAddress()), 10);
246                 }
247             })},
248             { linkVaListCB("upcallMemoryAddress"), VaListConsumer.mh(vaList -> {
249                 MemoryAddress intPtr = vaList.vargAsAddress(C_POINTER);
250                 MemorySegment ms = MemorySegment.ofNativeRestricted(intPtr, C_INT.byteSize(),
251                                                                     Thread.currentThread(), null, null);
252                 int x = (int) VH_int.get(ms.baseAddress());
253                 assertEquals(x, 10);
254             })},
255             { linkVaListCB("upcallDoubles"), VaListConsumer.mh(vaList -> {
256                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 3.0);
257                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 4.0);
258                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 5.0);
259             })},
260             { linkVaListCB("upcallInts"), VaListConsumer.mh(vaList -> {
261                 assertEquals(vaList.vargAsInt(C_INT), 10);
262                 assertEquals(vaList.vargAsInt(C_INT), 15);
263                 assertEquals(vaList.vargAsInt(C_INT), 20);
264             })},
265             { linkVaListCB("upcallStack"), VaListConsumer.mh(vaList -> {
266                 // skip all registers
267                 assertEquals(vaList.vargAsLong(C_LONGLONG), 1L); // 1st windows arg read from shadow space
268                 assertEquals(vaList.vargAsLong(C_LONGLONG), 2L); // 2nd windows arg read from shadow space
269                 assertEquals(vaList.vargAsLong(C_LONGLONG), 3L); // windows 1st stack arg (int/float)
270                 assertEquals(vaList.vargAsLong(C_LONGLONG), 4L);
271                 assertEquals(vaList.vargAsLong(C_LONGLONG), 5L);
272                 assertEquals(vaList.vargAsLong(C_LONGLONG), 6L);
273                 assertEquals(vaList.vargAsLong(C_LONGLONG), 7L); // sysv 1st int stack arg
274                 assertEquals(vaList.vargAsLong(C_LONGLONG), 8L);
275                 assertEquals(vaList.vargAsLong(C_LONGLONG), 9L);
276                 assertEquals(vaList.vargAsLong(C_LONGLONG), 10L);
277                 assertEquals(vaList.vargAsLong(C_LONGLONG), 11L);
278                 assertEquals(vaList.vargAsLong(C_LONGLONG), 12L);
279                 assertEquals(vaList.vargAsLong(C_LONGLONG), 13L);
280                 assertEquals(vaList.vargAsLong(C_LONGLONG), 14L);
281                 assertEquals(vaList.vargAsLong(C_LONGLONG), 15L);
282                 assertEquals(vaList.vargAsLong(C_LONGLONG), 16L);
283                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 1.0D);
284                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 2.0D);
285                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 3.0D);
286                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 4.0D);
287                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 5.0D);
288                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 6.0D);
289                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 7.0D);
290                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 8.0D);
291                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 9.0D); // sysv 1st float stack arg
292                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 10.0D);
293                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 11.0D);
294                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 12.0D);
295                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 13.0D);
296                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 14.0D);
297                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 15.0D);
298                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 16.0D);
299 
300                 // test some arbitrary values on the stack
301                 assertEquals((byte) vaList.vargAsInt(C_INT), (byte) 1);
302                 assertEquals((char) vaList.vargAsInt(C_INT), 'a');
303                 assertEquals((short) vaList.vargAsInt(C_INT), (short) 3);
304                 assertEquals(vaList.vargAsInt(C_INT), 4);
305                 assertEquals(vaList.vargAsLong(C_LONGLONG), 5L);
306                 assertEquals((float) vaList.vargAsDouble(C_DOUBLE), 6.0F);
307                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 7.0D);
308                 assertEquals((byte) vaList.vargAsInt(C_INT), (byte) 8);
309                 assertEquals((char) vaList.vargAsInt(C_INT), 'b');
310                 assertEquals((short) vaList.vargAsInt(C_INT), (short) 10);
311                 assertEquals(vaList.vargAsInt(C_INT), 11);
312                 assertEquals(vaList.vargAsLong(C_LONGLONG), 12L);
313                 assertEquals((float) vaList.vargAsDouble(C_DOUBLE), 13.0F);
314                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 14.0D);
315 
316                 try (MemorySegment point = vaList.vargAsSegment(Point_LAYOUT)) {
317                     assertEquals((int) VH_Point_x.get(point.baseAddress()), 5);
318                     assertEquals((int) VH_Point_y.get(point.baseAddress()), 10);
319                 }
320 
321                 VaList copy = vaList.copy();
322                 try (MemorySegment bigPoint = vaList.vargAsSegment(BigPoint_LAYOUT)) {
323                     assertEquals((long) VH_BigPoint_x.get(bigPoint.baseAddress()), 15);
324                     assertEquals((long) VH_BigPoint_y.get(bigPoint.baseAddress()), 20);
325 
326                     VH_BigPoint_x.set(bigPoint.baseAddress(), 0);
327                     VH_BigPoint_y.set(bigPoint.baseAddress(), 0);
328                 }
329 
330                 // should be independent
331                 try (MemorySegment struct = copy.vargAsSegment(BigPoint_LAYOUT)) {
332                     assertEquals((long) VH_BigPoint_x.get(struct.baseAddress()), 15);
333                     assertEquals((long) VH_BigPoint_y.get(struct.baseAddress()), 20);
334                 }
335             })},
336             // test skip
337             { linkVaListCB("upcallStack"), VaListConsumer.mh(vaList -> {
338                 vaList.skip(C_LONGLONG, C_LONGLONG, C_LONGLONG, C_LONGLONG);
339                 assertEquals(vaList.vargAsLong(C_LONGLONG), 5L);
340                 vaList.skip(C_LONGLONG, C_LONGLONG, C_LONGLONG, C_LONGLONG);
341                 assertEquals(vaList.vargAsLong(C_LONGLONG), 10L);
342                 vaList.skip(C_LONGLONG, C_LONGLONG, C_LONGLONG, C_LONGLONG, C_LONGLONG, C_LONGLONG);
343                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 1.0D);
344                 vaList.skip(C_DOUBLE, C_DOUBLE, C_DOUBLE, C_DOUBLE);
345                 assertEquals(vaList.vargAsDouble(C_DOUBLE), 6.0D);
346             })},
347         };
348     }
349 
350     interface VaListConsumer {
351         void accept(CSupport.VaList list);
352 
353         static MethodHandle mh(VaListConsumer instance) {
354             try {
355                 return MethodHandles.lookup().findVirtual(VaListConsumer.class, "accept",
356                     MethodType.methodType(void.class, VaList.class)).bindTo(instance);
357             } catch (ReflectiveOperationException e) {
358                 throw new InternalError(e);
359             }
360         }
361     }
362 
363 }