1 /*
  2  * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi;
 26 
 27 import jdk.incubator.foreign.CSupport;
 28 import jdk.incubator.foreign.ForeignLinker;
 29 import jdk.incubator.foreign.FunctionDescriptor;
 30 import jdk.incubator.foreign.GroupLayout;
 31 import jdk.incubator.foreign.MemoryAddress;
 32 import jdk.incubator.foreign.MemoryHandles;
 33 import jdk.incubator.foreign.MemoryLayout;
 34 import jdk.incubator.foreign.MemorySegment;
 35 import jdk.incubator.foreign.SequenceLayout;
 36 import jdk.incubator.foreign.ValueLayout;
 37 import jdk.internal.foreign.MemoryAddressImpl;
 38 import jdk.internal.foreign.Utils;
 39 import jdk.internal.foreign.abi.aarch64.AArch64Linker;
 40 import jdk.internal.foreign.abi.x64.sysv.SysVVaList;
 41 import jdk.internal.foreign.abi.x64.sysv.SysVx64Linker;
 42 import jdk.internal.foreign.abi.x64.windows.Windowsx64Linker;
 43 
 44 import java.lang.invoke.MethodHandle;
 45 import java.lang.invoke.MethodHandles;
 46 import java.lang.invoke.MethodType;
 47 import java.lang.invoke.VarHandle;
 48 import java.nio.ByteOrder;
 49 import java.util.List;
 50 import java.util.function.Consumer;
 51 import java.util.stream.IntStream;
 52 
 53 import static java.lang.invoke.MethodHandles.collectArguments;
 54 import static java.lang.invoke.MethodHandles.identity;
 55 import static java.lang.invoke.MethodHandles.insertArguments;
 56 import static java.lang.invoke.MethodHandles.permuteArguments;
 57 import static java.lang.invoke.MethodType.methodType;
 58 import static jdk.incubator.foreign.CSupport.*;
 59 
 60 public class SharedUtils {
 61 
 62     private static final MethodHandle MH_ALLOC_BUFFER;
 63     private static final MethodHandle MH_BASEADDRESS;
 64     private static final MethodHandle MH_BUFFER_COPY;
 65 
 66     private static final VarHandle VH_BYTE = MemoryHandles.varHandle(byte.class, ByteOrder.nativeOrder());
 67     private static final VarHandle VH_CHAR = MemoryHandles.varHandle(char.class, ByteOrder.nativeOrder());
 68     private static final VarHandle VH_SHORT = MemoryHandles.varHandle(short.class, ByteOrder.nativeOrder());
 69     private static final VarHandle VH_INT = MemoryHandles.varHandle(int.class, ByteOrder.nativeOrder());
 70     private static final VarHandle VH_LONG = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder());
 71     private static final VarHandle VH_FLOAT = MemoryHandles.varHandle(float.class, ByteOrder.nativeOrder());
 72     private static final VarHandle VH_DOUBLE = MemoryHandles.varHandle(double.class, ByteOrder.nativeOrder());
 73 
 74     static {
 75         try {
 76             var lookup = MethodHandles.lookup();
 77             MH_ALLOC_BUFFER = lookup.findStatic(SharedUtils.class, "allocateNative",
 78                     methodType(MemorySegment.class, MemoryLayout.class));
 79             MH_BASEADDRESS = lookup.findVirtual(MemorySegment.class, "baseAddress",
 80                     methodType(MemoryAddress.class));
 81             MH_BUFFER_COPY = lookup.findStatic(SharedUtils.class, "bufferCopy",
 82                     methodType(MemoryAddress.class, MemoryAddress.class, MemorySegment.class));
 83         } catch (ReflectiveOperationException e) {
 84             throw new BootstrapMethodError(e);
 85         }
 86     }
 87 
 88     // workaround for https://bugs.openjdk.java.net/browse/JDK-8239083
 89     private static MemorySegment allocateNative(MemoryLayout layout) {
 90         return MemorySegment.allocateNative(layout);
 91     }
 92 
 93     /**
 94      * Align the specified type from a given address
 95      * @return The address the data should be at based on alignment requirement
 96      */
 97     public static long align(MemoryLayout t, boolean isVar, long addr) {
 98         return alignUp(addr, alignment(t, isVar));
 99     }
100 
101     public static long alignUp(long addr, long alignment) {
102         return ((addr - 1) | (alignment - 1)) + 1;
103     }
104 
105     /**
106      * The alignment requirement for a given type
107      * @param isVar indicate if the type is a standalone variable. This change how
108      * array is aligned. for example.
109      */
110     public static long alignment(MemoryLayout t, boolean isVar) {
111         if (t instanceof ValueLayout) {
112             return alignmentOfScalar((ValueLayout) t);
113         } else if (t instanceof SequenceLayout) {
114             // when array is used alone
115             return alignmentOfArray((SequenceLayout) t, isVar);
116         } else if (t instanceof GroupLayout) {
117             return alignmentOfContainer((GroupLayout) t);
118         } else if (t.isPadding()) {
119             return 1;
120         } else {
121             throw new IllegalArgumentException("Invalid type: " + t);
122         }
123     }
124 
125     private static long alignmentOfScalar(ValueLayout st) {
126         return st.byteSize();
127     }
128 
129     private static long alignmentOfArray(SequenceLayout ar, boolean isVar) {
130         if (ar.elementCount().orElseThrow() == 0) {
131             // VLA or incomplete
132             return 16;
133         } else if ((ar.byteSize()) >= 16 && isVar) {
134             return 16;
135         } else {
136             // align as element type
137             MemoryLayout elementType = ar.elementLayout();
138             return alignment(elementType, false);
139         }
140     }
141 
142     private static long alignmentOfContainer(GroupLayout ct) {
143         // Most strict member
144         return ct.memberLayouts().stream().mapToLong(t -> alignment(t, false)).max().orElse(1);
145     }
146 
147     /**
148      * Takes a MethodHandle that takes an input buffer as a first argument (a MemoryAddress), and returns nothing,
149      * and adapts it to return a MemorySegment, by allocating a MemorySegment for the input
150      * buffer, calling the target MethodHandle, and then returning the allocated MemorySegment.
151      *
152      * This allows viewing a MethodHandle that makes use of in memory return (IMR) as a MethodHandle that just returns
153      * a MemorySegment without requiring a pre-allocated buffer as an explicit input.
154      *
155      * @param handle the target handle to adapt
156      * @param cDesc the function descriptor of the native function (with actual return layout)
157      * @return the adapted handle
158      */
159     public static MethodHandle adaptDowncallForIMR(MethodHandle handle, FunctionDescriptor cDesc) {
160         if (handle.type().returnType() != void.class)
161             throw new IllegalArgumentException("return expected to be void for in memory returns");
162         if (handle.type().parameterType(0) != MemoryAddress.class)
163             throw new IllegalArgumentException("MemoryAddress expected as first param");
164         if (cDesc.returnLayout().isEmpty())
165             throw new IllegalArgumentException("Return layout needed: " + cDesc);
166 
167         MethodHandle ret = identity(MemorySegment.class); // (MemorySegment) MemorySegment
168         handle = collectArguments(ret, 1, handle); // (MemorySegment, MemoryAddress ...) MemorySegment
169         handle = collectArguments(handle, 1, MH_BASEADDRESS); // (MemorySegment, MemorySegment ...) MemorySegment
170         MethodType oldType = handle.type(); // (MemorySegment, MemorySegment, ...) MemorySegment
171         MethodType newType = oldType.dropParameterTypes(0, 1); // (MemorySegment, ...) MemorySegment
172         int[] reorder = IntStream.range(-1, newType.parameterCount()).toArray();
173         reorder[0] = 0; // [0, 0, 1, 2, 3, ...]
174         handle = permuteArguments(handle, newType, reorder); // (MemorySegment, ...) MemoryAddress
175         handle = collectArguments(handle, 0, insertArguments(MH_ALLOC_BUFFER, 0, cDesc.returnLayout().get())); // (...) MemoryAddress
176 
177         return handle;
178     }
179 
180     /**
181      * Takes a MethodHandle that returns a MemorySegment, and adapts it to take an input buffer as a first argument
182      * (a MemoryAddress), and upon invocation, copies the contents of the returned MemorySegment into the input buffer
183      * passed as the first argument.
184      *
185      * @param target the target handle to adapt
186      * @return the adapted handle
187      */
188     public static MethodHandle adaptUpcallForIMR(MethodHandle target) {
189         if (target.type().returnType() != MemorySegment.class)
190             throw new IllegalArgumentException("Must return MemorySegment for IMR");
191 
192         target = collectArguments(MH_BUFFER_COPY, 1, target); // (MemoryAddress, ...) MemoryAddress
193 
194         return target;
195     }
196 
197     private static MemoryAddress bufferCopy(MemoryAddress dest, MemorySegment buffer) {
198         MemoryAddressImpl.ofLongUnchecked(dest.toRawLongValue(), buffer.byteSize())
199                 .segment().copyFrom(buffer);
200         return dest;
201     }
202 
203     public static void checkCompatibleType(Class<?> carrier, MemoryLayout layout, long addressSize) {
204         if (carrier.isPrimitive()) {
205             Utils.checkPrimitiveCarrierCompat(carrier, layout);
206         } else if (carrier == MemoryAddress.class) {
207             Utils.checkLayoutType(layout, ValueLayout.class);
208             if (layout.bitSize() != addressSize)
209                 throw new IllegalArgumentException("Address size mismatch: " + addressSize + " != " + layout.bitSize());
210         } else if (carrier == MemorySegment.class) {
211             Utils.checkLayoutType(layout, GroupLayout.class);
212         } else {
213             throw new IllegalArgumentException("Unsupported carrier: " + carrier);
214         }
215     }
216 
217     public static void checkFunctionTypes(MethodType mt, FunctionDescriptor cDesc, long addressSize) {
218         if (mt.returnType() == void.class != cDesc.returnLayout().isEmpty())
219             throw new IllegalArgumentException("Return type mismatch: " + mt + " != " + cDesc);
220         List<MemoryLayout> argLayouts = cDesc.argumentLayouts();
221         if (mt.parameterCount() != argLayouts.size())
222             throw new IllegalArgumentException("Arity mismatch: " + mt + " != " + cDesc);
223 
224         int paramCount = mt.parameterCount();
225         for (int i = 0; i < paramCount; i++) {
226             checkCompatibleType(mt.parameterType(i), argLayouts.get(i), addressSize);
227         }
228         cDesc.returnLayout().ifPresent(rl -> checkCompatibleType(mt.returnType(), rl, addressSize));
229     }
230 
231     public static Class<?> primitiveCarrierForSize(long size) {
232         if (size == 1) {
233             return byte.class;
234         } else if(size == 2) {
235             return short.class;
236         } else if (size <= 4) {
237             return int.class;
238         } else if (size <= 8) {
239             return long.class;
240         }
241 
242         throw new IllegalArgumentException("Size too large: " + size);
243     }
244 
245     public static ForeignLinker getSystemLinker() {
246         String arch = System.getProperty("os.arch");
247         String os = System.getProperty("os.name");
248         if (arch.equals("amd64") || arch.equals("x86_64")) {
249             if (os.startsWith("Windows")) {
250                 return Windowsx64Linker.getInstance();
251             } else {
252                 return SysVx64Linker.getInstance();
253             }
254         } else if (arch.equals("aarch64")) {
255             return AArch64Linker.getInstance();
256         }
257         throw new UnsupportedOperationException("Unsupported os or arch: " + os + ", " + arch);
258     }
259 
260     public static VaList newVaList(Consumer<VaList.Builder> actions) {
261         String name = CSupport.getSystemLinker().name();
262         return switch(name) {
263             case Win64.NAME -> Windowsx64Linker.newVaList(actions);
264             case SysV.NAME -> SysVx64Linker.newVaList(actions);
265             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
266             default -> throw new IllegalStateException("Unknown linker name: " + name);
267         };
268     }
269 
270     public static VarHandle vhPrimitiveOrAddress(Class<?> carrier, MemoryLayout layout) {
271         return carrier == MemoryAddress.class
272             ? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
273             : layout.varHandle(carrier);
274     }
275 
276     public static VaList newVaListOfAddress(MemoryAddress ma) {
277         String name = CSupport.getSystemLinker().name();
278         return switch(name) {
279             case Win64.NAME -> Windowsx64Linker.newVaListOfAddress(ma);
280             case SysV.NAME -> SysVx64Linker.newVaListOfAddress(ma);
281             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
282             default -> throw new IllegalStateException("Unknown linker name: " + name);
283         };
284     }
285 
286     public static VaList emptyVaList() {
287         String name = CSupport.getSystemLinker().name();
288         return switch(name) {
289             case Win64.NAME -> Windowsx64Linker.emptyVaList();
290             case SysV.NAME -> SysVx64Linker.emptyVaList();
291             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
292             default -> throw new IllegalStateException("Unknown linker name: " + name);
293         };
294     }
295 
296     public static MethodType convertVaListCarriers(MethodType mt, Class<?> carrier) {
297         Class<?>[] params = new Class<?>[mt.parameterCount()];
298         for (int i = 0; i < params.length; i++) {
299             Class<?> pType = mt.parameterType(i);
300             params[i] = ((pType == VaList.class) ? carrier : pType);
301         }
302         return methodType(mt.returnType(), params);
303     }
304 
305     public static MethodHandle unboxVaLists(MethodType type, MethodHandle handle, MethodHandle unboxer) {
306         for (int i = 0; i < type.parameterCount(); i++) {
307             if (type.parameterType(i) == VaList.class) {
308                handle = MethodHandles.filterArguments(handle, i, unboxer);
309             }
310         }
311         return handle;
312     }
313 
314     public static MethodHandle boxVaLists(MethodHandle handle, MethodHandle boxer) {
315         MethodType type = handle.type();
316         for (int i = 0; i < type.parameterCount(); i++) {
317             if (type.parameterType(i) == VaList.class) {
318                handle = MethodHandles.filterArguments(handle, i, boxer);
319             }
320         }
321         return handle;
322     }
323 
324     public static class SimpleVaArg {
325         public final Class<?> carrier;
326         public final MemoryLayout layout;
327         public final Object value;
328 
329         public SimpleVaArg(Class<?> carrier, MemoryLayout layout, Object value) {
330             this.carrier = carrier;
331             this.layout = layout;
332             this.value = value;
333         }
334 
335         public VarHandle varHandle() {
336             return carrier == MemoryAddress.class
337                 ? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
338                 : layout.varHandle(carrier);
339         }
340     }
341 
342     public static class EmptyVaList implements CSupport.VaList {
343 
344         private final MemoryAddress address;
345 
346         public EmptyVaList(MemoryAddress address) {
347             this.address = address;
348         }
349 
350         private static UnsupportedOperationException uoe() {
351             return new UnsupportedOperationException("Empty VaList");
352         }
353 
354         @Override
355         public int vargAsInt(MemoryLayout layout) {
356             throw uoe();
357         }
358 
359         @Override
360         public long vargAsLong(MemoryLayout layout) {
361             throw uoe();
362         }
363 
364         @Override
365         public double vargAsDouble(MemoryLayout layout) {
366             throw uoe();
367         }
368 
369         @Override
370         public MemoryAddress vargAsAddress(MemoryLayout layout) {
371             throw uoe();
372         }
373 
374         @Override
375         public MemorySegment vargAsSegment(MemoryLayout layout) {
376             throw uoe();
377         }
378 
379         @Override
380         public void skip(MemoryLayout... layouts) {
381             throw uoe();
382         }
383 
384         @Override
385         public boolean isAlive() {
386             return true;
387         }
388 
389         @Override
390         public void close() {
391             throw uoe();
392         }
393 
394         @Override
395         public VaList copy() {
396             return this;
397         }
398 
399         @Override
400         public MemoryAddress address() {
401             return address;
402         }
403     }
404 
405     static void writeOverSized(MemoryAddress ptr, Class<?> type, Object o) {
406         // use VH_LONG for integers to zero out the whole register in the process
407         if (type == long.class) {
408             VH_LONG.set(ptr, (long) o);
409         } else if (type == int.class) {
410             VH_LONG.set(ptr, (long) (int) o);
411         } else if (type == short.class) {
412             VH_LONG.set(ptr, (long) (short) o);
413         } else if (type == char.class) {
414             VH_LONG.set(ptr, (long) (char) o);
415         } else if (type == byte.class) {
416             VH_LONG.set(ptr, (long) (byte) o);
417         } else if (type == float.class) {
418             VH_FLOAT.set(ptr, (float) o);
419         } else if (type == double.class) {
420             VH_DOUBLE.set(ptr, (double) o);
421         } else {
422             throw new IllegalArgumentException("Unsupported carrier: " + type);
423         }
424     }
425 
426     static void write(MemoryAddress ptr, Class<?> type, Object o) {
427         if (type == long.class) {
428             VH_LONG.set(ptr, (long) o);
429         } else if (type == int.class) {
430             VH_INT.set(ptr, (int) o);
431         } else if (type == short.class) {
432             VH_SHORT.set(ptr, (short) o);
433         } else if (type == char.class) {
434             VH_CHAR.set(ptr, (char) o);
435         } else if (type == byte.class) {
436             VH_BYTE.set(ptr, (byte) o);
437         } else if (type == float.class) {
438             VH_FLOAT.set(ptr, (float) o);
439         } else if (type == double.class) {
440             VH_DOUBLE.set(ptr, (double) o);
441         } else {
442             throw new IllegalArgumentException("Unsupported carrier: " + type);
443         }
444     }
445 
446     static Object read(MemoryAddress ptr, Class<?> type) {
447         if (type == long.class) {
448             return (long) VH_LONG.get(ptr);
449         } else if (type == int.class) {
450             return (int) VH_INT.get(ptr);
451         } else if (type == short.class) {
452             return (short) VH_SHORT.get(ptr);
453         } else if (type == char.class) {
454             return (char) VH_CHAR.get(ptr);
455         } else if (type == byte.class) {
456             return (byte) VH_BYTE.get(ptr);
457         } else if (type == float.class) {
458             return (float) VH_FLOAT.get(ptr);
459         } else if (type == double.class) {
460             return (double) VH_DOUBLE.get(ptr);
461         } else {
462             throw new IllegalArgumentException("Unsupported carrier: " + type);
463         }
464     }
465 }