跨语言编程#

本页将向您展示如何使用 Ray 的跨语言编程功能。

设置驱动程序#

我们需要在您的驱动程序中设置 代码搜索路径

import ray

ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/code"]))
java -classpath <classpath> \
    -Dray.address=<address> \
    -Dray.job.code-search-path=/path/to/code/ \
    <classname> <args>

如果您希望为工作者加载Python和Java代码,并且它们位于不同的目录中,您可能需要包含多个目录。

import ray

ray.init(job_config=ray.job_config.JobConfig(code_search_path="/path/to/jars:/path/to/pys"))
java -classpath <classpath> \
    -Dray.address=<address> \
    -Dray.job.code-search-path=/path/to/jars:/path/to/pys \
    <classname> <args>

Python 调用 Java#

假设我们有一个Java静态方法和一个Java类如下:

package io.ray.demo;

public class Math {

  public static int add(int a, int b) {
    return a + b;
  }
}
package io.ray.demo;

// A regular Java class.
public class Counter {

  private int value = 0;

  public int increment() {
    this.value += 1;
    return this.value;
  }
}

然后,在Python中,我们可以调用上述Java远程函数,或者从上述Java类创建一个actor。

import ray

with ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/code"])):
  # Define a Java class.
  counter_class = ray.cross_language.java_actor_class(
        "io.ray.demo.Counter")

  # Create a Java actor and call actor method.
  counter = counter_class.remote()
  obj_ref1 = counter.increment.remote()
  assert ray.get(obj_ref1) == 1
  obj_ref2 = counter.increment.remote()
  assert ray.get(obj_ref2) == 2

  # Define a Java function.
  add_function = ray.cross_language.java_function(
        "io.ray.demo.Math", "add")

  # Call the Java remote function.
  obj_ref3 = add_function.remote(1, 2)
  assert ray.get(obj_ref3) == 3

Java 调用 Python#

假设我们有一个Python模块如下:

# /path/to/the_dir/ray_demo.py

import ray

@ray.remote
class Counter(object):
  def __init__(self):
      self.value = 0

  def increment(self):
      self.value += 1
      return self.value

@ray.remote
def add(a, b):
    return a + b

备注

  • 函数或类应由 @ray.remote 装饰。

然后,在Java中,我们可以调用上述Python远程函数,或者从上述Python类创建一个actor。

package io.ray.demo;

import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.Ray;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
import org.testng.Assert;

public class JavaCallPythonDemo {

  public static void main(String[] args) {
    // Set the code-search-path to the directory of your `ray_demo.py` file.
    System.setProperty("ray.job.code-search-path", "/path/to/the_dir/");
    Ray.init();

    // Define a Python class.
    PyActorClass actorClass = PyActorClass.of(
        "ray_demo", "Counter");

    // Create a Python actor and call actor method.
    PyActorHandle actor = Ray.actor(actorClass).remote();
    ObjectRef objRef1 = actor.task(
        PyActorMethod.of("increment", int.class)).remote();
    Assert.assertEquals(objRef1.get(), 1);
    ObjectRef objRef2 = actor.task(
        PyActorMethod.of("increment", int.class)).remote();
    Assert.assertEquals(objRef2.get(), 2);

    // Call the Python remote function.
    ObjectRef objRef3 = Ray.task(PyFunction.of(
        "ray_demo", "add", int.class), 1, 2).remote();
    Assert.assertEquals(objRef3.get(), 3);

    Ray.shutdown();
  }
}

跨语言数据序列化#

如果 ray 调用的参数和返回值的类型是以下类型,它们可以被自动序列化和反序列化:

  • 原始数据类型

    MessagePack

    Python

    Java

    nil

    null

    布尔

    布尔

    布尔值

    整数

    整数

    短整数 / 整数 / 长整数 / 大整数

    浮动

    浮动

    浮点数 / 双精度

    str

    str

    字符串

    bin

    字节

    byte[]

  • 基本容器类型

    MessagePack

    Python

    Java

    数组

    列表

    数组

  • Ray 内置类型
    • ActorHandle

备注

  • 注意Python和Java之间浮点数/双精度数的精度差异。如果Java使用浮点类型来接收输入参数,Python的双精度数据将在Java中被降低为浮点精度。

  • BigInteger 可以支持的最大值为 2^64-1,请参考:msgpack/msgpack。如果值大于 2^64-1,则将其发送到 Python 时将引发异常。

以下示例展示了如何将这些类型作为参数传递,以及如何返回这些类型。

你可以编写一个返回输入数据的Python函数:

# ray_serialization.py

import ray

@ray.remote
def py_return_input(v):
    return v

然后你可以将对象从Java传输到Python,然后再从Python传回Java:

package io.ray.demo;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
import java.math.BigInteger;
import org.testng.Assert;

public class SerializationDemo {

  public static void main(String[] args) {
    Ray.init();

    Object[] inputs = new Object[]{
        true,  // Boolean
        Byte.MAX_VALUE,  // Byte
        Short.MAX_VALUE,  // Short
        Integer.MAX_VALUE,  // Integer
        Long.MAX_VALUE,  // Long
        BigInteger.valueOf(Long.MAX_VALUE),  // BigInteger
        "Hello World!",  // String
        1.234f,  // Float
        1.234,  // Double
        "example binary".getBytes()};  // byte[]
    for (Object o : inputs) {
      ObjectRef res = Ray.task(
          PyFunction.of("ray_serialization", "py_return_input", o.getClass()),
          o).remote();
      Assert.assertEquals(res.get(), o);
    }

    Ray.shutdown();
  }
}

跨语言异常堆栈#

假设我们有一个Java包如下:

package io.ray.demo;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;

public class MyRayClass {

  public static int raiseExceptionFromPython() {
    PyFunction<Integer> raiseException = PyFunction.of(
        "ray_exception", "raise_exception", Integer.class);
    ObjectRef<Integer> refObj = Ray.task(raiseException).remote();
    return refObj.get();
  }
}

以及一个Python模块如下:

# ray_exception.py

import ray

@ray.remote
def raise_exception():
    1 / 0

然后,运行以下代码:

# ray_exception_demo.py

import ray

with ray.init(job_config=ray.job_config.JobConfig(code_search_path=["/path/to/ray_exception"])):
  obj_ref = ray.cross_language.java_function(
        "io.ray.demo.MyRayClass",
        "raiseExceptionFromPython").remote()
  ray.get(obj_ref)  # <-- raise exception from here.

异常堆栈将是:

Traceback (most recent call last):
  File "ray_exception_demo.py", line 9, in <module>
    ray.get(obj_ref)  # <-- raise exception from here.
  File "ray/python/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "ray/python/ray/_private/worker.py", line 2247, in get
    raise value
ray.exceptions.CrossLanguageError: An exception raised from JAVA:
io.ray.api.exception.RayTaskException: (pid=61894, ip=172.17.0.2) Error executing task c8ef45ccd0112571ffffffffffffffffffffffff01000000
        at io.ray.runtime.task.TaskExecutor.execute(TaskExecutor.java:186)
        at io.ray.runtime.RayNativeRuntime.nativeRunTaskExecutor(Native Method)
        at io.ray.runtime.RayNativeRuntime.run(RayNativeRuntime.java:231)
        at io.ray.runtime.runner.worker.DefaultWorker.main(DefaultWorker.java:15)
Caused by: io.ray.api.exception.CrossLanguageException: An exception raised from PYTHON:
ray.exceptions.RayTaskError: ray::raise_exception() (pid=62041, ip=172.17.0.2)
  File "ray_exception.py", line 7, in raise_exception
    1 / 0
ZeroDivisionError: division by zero