共計(jì) 935 個(gè)字符,預(yù)計(jì)需要花費(fèi) 3 分鐘才能閱讀完成。
要在 Java 中調(diào)用 PyTorch 模型,可以使用 PyTorch 的 Java API,也就是 TorchScript。TorchScript 是 PyTorch 的靜態(tài)圖編譯器,它允許將 PyTorch 模型編譯為一種可序列化和可導(dǎo)入的中間表示形式。然后可以在 Java 中加載并運(yùn)行這個(gè)中間表示形式。
以下是一個(gè)簡(jiǎn)單的示例代碼,展示了如何在 Java 中加載并調(diào)用一個(gè) PyTorch 模型:
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
public class PyTorchModelExample {public static void main(String[] args) {try {// 加載 PyTorch 模型
Module module = Module.load("model.pt");
// 創(chuàng)建輸入 Tensor
float[] inputData = {1.0f, 2.0f, 3.0f, 4.0f};
Tensor inputTensor = Tensor.fromBlob(inputData, new long[]{1, 4});
// 運(yùn)行模型
IValue output = module.forward(IValue.from(inputTensor));
// 獲取輸出 Tensor
Tensor outputTensor = output.toTensor();
// 打印輸出
float[] outputData = outputTensor.getDataAsFloatArray();
for (float value : outputData) {System.out.println(value);
}
} catch (Exception e) {e.printStackTrace();
}
}
}
在這個(gè)示例中,我們加載了一個(gè)名為 model.pt 的 PyTorch 模型,并使用輸入數(shù)據(jù) {1.0, 2.0, 3.0, 4.0} 來(lái)運(yùn)行模型。最后,我們獲取輸出 Tensor 并打印出來(lái)。
請(qǐng)注意,為了在 Java 中使用 PyTorch 的 Java API,你需要在項(xiàng)目中添加 PyTorch 的 Java 庫(kù)依賴。你可以從 PyTorch 官方網(wǎng)站下載并添加到你的項(xiàng)目中。
丸趣 TV 網(wǎng) – 提供最優(yōu)質(zhì)的資源集合!
正文完