1 /*
  2  * Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
  3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  4  *
  5  * This code is free software; you can redistribute it and/or modify it
  6  * under the terms of the GNU General Public License version 2 only, as
  7  * published by the Free Software Foundation.  Oracle designates this
  8  * particular file as subject to the "Classpath" exception as provided
  9  * by Oracle in the LICENSE file that accompanied this code.
 10  *
 11  * This code is distributed in the hope that it will be useful, but WITHOUT
 12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 14  * version 2 for more details (a copy is included in the LICENSE file that
 15  * accompanied this code).
 16  *
 17  * You should have received a copy of the GNU General Public License version
 18  * 2 along with this work; if not, write to the Free Software Foundation,
 19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 20  *
 21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 22  * or visit www.oracle.com if you need additional information or have any
 23  * questions.
 24  */
 25 package jdk.internal.foreign.abi;
 26 
 27 import jdk.incubator.foreign.CSupport;
 28 import jdk.incubator.foreign.ForeignLinker;
 29 import jdk.incubator.foreign.FunctionDescriptor;
 30 import jdk.incubator.foreign.GroupLayout;
 31 import jdk.incubator.foreign.MemoryAddress;
 32 import jdk.incubator.foreign.MemoryHandles;
 33 import jdk.incubator.foreign.MemoryLayout;
 34 import jdk.incubator.foreign.MemorySegment;
 35 import jdk.incubator.foreign.SequenceLayout;
 36 import jdk.incubator.foreign.ValueLayout;
 37 import jdk.internal.foreign.MemoryAddressImpl;
 38 import jdk.internal.foreign.Utils;
 39 import jdk.internal.foreign.abi.aarch64.AArch64Linker;
 40 import jdk.internal.foreign.abi.x64.sysv.SysVVaList;
 41 import jdk.internal.foreign.abi.x64.sysv.SysVx64Linker;
 42 import jdk.internal.foreign.abi.x64.windows.Windowsx64Linker;
 43 
 44 import java.lang.invoke.MethodHandle;
 45 import java.lang.invoke.MethodHandles;
 46 import java.lang.invoke.MethodType;
 47 import java.lang.invoke.VarHandle;
 48 import java.nio.ByteOrder;
 49 import java.util.List;
 50 import java.util.function.Consumer;
 51 import java.util.stream.IntStream;
 52 
 53 import static java.lang.invoke.MethodHandles.collectArguments;
 54 import static java.lang.invoke.MethodHandles.identity;
 55 import static java.lang.invoke.MethodHandles.insertArguments;
 56 import static java.lang.invoke.MethodHandles.permuteArguments;
 57 import static java.lang.invoke.MethodType.methodType;
 58 import static jdk.incubator.foreign.CSupport.*;
 59 
 60 public class SharedUtils {
 61 
 62     private static final MethodHandle MH_ALLOC_BUFFER;
 63     private static final MethodHandle MH_BASEADDRESS;
 64     private static final MethodHandle MH_BUFFER_COPY;
 65 
 66     private static final VarHandle VH_BYTE = MemoryHandles.varHandle(byte.class, ByteOrder.nativeOrder());
 67     private static final VarHandle VH_CHAR = MemoryHandles.varHandle(char.class, ByteOrder.nativeOrder());
 68     private static final VarHandle VH_SHORT = MemoryHandles.varHandle(short.class, ByteOrder.nativeOrder());
 69     private static final VarHandle VH_INT = MemoryHandles.varHandle(int.class, ByteOrder.nativeOrder());
 70     private static final VarHandle VH_LONG = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder());
 71     private static final VarHandle VH_FLOAT = MemoryHandles.varHandle(float.class, ByteOrder.nativeOrder());
 72     private static final VarHandle VH_DOUBLE = MemoryHandles.varHandle(double.class, ByteOrder.nativeOrder());
 73 
 74     static {
 75         try {
 76             var lookup = MethodHandles.lookup();
 77             MH_ALLOC_BUFFER = lookup.findStatic(SharedUtils.class, "allocateNative",
 78                     methodType(MemorySegment.class, MemoryLayout.class));
 79             MH_BASEADDRESS = lookup.findVirtual(MemorySegment.class, "baseAddress",
 80                     methodType(MemoryAddress.class));
 81             MH_BUFFER_COPY = lookup.findStatic(SharedUtils.class, "bufferCopy",
 82                     methodType(MemoryAddress.class, MemoryAddress.class, MemorySegment.class));
 83         } catch (ReflectiveOperationException e) {
 84             throw new BootstrapMethodError(e);
 85         }
 86     }
 87 
 88     // workaround for https://bugs.openjdk.java.net/browse/JDK-8239083
 89     private static MemorySegment allocateNative(MemoryLayout layout) {
 90         return MemorySegment.allocateNative(layout);
 91     }
 92 
 93     /**
 94      * Align the specified type from a given address
 95      * @return The address the data should be at based on alignment requirement
 96      */
 97     public static long align(MemoryLayout t, boolean isVar, long addr) {
 98         return alignUp(addr, alignment(t, isVar));
 99     }
100 
101     public static long alignUp(long addr, long alignment) {
102         return ((addr - 1) | (alignment - 1)) + 1;
103     }
104 
105     /**
106      * The alignment requirement for a given type
107      * @param isVar indicate if the type is a standalone variable. This change how
108      * array is aligned. for example.
109      */
110     public static long alignment(MemoryLayout t, boolean isVar) {
111         if (t instanceof ValueLayout) {
112             return alignmentOfScalar((ValueLayout) t);
113         } else if (t instanceof SequenceLayout) {
114             // when array is used alone
115             return alignmentOfArray((SequenceLayout) t, isVar);
116         } else if (t instanceof GroupLayout) {
117             return alignmentOfContainer((GroupLayout) t);
118         } else if (t.isPadding()) {
119             return 1;
120         } else {
121             throw new IllegalArgumentException("Invalid type: " + t);
122         }
123     }
124 
125     private static long alignmentOfScalar(ValueLayout st) {
126         return st.byteSize();
127     }
128 
129     private static long alignmentOfArray(SequenceLayout ar, boolean isVar) {
130         if (ar.elementCount().orElseThrow() == 0) {
131             // VLA or incomplete
132             return 16;
133         } else if ((ar.byteSize()) >= 16 && isVar) {
134             return 16;
135         } else {
136             // align as element type
137             MemoryLayout elementType = ar.elementLayout();
138             return alignment(elementType, false);
139         }
140     }
141 
142     private static long alignmentOfContainer(GroupLayout ct) {
143         // Most strict member
144         return ct.memberLayouts().stream().mapToLong(t -> alignment(t, false)).max().orElse(1);
145     }
146 
147     /**
148      * Takes a MethodHandle that takes an input buffer as a first argument (a MemoryAddress), and returns nothing,
149      * and adapts it to return a MemorySegment, by allocating a MemorySegment for the input
150      * buffer, calling the target MethodHandle, and then returning the allocated MemorySegment.
151      *
152      * This allows viewing a MethodHandle that makes use of in memory return (IMR) as a MethodHandle that just returns
153      * a MemorySegment without requiring a pre-allocated buffer as an explicit input.
154      *
155      * @param handle the target handle to adapt
156      * @param cDesc the function descriptor of the native function (with actual return layout)
157      * @return the adapted handle
158      */
159     public static MethodHandle adaptDowncallForIMR(MethodHandle handle, FunctionDescriptor cDesc) {
160         if (handle.type().returnType() != void.class)
161             throw new IllegalArgumentException("return expected to be void for in memory returns");
162         if (handle.type().parameterType(0) != MemoryAddress.class)
163             throw new IllegalArgumentException("MemoryAddress expected as first param");
164         if (cDesc.returnLayout().isEmpty())
165             throw new IllegalArgumentException("Return layout needed: " + cDesc);
166 
167         MethodHandle ret = identity(MemorySegment.class); // (MemorySegment) MemorySegment
168         handle = collectArguments(ret, 1, handle); // (MemorySegment, MemoryAddress ...) MemorySegment
169         handle = collectArguments(handle, 1, MH_BASEADDRESS); // (MemorySegment, MemorySegment ...) MemorySegment
170         MethodType oldType = handle.type(); // (MemorySegment, MemorySegment, ...) MemorySegment
171         MethodType newType = oldType.dropParameterTypes(0, 1); // (MemorySegment, ...) MemorySegment
172         int[] reorder = IntStream.range(-1, newType.parameterCount()).toArray();
173         reorder[0] = 0; // [0, 0, 1, 2, 3, ...]
174         handle = permuteArguments(handle, newType, reorder); // (MemorySegment, ...) MemoryAddress
175         handle = collectArguments(handle, 0, insertArguments(MH_ALLOC_BUFFER, 0, cDesc.returnLayout().get())); // (...) MemoryAddress
176 
177         return handle;
178     }
179 
180     /**
181      * Takes a MethodHandle that returns a MemorySegment, and adapts it to take an input buffer as a first argument
182      * (a MemoryAddress), and upon invocation, copies the contents of the returned MemorySegment into the input buffer
183      * passed as the first argument.
184      *
185      * @param target the target handle to adapt
186      * @return the adapted handle
187      */
188     public static MethodHandle adaptUpcallForIMR(MethodHandle target) {
189         if (target.type().returnType() != MemorySegment.class)
190             throw new IllegalArgumentException("Must return MemorySegment for IMR");
191 
192         target = collectArguments(MH_BUFFER_COPY, 1, target); // (MemoryAddress, ...) MemoryAddress
193 
194         return target;
195     }
196 
197     private static MemoryAddress bufferCopy(MemoryAddress dest, MemorySegment buffer) {
198         MemoryAddressImpl.ofLongUnchecked(dest.toRawLongValue(), buffer.byteSize())
199                 .segment().copyFrom(buffer);
200         return dest;
201     }
202 
203     public static void checkCompatibleType(Class<?> carrier, MemoryLayout layout, long addressSize) {
204         if (carrier.isPrimitive()) {
205             Utils.checkPrimitiveCarrierCompat(carrier, layout);
206         } else if (carrier == MemoryAddress.class) {
207             Utils.checkLayoutType(layout, ValueLayout.class);
208             if (layout.bitSize() != addressSize)
209                 throw new IllegalArgumentException("Address size mismatch: " + addressSize + " != " + layout.bitSize());
210         } else if (carrier == MemorySegment.class) {
211             Utils.checkLayoutType(layout, GroupLayout.class);
212         } else {
213             throw new IllegalArgumentException("Unsupported carrier: " + carrier);
214         }
215     }
216 
217     public static void checkFunctionTypes(MethodType mt, FunctionDescriptor cDesc, long addressSize) {
218         if (mt.returnType() == void.class != cDesc.returnLayout().isEmpty())
219             throw new IllegalArgumentException("Return type mismatch: " + mt + " != " + cDesc);
220         List<MemoryLayout> argLayouts = cDesc.argumentLayouts();
221         if (mt.parameterCount() != argLayouts.size())
222             throw new IllegalArgumentException("Arity mismatch: " + mt + " != " + cDesc);
223 
224         int paramCount = mt.parameterCount();
225         for (int i = 0; i < paramCount; i++) {
226             checkCompatibleType(mt.parameterType(i), argLayouts.get(i), addressSize);
227         }
228         cDesc.returnLayout().ifPresent(rl -> checkCompatibleType(mt.returnType(), rl, addressSize));
229     }
230 
231     public static Class<?> primitiveCarrierForSize(long size) {
232         if (size == 1) {
233             return byte.class;
234         } else if(size == 2) {
235             return short.class;
236         } else if (size <= 4) {
237             return int.class;
238         } else if (size <= 8) {
239             return long.class;
240         }
241 
242         throw new IllegalArgumentException("Size too large: " + size);
243     }
244 
245     public static ForeignLinker getSystemLinker() {
246         String arch = System.getProperty("os.arch");
247         String os = System.getProperty("os.name");
248         if (arch.equals("amd64") || arch.equals("x86_64")) {
249             if (os.startsWith("Windows")) {
250                 return Windowsx64Linker.getInstance();
251             } else {
252                 return SysVx64Linker.getInstance();
253             }
254         } else if (arch.equals("aarch64")) {
255             return AArch64Linker.getInstance();
256         }
257         throw new UnsupportedOperationException("Unsupported os or arch: " + os + ", " + arch);
258     }
259 
260     public static VaList newVaList(Consumer<VaList.Builder> actions) {
261         String name = CSupport.getSystemLinker().name();
262         return switch(name) {
263             case Win64.NAME -> Windowsx64Linker.newVaList(actions);
264             case SysV.NAME -> SysVx64Linker.newVaList(actions);
265             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
266             default -> throw new IllegalStateException("Unknown linker name: " + name);
267         };
268     }
269 
270     public static VarHandle vhPrimitiveOrAddress(Class<?> carrier, MemoryLayout layout) {
271         return carrier == MemoryAddress.class
272             ? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
273             : layout.varHandle(carrier);
274     }
275 
276     public static VaList newVaListOfAddress(MemoryAddress ma) {
277         String name = CSupport.getSystemLinker().name();
278         return switch(name) {
279             case Win64.NAME -> Windowsx64Linker.newVaListOfAddress(ma);
280             case SysV.NAME -> SysVx64Linker.newVaListOfAddress(ma);
281             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
282             default -> throw new IllegalStateException("Unknown linker name: " + name);
283         };
284     }
285 
286     public static VaList emptyVaList() {
287         String name = CSupport.getSystemLinker().name();
288         return switch(name) {
289             case Win64.NAME -> Windowsx64Linker.emptyVaList();
290             case SysV.NAME -> SysVx64Linker.emptyVaList();
291             case AArch64.NAME -> throw new UnsupportedOperationException("Not yet implemented for this platform");
292             default -> throw new IllegalStateException("Unknown linker name: " + name);
293         };
294     }
295 
296     public static MethodType convertVaListCarriers(MethodType mt, Class<?> carrier) {
297         Class<?>[] params = new Class<?>[mt.parameterCount()];
298         for (int i = 0; i < params.length; i++) {
299             Class<?> pType = mt.parameterType(i);
300             params[i] = ((pType == VaList.class) ? carrier : pType);
301         }
302         return methodType(mt.returnType(), params);
303     }
304 
305     public static MethodHandle unboxVaLists(MethodType type, MethodHandle handle, MethodHandle unboxer) {
306         for (int i = 0; i < type.parameterCount(); i++) {
307             if (type.parameterType(i) == VaList.class) {
308                handle = MethodHandles.filterArguments(handle, i, unboxer);
309             }
310         }
311         return handle;
312     }
313 
314     public static MethodHandle boxVaLists(MethodHandle handle, MethodHandle boxer) {
315         MethodType type = handle.type();
316         for (int i = 0; i < type.parameterCount(); i++) {
317             if (type.parameterType(i) == VaList.class) {
318                handle = MethodHandles.filterArguments(handle, i, boxer);
319             }
320         }
321         return handle;
322     }
323 
324     static void checkType(Class<?> actualType, Class<?> expectedType) {
325         if (expectedType != actualType) {
326             throw new IllegalArgumentException(
327                     String.format("Invalid operand type: %s. %s expected", actualType, expectedType));
328         }
329     }
330 
331     public static class SimpleVaArg {
332         public final Class<?> carrier;
333         public final MemoryLayout layout;
334         public final Object value;
335 
336         public SimpleVaArg(Class<?> carrier, MemoryLayout layout, Object value) {
337             this.carrier = carrier;
338             this.layout = layout;
339             this.value = value;
340         }
341 
342         public VarHandle varHandle() {
343             return carrier == MemoryAddress.class
344                 ? MemoryHandles.asAddressVarHandle(layout.varHandle(primitiveCarrierForSize(layout.byteSize())))
345                 : layout.varHandle(carrier);
346         }
347     }
348 
349     public static class EmptyVaList implements CSupport.VaList {
350 
351         private final MemoryAddress address;
352 
353         public EmptyVaList(MemoryAddress address) {
354             this.address = address;
355         }
356 
357         private static UnsupportedOperationException uoe() {
358             return new UnsupportedOperationException("Empty VaList");
359         }
360 
361         @Override
362         public int vargAsInt(MemoryLayout layout) {
363             throw uoe();
364         }
365 
366         @Override
367         public long vargAsLong(MemoryLayout layout) {
368             throw uoe();
369         }
370 
371         @Override
372         public double vargAsDouble(MemoryLayout layout) {
373             throw uoe();
374         }
375 
376         @Override
377         public MemoryAddress vargAsAddress(MemoryLayout layout) {
378             throw uoe();
379         }
380 
381         @Override
382         public MemorySegment vargAsSegment(MemoryLayout layout) {
383             throw uoe();
384         }
385 
386         @Override
387         public void skip(MemoryLayout... layouts) {
388             throw uoe();
389         }
390 
391         @Override
392         public boolean isAlive() {
393             return true;
394         }
395 
396         @Override
397         public void close() {
398             throw uoe();
399         }
400 
401         @Override
402         public VaList copy() {
403             return this;
404         }
405 
406         @Override
407         public MemoryAddress address() {
408             return address;
409         }
410     }
411 
412     static void writeOverSized(MemoryAddress ptr, Class<?> type, Object o) {
413         // use VH_LONG for integers to zero out the whole register in the process
414         if (type == long.class) {
415             VH_LONG.set(ptr, (long) o);
416         } else if (type == int.class) {
417             VH_LONG.set(ptr, (long) (int) o);
418         } else if (type == short.class) {
419             VH_LONG.set(ptr, (long) (short) o);
420         } else if (type == char.class) {
421             VH_LONG.set(ptr, (long) (char) o);
422         } else if (type == byte.class) {
423             VH_LONG.set(ptr, (long) (byte) o);
424         } else if (type == float.class) {
425             VH_FLOAT.set(ptr, (float) o);
426         } else if (type == double.class) {
427             VH_DOUBLE.set(ptr, (double) o);
428         } else {
429             throw new IllegalArgumentException("Unsupported carrier: " + type);
430         }
431     }
432 
433     static void write(MemoryAddress ptr, Class<?> type, Object o) {
434         if (type == long.class) {
435             VH_LONG.set(ptr, (long) o);
436         } else if (type == int.class) {
437             VH_INT.set(ptr, (int) o);
438         } else if (type == short.class) {
439             VH_SHORT.set(ptr, (short) o);
440         } else if (type == char.class) {
441             VH_CHAR.set(ptr, (char) o);
442         } else if (type == byte.class) {
443             VH_BYTE.set(ptr, (byte) o);
444         } else if (type == float.class) {
445             VH_FLOAT.set(ptr, (float) o);
446         } else if (type == double.class) {
447             VH_DOUBLE.set(ptr, (double) o);
448         } else {
449             throw new IllegalArgumentException("Unsupported carrier: " + type);
450         }
451     }
452 
453     static Object read(MemoryAddress ptr, Class<?> type) {
454         if (type == long.class) {
455             return (long) VH_LONG.get(ptr);
456         } else if (type == int.class) {
457             return (int) VH_INT.get(ptr);
458         } else if (type == short.class) {
459             return (short) VH_SHORT.get(ptr);
460         } else if (type == char.class) {
461             return (char) VH_CHAR.get(ptr);
462         } else if (type == byte.class) {
463             return (byte) VH_BYTE.get(ptr);
464         } else if (type == float.class) {
465             return (float) VH_FLOAT.get(ptr);
466         } else if (type == double.class) {
467             return (double) VH_DOUBLE.get(ptr);
468         } else {
469             throw new IllegalArgumentException("Unsupported carrier: " + type);
470         }
471     }
472 }