< prev index next >

src/jdk.incubator.foreign/share/classes/jdk/internal/foreign/abi/ProgrammableInvoker.java

Print this page
@@ -23,13 +23,11 @@
  package jdk.internal.foreign.abi;
  
  import jdk.incubator.foreign.MemoryAddress;
  import jdk.incubator.foreign.MemoryHandles;
  import jdk.incubator.foreign.MemorySegment;
- import jdk.incubator.foreign.NativeAllocationScope;
- import jdk.internal.access.JavaLangInvokeAccess;
- import jdk.internal.access.SharedSecrets;
+ import jdk.incubator.foreign.NativeScope;
  import jdk.internal.foreign.MemoryAddressImpl;
  import jdk.internal.foreign.Utils;
  
  import java.lang.invoke.MethodHandle;
  import java.lang.invoke.MethodHandles;

@@ -69,41 +67,26 @@
      private static final VarHandle VH_LONG = MemoryHandles.varHandle(long.class, ByteOrder.nativeOrder());
  
      private static final MethodHandle MH_INVOKE_MOVES;
      private static final MethodHandle MH_INVOKE_INTERP_BINDINGS;
  
-     private static final MethodHandle MH_UNBOX_ADDRESS;
-     private static final MethodHandle MH_BOX_ADDRESS;
-     private static final MethodHandle MH_BASE_ADDRESS;
-     private static final MethodHandle MH_COPY_BUFFER;
      private static final MethodHandle MH_MAKE_ALLOCATOR;
      private static final MethodHandle MH_CLOSE_ALLOCATOR;
-     private static final MethodHandle MH_ALLOCATE_BUFFER;
  
      private static final Map<ABIDescriptor, Long> adapterStubs = new ConcurrentHashMap<>();
  
      static {
          try {
              MethodHandles.Lookup lookup = MethodHandles.lookup();
              MH_INVOKE_MOVES = lookup.findVirtual(ProgrammableInvoker.class, "invokeMoves",
                      methodType(Object.class, Object[].class, Binding.Move[].class, Binding.Move[].class));
              MH_INVOKE_INTERP_BINDINGS = lookup.findVirtual(ProgrammableInvoker.class, "invokeInterpBindings",
                      methodType(Object.class, Object[].class, MethodHandle.class, Map.class, Map.class));
-             MH_UNBOX_ADDRESS = lookup.findStatic(ProgrammableInvoker.class, "toRawLongValue",
-                     methodType(long.class, MemoryAddress.class));
-             MH_BOX_ADDRESS = lookup.findStatic(ProgrammableInvoker.class, "ofLong",
-                     methodType(MemoryAddress.class, long.class));
-             MH_BASE_ADDRESS = lookup.findVirtual(MemorySegment.class, "baseAddress",
-                     methodType(MemoryAddress.class));
-             MH_COPY_BUFFER = lookup.findStatic(ProgrammableInvoker.class, "copyBuffer",
-                     methodType(MemorySegment.class, MemorySegment.class, long.class, long.class, NativeAllocationScope.class));
-             MH_MAKE_ALLOCATOR = lookup.findStatic(NativeAllocationScope.class, "boundedScope",
-                     methodType(NativeAllocationScope.class, long.class));
-             MH_CLOSE_ALLOCATOR = lookup.findVirtual(NativeAllocationScope.class, "close",
+             MH_MAKE_ALLOCATOR = lookup.findStatic(NativeScope.class, "boundedScope",
+                     methodType(NativeScope.class, long.class));
+             MH_CLOSE_ALLOCATOR = lookup.findVirtual(NativeScope.class, "close",
                      methodType(void.class));
-             MH_ALLOCATE_BUFFER = lookup.findStatic(MemorySegment.class, "allocateNative",
-                     methodType(MemorySegment.class, long.class, long.class));
          } catch (ReflectiveOperationException e) {
              throw new RuntimeException(e);
          }
      }
  

@@ -161,15 +144,15 @@
                  ? void.class
                  : retMoves.length == 1
                      ? retMoves[0].type()
                      : Object[].class;
  
-         MethodType intrinsicType = methodType(returnType, argMoveTypes);
+         MethodType leafType = methodType(returnType, argMoveTypes);
  
          MethodHandle handle = insertArguments(MH_INVOKE_MOVES.bindTo(this), 1, argMoves, retMoves)
-                                             .asCollector(Object[].class, intrinsicType.parameterCount())
-                                             .asType(intrinsicType);
+                                             .asCollector(Object[].class, leafType.parameterCount())
+                                             .asType(leafType);
  
          if (NO_SPEC || retMoves.length > 1) {
              Map<VMStorage, Integer> argIndexMap = indexMap(argMoves);
              Map<VMStorage, Integer> retIndexMap = indexMap(retMoves);
  

@@ -181,133 +164,57 @@
           }
  
          return handle;
      }
  
-     private MethodHandle specialize(MethodHandle intrinsicHandle) {
-         MethodType type = callingSequence.methodType();
-         MethodType intrinsicType = intrinsicHandle.type();
+     private MethodHandle specialize(MethodHandle leafHandle) {
+         MethodType highLevelType = callingSequence.methodType();
+         MethodType leafType = leafHandle.type();
+ 
+         MethodHandle specializedHandle = leafHandle; // initial
  
          int insertPos = -1;
          if (bufferCopySize > 0) {
-             intrinsicHandle = dropArguments(intrinsicHandle, 0, NativeAllocationScope.class);
+             specializedHandle = dropArguments(specializedHandle, 0, NativeScope.class);
              insertPos++;
          }
-         for (int i = 0; i < type.parameterCount(); i++) {
+         for (int i = 0; i < highLevelType.parameterCount(); i++) {
              List<Binding> bindings = callingSequence.argumentBindings(i);
              insertPos += bindings.stream().filter(Binding.Move.class::isInstance).count() + 1;
              // We interpret the bindings in reverse since we have to construct a MethodHandle from the bottom up
              for (int j = bindings.size() - 1; j >= 0; j--) {
                  Binding binding = bindings.get(j);
-                 switch (binding.tag()) {
-                     case MOVE -> insertPos--; // handled by fallback
-                     case DUP ->
-                         intrinsicHandle = mergeArguments(intrinsicHandle, insertPos, insertPos + 1);
-                     case CONVERT_ADDRESS ->
-                         intrinsicHandle = filterArguments(intrinsicHandle, insertPos, MH_UNBOX_ADDRESS);
-                     case BASE_ADDRESS ->
-                         intrinsicHandle = filterArguments(intrinsicHandle, insertPos, MH_BASE_ADDRESS);
-                     case DEREFERENCE -> {
-                         Binding.Dereference deref = (Binding.Dereference) binding;
-                         MethodHandle filter = filterArguments(
-                             deref.varHandle()
-                             .toMethodHandle(VarHandle.AccessMode.GET)
-                             .asType(methodType(deref.type(), MemoryAddress.class)), 0, MH_BASE_ADDRESS);
-                         intrinsicHandle = filterArguments(intrinsicHandle, insertPos, filter);
-                     }
-                     case COPY_BUFFER -> {
-                         Binding.Copy copy = (Binding.Copy) binding;
-                         MethodHandle filter = insertArguments(MH_COPY_BUFFER, 1, copy.size(), copy.alignment());
-                         intrinsicHandle = collectArguments(intrinsicHandle, insertPos, filter);
-                         intrinsicHandle = mergeArguments(intrinsicHandle, 0, insertPos + 1);
-                     }
-                     default -> throw new IllegalArgumentException("Illegal tag: " + binding.tag());
+                 if (binding.tag() == Binding.Tag.MOVE) {
+                     insertPos--;
+                 } else {
+                     specializedHandle = binding.specializeUnbox(specializedHandle, insertPos);
                  }
              }
          }
  
-         if (type.returnType() != void.class) {
-             MethodHandle returnFilter = identity(type.returnType());
+         if (highLevelType.returnType() != void.class) {
+             MethodHandle returnFilter = identity(highLevelType.returnType());
              List<Binding> bindings = callingSequence.returnBindings();
              for (int j = bindings.size() - 1; j >= 0; j--) {
                  Binding binding = bindings.get(j);
-                 switch (binding.tag()) {
-                     case MOVE -> { /* handled by fallback */ }
-                     case CONVERT_ADDRESS ->
-                         returnFilter = filterArguments(returnFilter, 0, MH_BOX_ADDRESS);
-                     case DEREFERENCE -> {
-                         Binding.Dereference deref = (Binding.Dereference) binding;
-                         MethodHandle setter = deref.varHandle().toMethodHandle(VarHandle.AccessMode.SET);
-                         setter = filterArguments(
-                             setter.asType(methodType(void.class, MemoryAddress.class, deref.type())),
-                             0, MH_BASE_ADDRESS);
-                         returnFilter = collectArguments(returnFilter, returnFilter.type().parameterCount(), setter);
-                     }
-                     case DUP ->
-                         // FIXME assumes shape like: (MS, ..., MS, T) R, is that good enough?
-                         returnFilter = mergeArguments(returnFilter, 0, returnFilter.type().parameterCount() - 2);
-                     case ALLOC_BUFFER -> {
-                         Binding.Allocate alloc = (Binding.Allocate) binding;
-                         returnFilter = collectArguments(returnFilter, 0,
-                                 insertArguments(MH_ALLOCATE_BUFFER, 0, alloc.size(), alloc.alignment()));
-                     }
-                     default ->
-                         throw new IllegalArgumentException("Illegal tag: " + binding.tag());
-                 }
+                 returnFilter = binding.specializeBox(returnFilter);
              }
- 
-             intrinsicHandle = MethodHandles.filterReturnValue(intrinsicHandle, returnFilter);
+             specializedHandle = MethodHandles.filterReturnValue(specializedHandle, returnFilter);
          }
  
          if (bufferCopySize > 0) {
-             MethodHandle closer = intrinsicType.returnType() == void.class
-                   // (Throwable, NativeAllocationScope) -> void
+             // insert try-finally to close the NativeScope used for Binding.Copy
+             MethodHandle closer = leafType.returnType() == void.class
+                   // (Throwable, NativeScope) -> void
                  ? collectArguments(empty(methodType(void.class, Throwable.class)), 1, MH_CLOSE_ALLOCATOR)
-                   // (Throwable, V, NativeAllocationScope) -> V
-                 : collectArguments(dropArguments(identity(intrinsicHandle.type().returnType()), 0, Throwable.class),
+                   // (Throwable, V, NativeScope) -> V
+                 : collectArguments(dropArguments(identity(specializedHandle.type().returnType()), 0, Throwable.class),
                                     2, MH_CLOSE_ALLOCATOR);
-             intrinsicHandle = tryFinally(intrinsicHandle, closer);
-             intrinsicHandle = collectArguments(intrinsicHandle, 0, insertArguments(MH_MAKE_ALLOCATOR, 0, bufferCopySize));
-         }
-         return intrinsicHandle;
-     }
- 
-     private static MethodHandle mergeArguments(MethodHandle mh, int sourceIndex, int destIndex) {
-         MethodType oldType = mh.type();
-         Class<?> sourceType = oldType.parameterType(sourceIndex);
-         Class<?> destType = oldType.parameterType(destIndex);
-         if (sourceType != destType) {
-             // TODO meet?
-             throw new IllegalArgumentException("Parameter types differ: " + sourceType + " != " + destType);
-         }
-         MethodType newType = oldType.dropParameterTypes(destIndex, destIndex + 1);
-         int[] reorder = new int[oldType.parameterCount()];
-         assert destIndex > sourceIndex;
-         for (int i = 0, index = 0; i < reorder.length; i++) {
-             if (i != destIndex) {
-                 reorder[i] = index++;
-             } else {
-                 reorder[i] = sourceIndex;
-             }
+             specializedHandle = tryFinally(specializedHandle, closer);
+             specializedHandle = collectArguments(specializedHandle, 0, insertArguments(MH_MAKE_ALLOCATOR, 0, bufferCopySize));
          }
-         return permuteArguments(mh, newType, reorder);
-     }
- 
-     private static MemorySegment copyBuffer(MemorySegment operand, long size, long alignment,
-                                     NativeAllocationScope allocator) {
-         assert operand.byteSize() == size : "operand size mismatch";
-         MemorySegment copy = allocator.allocate(size, alignment).segment();
-         copy.copyFrom(operand.asSlice(0, size));
-         return copy;
-     }
- 
-     private static long toRawLongValue(MemoryAddress address) {
-         return address.toRawLongValue(); // Workaround for JDK-8239083
-     }
- 
-     private static MemoryAddress ofLong(long address) {
-         return MemoryAddress.ofLong(address); // Workaround for JDK-8239083
+         return specializedHandle;
      }
  
      private Map<VMStorage, Integer> indexMap(Binding.Move[] moves) {
          return IntStream.range(0, moves.length)
                          .boxed()

@@ -383,20 +290,20 @@
      }
  
      Object invokeInterpBindings(Object[] args, MethodHandle leaf,
                                  Map<VMStorage, Integer> argIndexMap,
                                  Map<VMStorage, Integer> retIndexMap) throws Throwable {
-         List<MemorySegment> tempBuffers = new ArrayList<>();
+         NativeScope scope = bufferCopySize != 0 ? NativeScope.boundedScope(bufferCopySize) : null;
          try {
              // do argument processing, get Object[] as result
              Object[] moves = new Object[leaf.type().parameterCount()];
              for (int i = 0; i < args.length; i++) {
                  Object arg = args[i];
                  BindingInterpreter.unbox(arg, callingSequence.argumentBindings(i),
                          (storage, type, value) -> {
                              moves[argIndexMap.get(storage)] = value;
-                         }, tempBuffers);
+                         }, scope);
              }
  
              // call leaf
              Object o = leaf.invokeWithArguments(moves);
  

@@ -409,11 +316,13 @@
                          (storage, type) -> oArr[retIndexMap.get(storage)]);
              } else {
                  return BindingInterpreter.box(callingSequence.returnBindings(), (storage, type) -> o);
              }
          } finally {
-             tempBuffers.forEach(MemorySegment::close);
+             if (scope != null) {
+                 scope.close();
+             }
          }
      }
  
      //natives
  
< prev index next >