This commit is contained in:
2026-04-27 18:41:34 -05:00
commit 7496fd3781
25 changed files with 1294 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
package org.pvpbot.goated.client;
import net.fabricmc.api.ClientModInitializer;
public class GoatedClient implements ClientModInitializer {
@Override
public void onInitializeClient() {
}
}

View File

@@ -0,0 +1,12 @@
package org.pvpbot.goated.client;
import net.fabricmc.fabric.api.datagen.v1.DataGeneratorEntrypoint;
import net.fabricmc.fabric.api.datagen.v1.FabricDataGenerator;
public class GoatedDataGenerator implements DataGeneratorEntrypoint {
@Override
public void onInitializeDataGenerator(FabricDataGenerator fabricDataGenerator) {
FabricDataGenerator.Pack pack = fabricDataGenerator.createPack();
}
}

View File

@@ -0,0 +1,14 @@
{
"required": true,
"minVersion": "0.8",
"package": "org.pvpbot.goated.mixin.client",
"compatibilityLevel": "JAVA_21",
"client": [
],
"injectors": {
"defaultRequire": 1
},
"overwrites": {
"requireAnnotations": true
}
}

View File

@@ -0,0 +1,98 @@
package org.pvpbot.goated;
import com.mojang.brigadier.exceptions.CommandSyntaxException;
import net.fabricmc.api.ModInitializer;
import net.fabricmc.fabric.api.command.v2.CommandRegistrationCallback;
import net.fabricmc.fabric.api.event.lifecycle.v1.ServerTickEvents;
import net.fabricmc.loader.api.FabricLoader;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.command.CommandManager;
import net.minecraft.server.command.ServerCommandSource;
import net.minecraft.server.network.ServerPlayerEntity;
import net.minecraft.text.Text;
import net.minecraft.util.math.Vec3d;
import org.pvpbot.goated.ai.BotController;
import org.pvpbot.goated.ai.BotRegistry;
public class Goated implements ModInitializer {
private static final String BOT_NAME = "PvPBOT";
private ServerPlayerEntity bot;
private Vec3d moveTarget = null;
@Override
public void onInitialize() {
if (!FabricLoader.getInstance().isModLoaded("carpet")) {
return;
}
// ===== SPAWN BOT COMMAND =====
CommandRegistrationCallback.EVENT.register((dispatcher, registryAccess, environment) ->
dispatcher.register(
CommandManager.literal("spawnbot")
.requires(source -> source.hasPermissionLevel(2))
.executes(context -> {
ServerCommandSource source = context.getSource();
String ownerName = source.getEntity() instanceof ServerPlayerEntity player ? player.getName().getString() : null;
spawnBot(source);
source.sendFeedback(() ->
Text.literal("Spawning " + BOT_NAME + "..."),
false
);
return 1;
})
)
);
}
// ===================== BOT SPAWN =====================
private void spawnBot(ServerCommandSource source) {
try {
source.getServer().getCommandManager().getDispatcher().execute(
"player " + BOT_NAME + " spawn",
source
);
source.getServer().execute(() -> {
bot = source.getServer().getPlayerManager().getPlayer(BOT_NAME);
if (bot == null) return;
resetPlayer(source.getServer(), bot);
});
} catch (CommandSyntaxException e) {
source.sendError(Text.literal("Failed to spawn bot: " + e.getMessage()));
}
}
// ===================== EQUIP BOT =====================
public static void equipPlayer(MinecraftServer server, ServerPlayerEntity player) {
try {
var dispatcher = server.getCommandManager().getDispatcher();
var source = server.getCommandSource();
String playerName = player.getName().getString();
dispatcher.execute("item replace entity " + playerName + " armor.head with minecraft:diamond_helmet", source);
dispatcher.execute("item replace entity " + playerName + " armor.chest with minecraft:diamond_chestplate", source);
dispatcher.execute("item replace entity " + playerName + " armor.legs with minecraft:diamond_leggings", source);
dispatcher.execute("item replace entity " + playerName + " armor.feet with minecraft:diamond_boots", source);
dispatcher.execute("item replace entity " + playerName + " weapon.mainhand with minecraft:diamond_sword", source);
} catch (CommandSyntaxException e) {
System.err.println("Failed to equip player " + player.getName().getString() + ": " + e.getMessage());
}
}
public static void resetPlayer(MinecraftServer server, ServerPlayerEntity player) {
if (player == null) return;
player.setHealth(player.getMaxHealth());
player.getHungerManager().setFoodLevel(20);
player.getHungerManager().setSaturationLevel(20f);
equipPlayer(server, player);
}
}

View File

@@ -0,0 +1,31 @@
package org.pvpbot.goated.ai;
public class BotAIState {
// 🔥 Latest values from Python AI (updated async)
public volatile float yaw;
public volatile float pitch;
public volatile float moveForward;
public volatile float moveStrafe;
public volatile boolean jump;
public volatile boolean sprint;
public volatile boolean crouch;
public volatile boolean swing;
// optional debug
public volatile long lastUpdateTime;
public void reset() {
yaw = 0;
pitch = 0;
moveForward = 0;
moveStrafe = 0;
jump = false;
sprint = false;
crouch = false;
swing = false;
}
}

View File

@@ -0,0 +1,204 @@
package org.pvpbot.goated.ai;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.network.ServerPlayerEntity;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
public class BotBrain {
private final ServerPlayerEntity bot;
private ServerPlayerEntity target;
private final BotInputState aiState = new BotInputState();
private final AtomicInteger tickCounter = new AtomicInteger(0);
private final AtomicBoolean requestInFlight = new AtomicBoolean(false);
private float lastBotHealth = 20f;
private float lastTargetHealth = 20f;
public BotBrain(ServerPlayerEntity bot, ServerPlayerEntity target) {
this.bot = bot;
this.target = target;
}
public BotInputState getAIState() {
return aiState;
}
public ServerPlayerEntity getTarget() {
return target;
}
public void setTarget(ServerPlayerEntity target) {
this.target = target;
}
// ===================== MAIN TICK =====================
public void tick(MinecraftServer server) {
// Prevent overlapping requests if one takes longer than a tick
if (!requestInFlight.compareAndSet(false, true)) {
return;
}
// Build request
String json = BotStateSerializer.toJson(bot, target);
// Send state to AI server every tick
BotHttpClient.sendState(json)
.thenAccept(response -> {
try {
System.out.println("[DEBUG] BotBrain received response: " + response);
parseAndApply(response, aiState);
sendRewardUpdate();
} finally {
requestInFlight.set(false);
}
})
.exceptionally(err -> {
requestInFlight.set(false);
err.printStackTrace();
return null;
});
}
private void sendRewardUpdate() {
float currentBotHealth = bot.getHealth();
float currentTargetHealth = target.getHealth();
float attackCooldown = bot.getAttackCooldownProgress(0.5f);
float damageDealt = Math.max(0, lastTargetHealth - currentTargetHealth);
float damageTaken = Math.max(0, lastBotHealth - currentBotHealth);
double dx = target.getX() - bot.getX();
double dy = target.getY() - bot.getY();
double dz = target.getZ() - bot.getZ();
double distance = Math.sqrt(dx * dx + dy * dy + dz * dz);
// Look-at reward
double botYaw = bot.getYaw();
double botPitch = bot.getPitch();
// Target direction
double tx = target.getX() - bot.getX();
double ty = target.getEyeY() - bot.getEyeY();
double tz = target.getZ() - bot.getZ();
double tDist = Math.sqrt(tx * tx + ty * ty + tz * tz);
double targetYaw = Math.toDegrees(Math.atan2(-tx, tz));
double targetPitch = Math.toDegrees(Math.asin(-ty / tDist));
double yawDiff = Math.abs(normalizeDegrees(targetYaw - botYaw));
double pitchDiff = Math.abs(normalizeDegrees(targetPitch - botPitch));
double lookReward = (1.0 - (yawDiff + pitchDiff) / 180.0) * 2.0;
lookReward = Math.max(0, lookReward);
boolean done = bot.isDead() || target.isDead() || bot.isRemoved() || target.isRemoved();
// Calculate total reward in Java
double totalReward = 0.0;
// Cooldown-weighted damage reward
// If attackCooldown is 1.0, it's a full strength hit.
// If it's low, it's a weak hit.
if (damageDealt > 0) {
if (attackCooldown < 0.85) {
// Penalize weak hits (spamming) - significantly increased penalty
totalReward -= 10.0 * (1.0 - attackCooldown);
} else {
totalReward += damageDealt * 10.0 * attackCooldown;
totalReward += 5.0 * attackCooldown; // Bonus for landing a well-timed hit
}
}
totalReward -= damageTaken * 5.0;
// Sigmoid-like distance reward: peak at 3.0, lower for < 3.0, tapering off for > 3.0
// Using: exp(-0.5 * ((dist - 3.0) / 1.5)^2) * 2.0
double distReward = Math.exp(-0.5 * Math.pow((distance - 3.0) / 1.5, 2)) * 2.0;
totalReward += distReward;
totalReward += lookReward;
if (currentTargetHealth <= 0 || target.isDead()) totalReward += 100.0;
if (currentBotHealth <= 0 || bot.isDead()) totalReward -= 100.0;
// Reset trackers on death to avoid huge damage_dealt/taken reward spikes on respawn
if (done) {
currentBotHealth = 20f;
currentTargetHealth = 20f;
}
System.out.println("[DEBUG] Calculated Reward: " + totalReward + " (Dealt: " + damageDealt + ", Taken: " + damageTaken + ", Dist: " + distance + ", LookRew: " + lookReward + ", DistRew: " + distReward + ")");
String rewardJson = """
{
"bot_health": %.1f,
"target_health": %.1f,
"damage_dealt": %.1f,
"damage_taken": %.1f,
"distance": %.3f,
"attack_cooldown": %.3f,
"hit_success": %b,
"done": %b,
"total_reward": %.4f
}
""".formatted(
currentBotHealth,
currentTargetHealth,
damageDealt,
damageTaken,
distance,
attackCooldown,
damageDealt > 0,
done,
totalReward
);
BotHttpClient.sendReward(rewardJson)
.thenAccept(resp -> System.out.println("[DEBUG] Reward sent. Server response: " + resp));
lastBotHealth = currentBotHealth;
lastTargetHealth = currentTargetHealth;
}
// ===================== SAFE JSON PARSER =====================
private void parseAndApply(String json, BotInputState state) {
JsonObject obj = JsonParser.parseString(json).getAsJsonObject();
state.yaw = obj.has("yaw") ? obj.get("yaw").getAsFloat() : 0;
state.pitch = obj.has("pitch") ? obj.get("pitch").getAsFloat() : 0;
state.moveForward = obj.has("move_forward") ? obj.get("move_forward").getAsFloat() : 0;
state.moveStrafe = obj.has("move_strafe") ? obj.get("move_strafe").getAsFloat() : 0;
state.jump = obj.has("jump") && obj.get("jump").getAsBoolean();
state.sprint = obj.has("sprint") && obj.get("sprint").getAsBoolean();
state.crouch = obj.has("crouch") && obj.get("crouch").getAsBoolean();
state.swing = obj.has("swing") && obj.get("swing").getAsBoolean();
}
// ===================== CHAT HELPER =====================
public static void botSay(MinecraftServer server, ServerPlayerEntity bot, String message) {
server.getCommandManager().executeWithPrefix(
bot.getCommandSource(),
"say " + message
);
}
private double normalizeDegrees(double degrees) {
double result = degrees % 360.0;
if (result >= 180.0) result -= 360.0;
if (result < -180.0) result += 360.0;
return result;
}
}

View File

@@ -0,0 +1,54 @@
package org.pvpbot.goated.ai;
import net.minecraft.server.network.ServerPlayerEntity;
public class BotController {
private final ServerPlayerEntity bot;
private String ownerName;
private final BotInputState input = new BotInputState();
private BotBrain brain;
public BotController(ServerPlayerEntity bot) {
this.bot = bot;
}
public ServerPlayerEntity getBot() {
return bot;
}
public BotInputState input() {
return input;
}
public BotBrain getBrain() {
return brain;
}
public void setBrain(BotBrain brain) {
this.brain = brain;
}
public String getOwnerName() {
return ownerName;
}
public void setOwnerName(String ownerName) {
this.ownerName = ownerName;
}
public void setInput(BotInputState newInput) {
this.input.yaw = newInput.yaw;
this.input.pitch = newInput.pitch;
this.input.moveForward = newInput.moveForward;
this.input.moveStrafe = newInput.moveStrafe;
this.input.jump = newInput.jump;
this.input.sprint = newInput.sprint;
this.input.crouch = newInput.crouch;
this.input.swing = newInput.swing;
}
}

View File

@@ -0,0 +1,35 @@
package org.pvpbot.goated.ai;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
public class BotHttpClient {
private static final HttpClient CLIENT = HttpClient.newHttpClient();
public static CompletableFuture<String> sendState(String json) {
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create("http://127.0.0.1:5000/predict"))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(json))
.build();
return CLIENT.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(HttpResponse::body);
}
public static CompletableFuture<String> sendReward(String json) {
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create("http://127.0.0.1:5000/reward"))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(json))
.build();
return CLIENT.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(HttpResponse::body);
}
}

View File

@@ -0,0 +1,148 @@
package org.pvpbot.goated.ai;
import net.minecraft.command.argument.EntityAnchorArgumentType;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.network.ServerPlayerEntity;
import net.minecraft.util.math.Vec3d;
public class BotInputApplier {
private static boolean lastSprintState = false;
private static boolean lastCrouchState = false;
public static void apply(MinecraftServer server, BotController controller) {
ServerPlayerEntity bot = controller.getBot();
BotInputState input = controller.input();
String name = bot.getName().getString();
// ================= LOOK =================
applyLook(bot, input);
// ================= DEADZONE INPUT =================
float forward = applyDeadzone(input.moveForward);
float strafe = applyDeadzone(input.moveStrafe);
// ================= SPRINT (NO BACKWARD SPRINT) =================
boolean shouldSprint = input.sprint && forward >= 0;
if (shouldSprint != lastSprintState) {
lastSprintState = shouldSprint;
server.getCommandManager().executeWithPrefix(
server.getCommandSource(),
"player " + name + " sprint"
);
}
// ================= CROUCH (EDGE TRIGGER) =================
if (input.crouch != lastCrouchState) {
lastCrouchState = input.crouch;
server.getCommandManager().executeWithPrefix(
server.getCommandSource(),
"player " + name + " crouch"
);
}
// ================= JUMP =================
if (input.jump && bot.isOnGround()) {
server.getCommandManager().executeWithPrefix(
server.getCommandSource(),
"player " + name + " jump"
);
}
// ================= MOVEMENT =================
applyMovement(bot, forward, strafe);
// ================= ATTACK =================
if (input.swing) {
server.getCommandManager().executeWithPrefix(
server.getCommandSource(),
"player " + name + " attack"
);
bot.swingHand(net.minecraft.util.Hand.MAIN_HAND);
}
input.clearTransient();
}
// ================= MOVEMENT =================
private static void applyMovement(ServerPlayerEntity bot, float forwardInput, float strafeInput) {
Vec3d velocity = bot.getVelocity();
if (forwardInput == 0 && strafeInput == 0) return;
Vec3d forward = getForwardVector(bot.getYaw());
Vec3d right = new Vec3d(-forward.z, 0, forward.x);
Vec3d moveDir = Vec3d.ZERO;
// ✔ PURE DISCRETE INTENT (no scaling)
if (forwardInput != 0) {
moveDir = moveDir.add(forward.multiply(forwardInput));
}
if (strafeInput != 0) {
moveDir = moveDir.add(right.multiply(strafeInput));
}
if (moveDir.lengthSquared() == 0) return;
moveDir = moveDir.normalize();
double steerStrength = bot.isOnGround() ? 0.08 : 0.03;
Vec3d steering = moveDir.multiply(steerStrength);
Vec3d newVel = new Vec3d(
velocity.x + steering.x,
velocity.y,
velocity.z + steering.z
);
bot.setVelocity(newVel);
bot.velocityModified = true;
}
// ================= LOOK =================
private static void applyLook(ServerPlayerEntity bot, BotInputState input) {
Vec3d direction = getDirectionVector(input.yaw, input.pitch);
Vec3d eyePos = bot.getEyePos();
Vec3d target = eyePos.add(direction.multiply(6));
bot.lookAt(EntityAnchorArgumentType.EntityAnchor.EYES, target);
}
// ================= DEADZONE =================
private static float applyDeadzone(float v) {
if (v > 0.25f) return 1f;
if (v < -0.25f) return -1f;
return 0f;
}
// ================= HELPERS =================
private static Vec3d getForwardVector(float yaw) {
double rad = Math.toRadians(yaw);
return new Vec3d(-Math.sin(rad), 0, Math.cos(rad)).normalize();
}
private static Vec3d getDirectionVector(float yaw, float pitch) {
double yawRad = Math.toRadians(yaw);
double pitchRad = Math.toRadians(pitch);
double x = -Math.sin(yawRad) * Math.cos(pitchRad);
double y = -Math.sin(pitchRad);
double z = Math.cos(yawRad) * Math.cos(pitchRad);
return new Vec3d(x, y, z).normalize();
}
}

View File

@@ -0,0 +1,24 @@
package org.pvpbot.goated.ai;
public class BotInputState {
// ===== LOOK =====
public float yaw;
public float pitch;
// ===== MOVEMENT (WASD STYLE) =====
public float moveForward = 0; // -1 back, +1 forward
public float moveStrafe = 0; // -1 left, +1 right
public boolean sprint = false;
public boolean crouch = false;
// ===== ACTIONS =====
public boolean jump = false;
public boolean swing = false;
public void clearTransient() {
jump = false;
swing = false;
}
}

View File

@@ -0,0 +1,19 @@
package org.pvpbot.goated.ai;
import net.minecraft.server.network.ServerPlayerEntity;
import java.util.HashMap;
import java.util.Map;
public class BotRegistry {
private static final Map<String, BotController> BOTS = new HashMap<>();
public static void register(BotController controller) {
BOTS.put(controller.getBot().getName().getString(), controller);
}
public static BotController get(String name) {
return BOTS.get(name);
}
}

View File

@@ -0,0 +1,74 @@
package org.pvpbot.goated.ai;
import net.minecraft.server.network.ServerPlayerEntity;
public class BotStateSerializer {
public static String toJson(ServerPlayerEntity bot, ServerPlayerEntity target) {
double dx = target.getX() - bot.getX();
double dy = target.getY() - bot.getY();
double dz = target.getZ() - bot.getZ();
double dist = Math.sqrt(dx * dx + dy * dy + dz * dz);
boolean targetIsSwinging = target.handSwinging;
String json = """
{
"bot_pos": {
"x": %.3f,
"y": %.3f,
"z": %.3f
},
"bot_vel": {
"x": %.3f,
"y": %.3f,
"z": %.3f
},
"target_pos": {
"x": %.3f,
"y": %.3f,
"z": %.3f
},
"target_vel": {
"x": %.3f,
"y": %.3f,
"z": %.3f
},
"distance": %.3f,
"bot_health": %.1f,
"target_health": %.1f,
"bot_velocity_x": %.3f,
"bot_velocity_y": %.3f,
"bot_velocity_z": %.3f,
"target_velocity_x": %.3f,
"target_velocity_y": %.3f,
"target_velocity_z": %.3f,
"target_is_swinging": %b,
"attack_cooldown": %.3f,
"combat_timer": %d
}
""".formatted(
bot.getX(), bot.getY(), bot.getZ(),
bot.getVelocity().x, bot.getVelocity().y, bot.getVelocity().z,
target.getX(), target.getY(), target.getZ(),
target.getVelocity().x, target.getVelocity().y, target.getVelocity().z,
dist,
bot.getHealth(),
target.getHealth(),
bot.getVelocity().x,
bot.getVelocity().y,
bot.getVelocity().z,
target.getVelocity().x,
target.getVelocity().y,
target.getVelocity().z,
targetIsSwinging,
bot.getAttackCooldownProgress(0.5f),
bot.age
);
System.out.println("[DEBUG] BotState JSON: " + json);
return json;
}
}

View File

@@ -0,0 +1,96 @@
package org.pvpbot.goated.mixin;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.network.ServerPlayerEntity;
import org.pvpbot.goated.ai.*;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
@Mixin(MinecraftServer.class)
public class ServerTickMixin {
private String targetName = null;
@Inject(method = "tick", at = @At("TAIL"))
private void onTick(CallbackInfo ci) {
MinecraftServer server = (MinecraftServer)(Object)this;
ServerPlayerEntity botPlayer =
server.getPlayerManager().getPlayer("PvPBOT");
if (botPlayer == null) return;
// =========================
// GET OR CREATE CONTROLLER
// =========================
BotController controller = BotRegistry.get("PvPBOT");
if (controller == null) {
controller = new BotController(botPlayer);
BotRegistry.register(controller);
}
// =========================
// GET TARGET
// =========================
ServerPlayerEntity target = null;
if (targetName != null) {
target = server.getPlayerManager().getPlayer(targetName);
}
// If current target is gone or dead, try to find a new one or the same one if it respawned
if (target == null || target.isRemoved()) {
target = findTarget(server, botPlayer, targetName);
if (target != null) {
targetName = target.getName().getString();
}
}
if (target == null) return;
// =========================
// HEAL BOT AND TARGET IF EITHER DIED
// =========================
if (target.isDead() || target.getHealth() <= 0 || botPlayer.isDead() || botPlayer.getHealth() <= 0) {
org.pvpbot.goated.Goated.resetPlayer(server, botPlayer);
org.pvpbot.goated.Goated.resetPlayer(server, target);
}
// =========================
// INIT BRAIN IF NEEDED
// =========================
if (controller.getBrain() == null || !controller.getBrain().getTarget().equals(target)) {
controller.setBrain(new BotBrain(botPlayer, target));
}
// =========================
// UPDATE BRAIN
// =========================
controller.getBrain().tick(server);
// =========================
// APPLY AI → INPUT
// =========================
controller.setInput(controller.getBrain().getAIState());
BotInputApplier.apply(server, controller);
}
// =========================
// SIMPLE TARGET FINDER
// =========================
private ServerPlayerEntity findTarget(MinecraftServer server, ServerPlayerEntity bot, String preferredName) {
if (preferredName != null) {
ServerPlayerEntity preferred = server.getPlayerManager().getPlayer(preferredName);
if (preferred != null && !preferred.isRemoved()) {
return preferred;
}
}
return server.getPlayerManager().getPlayerList().stream()
.filter(p -> !p.getName().getString().equals("PvPBOT"))
.findFirst()
.orElse(null);
}
}

37
src/main/proto/bot.proto Normal file
View File

@@ -0,0 +1,37 @@
syntax = "proto3";
package bot;
option java_multiple_files = true;
service BotAI {
rpc Predict (BotState) returns (BotAction);
}
message Vec3 {
double x = 1;
double y = 2;
double z = 3;
}
message BotState {
Vec3 bot_pos = 1;
Vec3 bot_vel = 2;
Vec3 target_pos = 3;
double distance = 4;
}
message BotAction {
float yaw = 1;
float pitch = 2;
float move_forward = 3;
float move_strafe = 4;
bool jump = 5;
bool sprint = 6;
bool crouch = 7;
bool swing = 8;
}

Binary file not shown.

Binary file not shown.

343
src/main/python/server.py Normal file
View File

@@ -0,0 +1,343 @@
import os
import time
from flask import Flask, request, jsonify
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
app = Flask(__name__)
MODEL_PATH = "ppo_model.pth"
SAVE_INTERVAL = 300 # 5 minutes in seconds
# ===================== PPO NEURAL NETWORK =====================
class ActorCritic(nn.Module):
"""Neural network with ~20M parameters for PPO"""
def __init__(self, state_dim=15, action_dim=8):
super(ActorCritic, self).__init__()
# Shared feature extractor (~15M parameters)
self.shared = nn.Sequential(
nn.Linear(state_dim, 2048),
nn.ReLU(),
nn.Linear(2048, 2048),
nn.ReLU(),
nn.Linear(2048, 2048),
nn.ReLU(),
nn.Linear(2048, 1024),
nn.ReLU()
)
# Actor head for policy (~2.5M parameters)
self.actor = nn.Sequential(
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, action_dim)
)
# Critic head for value function (~2.5M parameters)
self.critic = nn.Sequential(
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, state):
features = self.shared(state)
action_logits = self.actor(features)
value = self.critic(features)
return action_logits, value
def get_action(self, state):
with torch.no_grad():
action_logits, value = self.forward(state)
action_probs = torch.softmax(action_logits, dim=-1)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob.item(), value.item()
# ===================== PPO AGENT =====================
class PPOAgent:
def __init__(self, state_dim=15, action_dim=8, lr=3e-4, gamma=0.99,
epsilon=0.2, epochs=10, batch_size=64):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = ActorCritic(state_dim, action_dim).to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
self.epochs = epochs
self.batch_size = batch_size
self.last_save_time = time.time()
# Load existing model if available
if os.path.exists(MODEL_PATH):
try:
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
print(f"Loaded existing model from {MODEL_PATH}")
except Exception as e:
print(f"Failed to load model: {e}")
# Experience buffer
self.states = []
self.actions = []
self.log_probs = []
self.rewards = []
self.values = []
self.dones = []
self.action_mapping = [
# [yaw_delta, pitch_delta, forward, strafe, jump, sprint, crouch, swing]
[0, 0, 1.0, 0, 0, 1, 0, 0], # Forward sprint
[15, 0, 1.0, 0, 0, 1, 0, 1], # Forward right with swing
[-15, 0, 1.0, 0, 0, 1, 0, 1], # Forward left with swing
[0, 0, 1.0, 1, 0, 1, 0, 1], # Strafe right with swing
[0, 0, 1.0, -1, 0, 1, 0, 1], # Strafe left with swing
[0, 0, 1.0, 0, 1, 1, 0, 1], # Jump attack
[0, 0, 0, 1, 0, 0, 1, 0], # Crouch strafe right
[0, 0, 0, -1, 0, 0, 1, 0], # Crouch strafe left
]
def preprocess_state(self, state_dict):
"""Convert game state to normalized tensor"""
bot_pos = state_dict["bot_pos"]
target_pos = state_dict["target_pos"]
dx = target_pos["x"] - bot_pos["x"]
dy = target_pos["y"] - bot_pos["y"]
dz = target_pos["z"] - bot_pos["z"]
distance = state_dict["distance"]
# Sigmoid-like distance input: exp(-0.5 * ((dist - 3.0) / 1.5)^2)
dist_feature = math.exp(-0.5 * math.pow((distance - 3.0) / 1.5, 2))
# Normalize values
state_vector = [
dx / 100.0,
dy / 100.0,
dz / 100.0,
dist_feature,
state_dict.get("bot_health", 20.0) / 20.0,
state_dict.get("target_health", 20.0) / 20.0,
state_dict.get("bot_velocity_x", 0.0),
state_dict.get("bot_velocity_y", 0.0),
state_dict.get("bot_velocity_z", 0.0),
state_dict.get("target_velocity_x", 0.0),
state_dict.get("target_velocity_y", 0.0),
state_dict.get("target_velocity_z", 0.0),
1.0 if state_dict.get("target_is_swinging", False) else 0.0,
state_dict.get("attack_cooldown", 1.0),
state_dict.get("combat_timer", 0.0) / 100.0,
]
return torch.FloatTensor(state_vector).unsqueeze(0).to(self.device)
def select_action(self, state_dict):
"""Select action using current policy"""
state_tensor = self.preprocess_state(state_dict)
action_idx, log_prob, value = self.model.get_action(state_tensor)
# Store for training
self.states.append(state_tensor.cpu())
self.actions.append(action_idx)
self.log_probs.append(log_prob)
self.values.append(value)
return action_idx
def store_reward(self, reward, done=False):
"""Store reward for last action"""
self.rewards.append(reward)
self.dones.append(done)
def save_model(self):
"""Save current model state"""
try:
torch.save(self.model.state_dict(), MODEL_PATH)
self.last_save_time = time.time()
print(f"Model saved automatically to {MODEL_PATH}")
except Exception as e:
print(f"Failed to save model: {e}")
def train(self):
"""Train the model using collected experiences"""
if len(self.states) < self.batch_size:
return
# Compute returns and advantages
returns = []
advantages = []
R = 0
# Avoid empty rewards
if not self.rewards:
return
for i in reversed(range(len(self.rewards))):
R = self.rewards[i] + self.gamma * R * (1 - self.dones[i])
returns.insert(0, R)
advantage = R - self.values[i]
advantages.insert(0, advantage)
# Convert to tensors
states = torch.cat(self.states).to(self.device)
actions = torch.LongTensor(self.actions).to(self.device)
old_log_probs = torch.FloatTensor(self.log_probs).to(self.device)
returns = torch.FloatTensor(returns).to(self.device)
advantages = torch.FloatTensor(advantages).to(self.device)
# Normalize advantages
if advantages.std() < 1e-8:
advantages = advantages - advantages.mean()
else:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# PPO update
for _ in range(self.epochs):
# Get current policy predictions
action_logits, values = self.model(states)
action_probs = torch.softmax(action_logits, dim=-1)
dist = torch.distributions.Categorical(action_probs)
new_log_probs = dist.log_prob(actions)
entropy = dist.entropy()
# Compute ratios
ratios = torch.exp(new_log_probs - old_log_probs)
# Compute losses
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = 0.5 * (returns - values.squeeze()).pow(2).mean()
entropy_loss = -0.01 * entropy.mean()
loss = actor_loss + critic_loss + entropy_loss
# Update
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
self.optimizer.step()
# Clear buffers
self.states.clear()
self.actions.clear()
self.log_probs.clear()
self.rewards.clear()
self.values.clear()
self.dones.clear()
# ===================== GLOBAL PPO AGENT =====================
ppo_agent = PPOAgent(state_dim=15, action_dim=8)
current_yaw = 0.0
@app.post("/predict")
def predict():
global current_yaw
state = request.json
bot_pos = state["bot_pos"]
target_pos = state["target_pos"]
distance = state["distance"]
# Calculate target yaw
dx = target_pos["x"] - bot_pos["x"]
dz = target_pos["z"] - bot_pos["z"]
target_yaw = math.degrees(math.atan2(-dx, dz))
# Use PPO to select action
action_idx = ppo_agent.select_action(state)
action = ppo_agent.action_mapping[action_idx]
# Apply action
yaw_delta, pitch_delta, forward, strafe, jump, sprint, crouch, swing = action
current_yaw = target_yaw + yaw_delta
print(f"[DEBUG] Predict: ActionIdx={action_idx}, Yaw={current_yaw:.1f}, Pitch={pitch_delta}, FWD={forward}, STR={strafe}, Jump={jump}, Sprint={sprint}, Swing={swing}")
return jsonify({
"yaw": current_yaw,
"pitch": pitch_delta,
"move_forward": forward,
"move_strafe": strafe,
"jump": bool(jump),
"sprint": bool(sprint),
"crouch": bool(crouch),
"swing": bool(swing)
})
@app.post("/reward")
def receive_reward():
"""Receive reward data from Java for reinforcement learning"""
reward_data = request.json
# Use the total reward calculated in Java if provided
if "total_reward" in reward_data:
total_reward = reward_data["total_reward"]
else:
# Fallback to Python calculation if Java hasn't sent it (for backward compatibility)
bot_health = reward_data.get("bot_health", 20.0)
target_health = reward_data.get("target_health", 20.0)
damage_dealt = reward_data.get("damage_dealt", 0.0)
damage_taken = reward_data.get("damage_taken", 0.0)
distance = reward_data.get("distance", 0.0)
hit_success = reward_data.get("hit_success", False)
total_reward = 0.0
total_reward += damage_dealt * 10.0
total_reward -= damage_taken * 5.0
total_reward += 5.0 if hit_success else 0.0
total_reward -= distance * 0.1 if distance > 4 else 0.0
total_reward += 1.0 if 1.5 < distance < 3.5 else 0.0
total_reward += 100.0 if target_health <= 0 else 0.0
total_reward -= 100.0 if bot_health <= 0 else 0.0
done = reward_data.get("done", False)
# Store reward and train PPO
ppo_agent.store_reward(total_reward, done)
# Periodic auto-save every 5 minutes
time_since_last_save = time.time() - ppo_agent.last_save_time
if time_since_last_save > SAVE_INTERVAL:
ppo_agent.save_model()
else:
# Log every minute or so how much time is left for next save
if int(time_since_last_save) % 60 == 0:
print(f"[DEBUG] Next auto-save in {int(SAVE_INTERVAL - time_since_last_save)} seconds")
print(f"[DEBUG] Reward: Total={total_reward:.4f}, Done={done}")
# Train every batch
if len(ppo_agent.rewards) >= ppo_agent.batch_size:
ppo_agent.train()
print(f"PPO Model trained. Buffer cleared. Last total reward: {total_reward:.2f}")
return jsonify({
"status": "reward_received",
"total_reward": total_reward,
"model_trained": len(ppo_agent.rewards) == 0
})
if __name__ == "__main__":
# Disable reloader to prevent multiple initializations and potential state loss
app.run(port=5000, debug=True, use_reloader=False)

View File

@@ -0,0 +1,30 @@
{
"schemaVersion": 1,
"id": "goated",
"version": "${version}",
"name": "goated",
"description": "",
"authors": [],
"contact": {},
"license": "All-Rights-Reserved",
"icon": "assets/goated/icon.png",
"environment": "*",
"entrypoints": {
"main": [
"org.pvpbot.goated.Goated"
]
},
"mixins": [
{
"config": "goated.mixins.json"
}
],
"depends": {
"fabricloader": ">=0.19.1",
"fabric-api": "*",
"minecraft": "1.21.4"
}
}

View File

@@ -0,0 +1,15 @@
{
"required": true,
"minVersion": "0.8",
"package": "org.pvpbot.goated.mixin",
"compatibilityLevel": "JAVA_21",
"mixins": [
"ServerTickMixin"
],
"injectors": {
"defaultRequire": 1
},
"overwrites": {
"requireAnnotations": true
}
}