Cooked
This commit is contained in:
10
src/client/java/org/pvpbot/goated/client/GoatedClient.java
Normal file
10
src/client/java/org/pvpbot/goated/client/GoatedClient.java
Normal file
@@ -0,0 +1,10 @@
|
||||
package org.pvpbot.goated.client;
|
||||
|
||||
import net.fabricmc.api.ClientModInitializer;
|
||||
|
||||
public class GoatedClient implements ClientModInitializer {
|
||||
|
||||
@Override
|
||||
public void onInitializeClient() {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package org.pvpbot.goated.client;
|
||||
|
||||
import net.fabricmc.fabric.api.datagen.v1.DataGeneratorEntrypoint;
|
||||
import net.fabricmc.fabric.api.datagen.v1.FabricDataGenerator;
|
||||
|
||||
public class GoatedDataGenerator implements DataGeneratorEntrypoint {
|
||||
|
||||
@Override
|
||||
public void onInitializeDataGenerator(FabricDataGenerator fabricDataGenerator) {
|
||||
FabricDataGenerator.Pack pack = fabricDataGenerator.createPack();
|
||||
}
|
||||
}
|
||||
14
src/client/resources/goated.client.mixins.json
Normal file
14
src/client/resources/goated.client.mixins.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"required": true,
|
||||
"minVersion": "0.8",
|
||||
"package": "org.pvpbot.goated.mixin.client",
|
||||
"compatibilityLevel": "JAVA_21",
|
||||
"client": [
|
||||
],
|
||||
"injectors": {
|
||||
"defaultRequire": 1
|
||||
},
|
||||
"overwrites": {
|
||||
"requireAnnotations": true
|
||||
}
|
||||
}
|
||||
98
src/main/java/org/pvpbot/goated/Goated.java
Normal file
98
src/main/java/org/pvpbot/goated/Goated.java
Normal file
@@ -0,0 +1,98 @@
|
||||
package org.pvpbot.goated;
|
||||
|
||||
import com.mojang.brigadier.exceptions.CommandSyntaxException;
|
||||
import net.fabricmc.api.ModInitializer;
|
||||
import net.fabricmc.fabric.api.command.v2.CommandRegistrationCallback;
|
||||
import net.fabricmc.fabric.api.event.lifecycle.v1.ServerTickEvents;
|
||||
import net.fabricmc.loader.api.FabricLoader;
|
||||
import net.minecraft.server.MinecraftServer;
|
||||
import net.minecraft.server.command.CommandManager;
|
||||
import net.minecraft.server.command.ServerCommandSource;
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
import net.minecraft.text.Text;
|
||||
import net.minecraft.util.math.Vec3d;
|
||||
import org.pvpbot.goated.ai.BotController;
|
||||
import org.pvpbot.goated.ai.BotRegistry;
|
||||
|
||||
public class Goated implements ModInitializer {
|
||||
|
||||
private static final String BOT_NAME = "PvPBOT";
|
||||
|
||||
private ServerPlayerEntity bot;
|
||||
private Vec3d moveTarget = null;
|
||||
|
||||
@Override
|
||||
public void onInitialize() {
|
||||
|
||||
if (!FabricLoader.getInstance().isModLoaded("carpet")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// ===== SPAWN BOT COMMAND =====
|
||||
CommandRegistrationCallback.EVENT.register((dispatcher, registryAccess, environment) ->
|
||||
dispatcher.register(
|
||||
CommandManager.literal("spawnbot")
|
||||
.requires(source -> source.hasPermissionLevel(2))
|
||||
.executes(context -> {
|
||||
ServerCommandSource source = context.getSource();
|
||||
String ownerName = source.getEntity() instanceof ServerPlayerEntity player ? player.getName().getString() : null;
|
||||
spawnBot(source);
|
||||
source.sendFeedback(() ->
|
||||
Text.literal("Spawning " + BOT_NAME + "..."),
|
||||
false
|
||||
);
|
||||
return 1;
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// ===================== BOT SPAWN =====================
|
||||
|
||||
private void spawnBot(ServerCommandSource source) {
|
||||
try {
|
||||
source.getServer().getCommandManager().getDispatcher().execute(
|
||||
"player " + BOT_NAME + " spawn",
|
||||
source
|
||||
);
|
||||
|
||||
source.getServer().execute(() -> {
|
||||
bot = source.getServer().getPlayerManager().getPlayer(BOT_NAME);
|
||||
if (bot == null) return;
|
||||
|
||||
resetPlayer(source.getServer(), bot);
|
||||
});
|
||||
|
||||
} catch (CommandSyntaxException e) {
|
||||
source.sendError(Text.literal("Failed to spawn bot: " + e.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== EQUIP BOT =====================
|
||||
|
||||
public static void equipPlayer(MinecraftServer server, ServerPlayerEntity player) {
|
||||
try {
|
||||
var dispatcher = server.getCommandManager().getDispatcher();
|
||||
var source = server.getCommandSource();
|
||||
String playerName = player.getName().getString();
|
||||
|
||||
dispatcher.execute("item replace entity " + playerName + " armor.head with minecraft:diamond_helmet", source);
|
||||
dispatcher.execute("item replace entity " + playerName + " armor.chest with minecraft:diamond_chestplate", source);
|
||||
dispatcher.execute("item replace entity " + playerName + " armor.legs with minecraft:diamond_leggings", source);
|
||||
dispatcher.execute("item replace entity " + playerName + " armor.feet with minecraft:diamond_boots", source);
|
||||
dispatcher.execute("item replace entity " + playerName + " weapon.mainhand with minecraft:diamond_sword", source);
|
||||
|
||||
} catch (CommandSyntaxException e) {
|
||||
System.err.println("Failed to equip player " + player.getName().getString() + ": " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public static void resetPlayer(MinecraftServer server, ServerPlayerEntity player) {
|
||||
if (player == null) return;
|
||||
|
||||
player.setHealth(player.getMaxHealth());
|
||||
player.getHungerManager().setFoodLevel(20);
|
||||
player.getHungerManager().setSaturationLevel(20f);
|
||||
equipPlayer(server, player);
|
||||
}
|
||||
}
|
||||
31
src/main/java/org/pvpbot/goated/ai/BotAIState.java
Normal file
31
src/main/java/org/pvpbot/goated/ai/BotAIState.java
Normal file
@@ -0,0 +1,31 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
public class BotAIState {
|
||||
|
||||
// 🔥 Latest values from Python AI (updated async)
|
||||
public volatile float yaw;
|
||||
public volatile float pitch;
|
||||
|
||||
public volatile float moveForward;
|
||||
public volatile float moveStrafe;
|
||||
|
||||
public volatile boolean jump;
|
||||
public volatile boolean sprint;
|
||||
public volatile boolean crouch;
|
||||
public volatile boolean swing;
|
||||
|
||||
// optional debug
|
||||
public volatile long lastUpdateTime;
|
||||
|
||||
public void reset() {
|
||||
yaw = 0;
|
||||
pitch = 0;
|
||||
moveForward = 0;
|
||||
moveStrafe = 0;
|
||||
|
||||
jump = false;
|
||||
sprint = false;
|
||||
crouch = false;
|
||||
swing = false;
|
||||
}
|
||||
}
|
||||
204
src/main/java/org/pvpbot/goated/ai/BotBrain.java
Normal file
204
src/main/java/org/pvpbot/goated/ai/BotBrain.java
Normal file
@@ -0,0 +1,204 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
import net.minecraft.server.MinecraftServer;
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class BotBrain {
|
||||
|
||||
private final ServerPlayerEntity bot;
|
||||
private ServerPlayerEntity target;
|
||||
|
||||
private final BotInputState aiState = new BotInputState();
|
||||
|
||||
private final AtomicInteger tickCounter = new AtomicInteger(0);
|
||||
private final AtomicBoolean requestInFlight = new AtomicBoolean(false);
|
||||
|
||||
private float lastBotHealth = 20f;
|
||||
private float lastTargetHealth = 20f;
|
||||
|
||||
public BotBrain(ServerPlayerEntity bot, ServerPlayerEntity target) {
|
||||
this.bot = bot;
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
public BotInputState getAIState() {
|
||||
return aiState;
|
||||
}
|
||||
|
||||
public ServerPlayerEntity getTarget() {
|
||||
return target;
|
||||
}
|
||||
|
||||
public void setTarget(ServerPlayerEntity target) {
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
// ===================== MAIN TICK =====================
|
||||
|
||||
public void tick(MinecraftServer server) {
|
||||
// Prevent overlapping requests if one takes longer than a tick
|
||||
if (!requestInFlight.compareAndSet(false, true)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build request
|
||||
String json = BotStateSerializer.toJson(bot, target);
|
||||
|
||||
// Send state to AI server every tick
|
||||
BotHttpClient.sendState(json)
|
||||
.thenAccept(response -> {
|
||||
try {
|
||||
System.out.println("[DEBUG] BotBrain received response: " + response);
|
||||
parseAndApply(response, aiState);
|
||||
sendRewardUpdate();
|
||||
} finally {
|
||||
requestInFlight.set(false);
|
||||
}
|
||||
})
|
||||
.exceptionally(err -> {
|
||||
requestInFlight.set(false);
|
||||
err.printStackTrace();
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
private void sendRewardUpdate() {
|
||||
float currentBotHealth = bot.getHealth();
|
||||
float currentTargetHealth = target.getHealth();
|
||||
float attackCooldown = bot.getAttackCooldownProgress(0.5f);
|
||||
|
||||
float damageDealt = Math.max(0, lastTargetHealth - currentTargetHealth);
|
||||
float damageTaken = Math.max(0, lastBotHealth - currentBotHealth);
|
||||
|
||||
double dx = target.getX() - bot.getX();
|
||||
double dy = target.getY() - bot.getY();
|
||||
double dz = target.getZ() - bot.getZ();
|
||||
double distance = Math.sqrt(dx * dx + dy * dy + dz * dz);
|
||||
|
||||
// Look-at reward
|
||||
double botYaw = bot.getYaw();
|
||||
double botPitch = bot.getPitch();
|
||||
|
||||
// Target direction
|
||||
double tx = target.getX() - bot.getX();
|
||||
double ty = target.getEyeY() - bot.getEyeY();
|
||||
double tz = target.getZ() - bot.getZ();
|
||||
double tDist = Math.sqrt(tx * tx + ty * ty + tz * tz);
|
||||
|
||||
double targetYaw = Math.toDegrees(Math.atan2(-tx, tz));
|
||||
double targetPitch = Math.toDegrees(Math.asin(-ty / tDist));
|
||||
|
||||
double yawDiff = Math.abs(normalizeDegrees(targetYaw - botYaw));
|
||||
double pitchDiff = Math.abs(normalizeDegrees(targetPitch - botPitch));
|
||||
|
||||
double lookReward = (1.0 - (yawDiff + pitchDiff) / 180.0) * 2.0;
|
||||
lookReward = Math.max(0, lookReward);
|
||||
|
||||
boolean done = bot.isDead() || target.isDead() || bot.isRemoved() || target.isRemoved();
|
||||
|
||||
// Calculate total reward in Java
|
||||
double totalReward = 0.0;
|
||||
|
||||
// Cooldown-weighted damage reward
|
||||
// If attackCooldown is 1.0, it's a full strength hit.
|
||||
// If it's low, it's a weak hit.
|
||||
if (damageDealt > 0) {
|
||||
if (attackCooldown < 0.85) {
|
||||
// Penalize weak hits (spamming) - significantly increased penalty
|
||||
totalReward -= 10.0 * (1.0 - attackCooldown);
|
||||
} else {
|
||||
totalReward += damageDealt * 10.0 * attackCooldown;
|
||||
totalReward += 5.0 * attackCooldown; // Bonus for landing a well-timed hit
|
||||
}
|
||||
}
|
||||
|
||||
totalReward -= damageTaken * 5.0;
|
||||
|
||||
// Sigmoid-like distance reward: peak at 3.0, lower for < 3.0, tapering off for > 3.0
|
||||
// Using: exp(-0.5 * ((dist - 3.0) / 1.5)^2) * 2.0
|
||||
double distReward = Math.exp(-0.5 * Math.pow((distance - 3.0) / 1.5, 2)) * 2.0;
|
||||
totalReward += distReward;
|
||||
|
||||
totalReward += lookReward;
|
||||
|
||||
if (currentTargetHealth <= 0 || target.isDead()) totalReward += 100.0;
|
||||
if (currentBotHealth <= 0 || bot.isDead()) totalReward -= 100.0;
|
||||
|
||||
// Reset trackers on death to avoid huge damage_dealt/taken reward spikes on respawn
|
||||
if (done) {
|
||||
currentBotHealth = 20f;
|
||||
currentTargetHealth = 20f;
|
||||
}
|
||||
|
||||
System.out.println("[DEBUG] Calculated Reward: " + totalReward + " (Dealt: " + damageDealt + ", Taken: " + damageTaken + ", Dist: " + distance + ", LookRew: " + lookReward + ", DistRew: " + distReward + ")");
|
||||
|
||||
String rewardJson = """
|
||||
{
|
||||
"bot_health": %.1f,
|
||||
"target_health": %.1f,
|
||||
"damage_dealt": %.1f,
|
||||
"damage_taken": %.1f,
|
||||
"distance": %.3f,
|
||||
"attack_cooldown": %.3f,
|
||||
"hit_success": %b,
|
||||
"done": %b,
|
||||
"total_reward": %.4f
|
||||
}
|
||||
""".formatted(
|
||||
currentBotHealth,
|
||||
currentTargetHealth,
|
||||
damageDealt,
|
||||
damageTaken,
|
||||
distance,
|
||||
attackCooldown,
|
||||
damageDealt > 0,
|
||||
done,
|
||||
totalReward
|
||||
);
|
||||
|
||||
BotHttpClient.sendReward(rewardJson)
|
||||
.thenAccept(resp -> System.out.println("[DEBUG] Reward sent. Server response: " + resp));
|
||||
|
||||
lastBotHealth = currentBotHealth;
|
||||
lastTargetHealth = currentTargetHealth;
|
||||
}
|
||||
|
||||
// ===================== SAFE JSON PARSER =====================
|
||||
|
||||
private void parseAndApply(String json, BotInputState state) {
|
||||
|
||||
JsonObject obj = JsonParser.parseString(json).getAsJsonObject();
|
||||
|
||||
state.yaw = obj.has("yaw") ? obj.get("yaw").getAsFloat() : 0;
|
||||
state.pitch = obj.has("pitch") ? obj.get("pitch").getAsFloat() : 0;
|
||||
|
||||
state.moveForward = obj.has("move_forward") ? obj.get("move_forward").getAsFloat() : 0;
|
||||
state.moveStrafe = obj.has("move_strafe") ? obj.get("move_strafe").getAsFloat() : 0;
|
||||
|
||||
state.jump = obj.has("jump") && obj.get("jump").getAsBoolean();
|
||||
state.sprint = obj.has("sprint") && obj.get("sprint").getAsBoolean();
|
||||
state.crouch = obj.has("crouch") && obj.get("crouch").getAsBoolean();
|
||||
state.swing = obj.has("swing") && obj.get("swing").getAsBoolean();
|
||||
}
|
||||
|
||||
// ===================== CHAT HELPER =====================
|
||||
|
||||
public static void botSay(MinecraftServer server, ServerPlayerEntity bot, String message) {
|
||||
server.getCommandManager().executeWithPrefix(
|
||||
bot.getCommandSource(),
|
||||
"say " + message
|
||||
);
|
||||
}
|
||||
|
||||
private double normalizeDegrees(double degrees) {
|
||||
double result = degrees % 360.0;
|
||||
if (result >= 180.0) result -= 360.0;
|
||||
if (result < -180.0) result += 360.0;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
54
src/main/java/org/pvpbot/goated/ai/BotController.java
Normal file
54
src/main/java/org/pvpbot/goated/ai/BotController.java
Normal file
@@ -0,0 +1,54 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
|
||||
public class BotController {
|
||||
|
||||
private final ServerPlayerEntity bot;
|
||||
private String ownerName;
|
||||
|
||||
private final BotInputState input = new BotInputState();
|
||||
|
||||
private BotBrain brain;
|
||||
|
||||
public BotController(ServerPlayerEntity bot) {
|
||||
this.bot = bot;
|
||||
}
|
||||
|
||||
public ServerPlayerEntity getBot() {
|
||||
return bot;
|
||||
}
|
||||
|
||||
public BotInputState input() {
|
||||
return input;
|
||||
}
|
||||
|
||||
public BotBrain getBrain() {
|
||||
return brain;
|
||||
}
|
||||
|
||||
public void setBrain(BotBrain brain) {
|
||||
this.brain = brain;
|
||||
}
|
||||
|
||||
public String getOwnerName() {
|
||||
return ownerName;
|
||||
}
|
||||
|
||||
public void setOwnerName(String ownerName) {
|
||||
this.ownerName = ownerName;
|
||||
}
|
||||
|
||||
public void setInput(BotInputState newInput) {
|
||||
this.input.yaw = newInput.yaw;
|
||||
this.input.pitch = newInput.pitch;
|
||||
|
||||
this.input.moveForward = newInput.moveForward;
|
||||
this.input.moveStrafe = newInput.moveStrafe;
|
||||
|
||||
this.input.jump = newInput.jump;
|
||||
this.input.sprint = newInput.sprint;
|
||||
this.input.crouch = newInput.crouch;
|
||||
this.input.swing = newInput.swing;
|
||||
}
|
||||
}
|
||||
35
src/main/java/org/pvpbot/goated/ai/BotHttpClient.java
Normal file
35
src/main/java/org/pvpbot/goated/ai/BotHttpClient.java
Normal file
@@ -0,0 +1,35 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public class BotHttpClient {
|
||||
|
||||
private static final HttpClient CLIENT = HttpClient.newHttpClient();
|
||||
|
||||
public static CompletableFuture<String> sendState(String json) {
|
||||
HttpRequest request = HttpRequest.newBuilder()
|
||||
.uri(URI.create("http://127.0.0.1:5000/predict"))
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(json))
|
||||
.build();
|
||||
|
||||
return CLIENT.sendAsync(request, HttpResponse.BodyHandlers.ofString())
|
||||
.thenApply(HttpResponse::body);
|
||||
}
|
||||
|
||||
public static CompletableFuture<String> sendReward(String json) {
|
||||
HttpRequest request = HttpRequest.newBuilder()
|
||||
.uri(URI.create("http://127.0.0.1:5000/reward"))
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(json))
|
||||
.build();
|
||||
|
||||
return CLIENT.sendAsync(request, HttpResponse.BodyHandlers.ofString())
|
||||
.thenApply(HttpResponse::body);
|
||||
}
|
||||
}
|
||||
148
src/main/java/org/pvpbot/goated/ai/BotInputApplier.java
Normal file
148
src/main/java/org/pvpbot/goated/ai/BotInputApplier.java
Normal file
@@ -0,0 +1,148 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import net.minecraft.command.argument.EntityAnchorArgumentType;
|
||||
import net.minecraft.server.MinecraftServer;
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
import net.minecraft.util.math.Vec3d;
|
||||
|
||||
public class BotInputApplier {
|
||||
|
||||
private static boolean lastSprintState = false;
|
||||
private static boolean lastCrouchState = false;
|
||||
|
||||
public static void apply(MinecraftServer server, BotController controller) {
|
||||
|
||||
ServerPlayerEntity bot = controller.getBot();
|
||||
BotInputState input = controller.input();
|
||||
|
||||
String name = bot.getName().getString();
|
||||
|
||||
// ================= LOOK =================
|
||||
applyLook(bot, input);
|
||||
|
||||
// ================= DEADZONE INPUT =================
|
||||
float forward = applyDeadzone(input.moveForward);
|
||||
float strafe = applyDeadzone(input.moveStrafe);
|
||||
|
||||
// ================= SPRINT (NO BACKWARD SPRINT) =================
|
||||
boolean shouldSprint = input.sprint && forward >= 0;
|
||||
|
||||
if (shouldSprint != lastSprintState) {
|
||||
lastSprintState = shouldSprint;
|
||||
|
||||
server.getCommandManager().executeWithPrefix(
|
||||
server.getCommandSource(),
|
||||
"player " + name + " sprint"
|
||||
);
|
||||
}
|
||||
|
||||
// ================= CROUCH (EDGE TRIGGER) =================
|
||||
if (input.crouch != lastCrouchState) {
|
||||
lastCrouchState = input.crouch;
|
||||
|
||||
server.getCommandManager().executeWithPrefix(
|
||||
server.getCommandSource(),
|
||||
"player " + name + " crouch"
|
||||
);
|
||||
}
|
||||
|
||||
// ================= JUMP =================
|
||||
if (input.jump && bot.isOnGround()) {
|
||||
server.getCommandManager().executeWithPrefix(
|
||||
server.getCommandSource(),
|
||||
"player " + name + " jump"
|
||||
);
|
||||
}
|
||||
|
||||
// ================= MOVEMENT =================
|
||||
applyMovement(bot, forward, strafe);
|
||||
|
||||
// ================= ATTACK =================
|
||||
if (input.swing) {
|
||||
server.getCommandManager().executeWithPrefix(
|
||||
server.getCommandSource(),
|
||||
"player " + name + " attack"
|
||||
);
|
||||
bot.swingHand(net.minecraft.util.Hand.MAIN_HAND);
|
||||
}
|
||||
|
||||
input.clearTransient();
|
||||
}
|
||||
|
||||
// ================= MOVEMENT =================
|
||||
|
||||
private static void applyMovement(ServerPlayerEntity bot, float forwardInput, float strafeInput) {
|
||||
|
||||
Vec3d velocity = bot.getVelocity();
|
||||
|
||||
if (forwardInput == 0 && strafeInput == 0) return;
|
||||
|
||||
Vec3d forward = getForwardVector(bot.getYaw());
|
||||
Vec3d right = new Vec3d(-forward.z, 0, forward.x);
|
||||
|
||||
Vec3d moveDir = Vec3d.ZERO;
|
||||
|
||||
// ✔ PURE DISCRETE INTENT (no scaling)
|
||||
if (forwardInput != 0) {
|
||||
moveDir = moveDir.add(forward.multiply(forwardInput));
|
||||
}
|
||||
|
||||
if (strafeInput != 0) {
|
||||
moveDir = moveDir.add(right.multiply(strafeInput));
|
||||
}
|
||||
|
||||
if (moveDir.lengthSquared() == 0) return;
|
||||
|
||||
moveDir = moveDir.normalize();
|
||||
|
||||
double steerStrength = bot.isOnGround() ? 0.08 : 0.03;
|
||||
|
||||
Vec3d steering = moveDir.multiply(steerStrength);
|
||||
|
||||
Vec3d newVel = new Vec3d(
|
||||
velocity.x + steering.x,
|
||||
velocity.y,
|
||||
velocity.z + steering.z
|
||||
);
|
||||
|
||||
bot.setVelocity(newVel);
|
||||
bot.velocityModified = true;
|
||||
}
|
||||
|
||||
// ================= LOOK =================
|
||||
|
||||
private static void applyLook(ServerPlayerEntity bot, BotInputState input) {
|
||||
Vec3d direction = getDirectionVector(input.yaw, input.pitch);
|
||||
|
||||
Vec3d eyePos = bot.getEyePos();
|
||||
Vec3d target = eyePos.add(direction.multiply(6));
|
||||
|
||||
bot.lookAt(EntityAnchorArgumentType.EntityAnchor.EYES, target);
|
||||
}
|
||||
|
||||
// ================= DEADZONE =================
|
||||
|
||||
private static float applyDeadzone(float v) {
|
||||
if (v > 0.25f) return 1f;
|
||||
if (v < -0.25f) return -1f;
|
||||
return 0f;
|
||||
}
|
||||
|
||||
// ================= HELPERS =================
|
||||
|
||||
private static Vec3d getForwardVector(float yaw) {
|
||||
double rad = Math.toRadians(yaw);
|
||||
return new Vec3d(-Math.sin(rad), 0, Math.cos(rad)).normalize();
|
||||
}
|
||||
|
||||
private static Vec3d getDirectionVector(float yaw, float pitch) {
|
||||
double yawRad = Math.toRadians(yaw);
|
||||
double pitchRad = Math.toRadians(pitch);
|
||||
|
||||
double x = -Math.sin(yawRad) * Math.cos(pitchRad);
|
||||
double y = -Math.sin(pitchRad);
|
||||
double z = Math.cos(yawRad) * Math.cos(pitchRad);
|
||||
|
||||
return new Vec3d(x, y, z).normalize();
|
||||
}
|
||||
}
|
||||
24
src/main/java/org/pvpbot/goated/ai/BotInputState.java
Normal file
24
src/main/java/org/pvpbot/goated/ai/BotInputState.java
Normal file
@@ -0,0 +1,24 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
public class BotInputState {
|
||||
|
||||
// ===== LOOK =====
|
||||
public float yaw;
|
||||
public float pitch;
|
||||
|
||||
// ===== MOVEMENT (WASD STYLE) =====
|
||||
public float moveForward = 0; // -1 back, +1 forward
|
||||
public float moveStrafe = 0; // -1 left, +1 right
|
||||
|
||||
public boolean sprint = false;
|
||||
public boolean crouch = false;
|
||||
|
||||
// ===== ACTIONS =====
|
||||
public boolean jump = false;
|
||||
public boolean swing = false;
|
||||
|
||||
public void clearTransient() {
|
||||
jump = false;
|
||||
swing = false;
|
||||
}
|
||||
}
|
||||
19
src/main/java/org/pvpbot/goated/ai/BotRegistry.java
Normal file
19
src/main/java/org/pvpbot/goated/ai/BotRegistry.java
Normal file
@@ -0,0 +1,19 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class BotRegistry {
|
||||
|
||||
private static final Map<String, BotController> BOTS = new HashMap<>();
|
||||
|
||||
public static void register(BotController controller) {
|
||||
BOTS.put(controller.getBot().getName().getString(), controller);
|
||||
}
|
||||
|
||||
public static BotController get(String name) {
|
||||
return BOTS.get(name);
|
||||
}
|
||||
}
|
||||
74
src/main/java/org/pvpbot/goated/ai/BotStateSerializer.java
Normal file
74
src/main/java/org/pvpbot/goated/ai/BotStateSerializer.java
Normal file
@@ -0,0 +1,74 @@
|
||||
package org.pvpbot.goated.ai;
|
||||
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
|
||||
public class BotStateSerializer {
|
||||
|
||||
public static String toJson(ServerPlayerEntity bot, ServerPlayerEntity target) {
|
||||
|
||||
double dx = target.getX() - bot.getX();
|
||||
double dy = target.getY() - bot.getY();
|
||||
double dz = target.getZ() - bot.getZ();
|
||||
|
||||
double dist = Math.sqrt(dx * dx + dy * dy + dz * dz);
|
||||
|
||||
boolean targetIsSwinging = target.handSwinging;
|
||||
|
||||
String json = """
|
||||
{
|
||||
"bot_pos": {
|
||||
"x": %.3f,
|
||||
"y": %.3f,
|
||||
"z": %.3f
|
||||
},
|
||||
"bot_vel": {
|
||||
"x": %.3f,
|
||||
"y": %.3f,
|
||||
"z": %.3f
|
||||
},
|
||||
"target_pos": {
|
||||
"x": %.3f,
|
||||
"y": %.3f,
|
||||
"z": %.3f
|
||||
},
|
||||
"target_vel": {
|
||||
"x": %.3f,
|
||||
"y": %.3f,
|
||||
"z": %.3f
|
||||
},
|
||||
"distance": %.3f,
|
||||
"bot_health": %.1f,
|
||||
"target_health": %.1f,
|
||||
"bot_velocity_x": %.3f,
|
||||
"bot_velocity_y": %.3f,
|
||||
"bot_velocity_z": %.3f,
|
||||
"target_velocity_x": %.3f,
|
||||
"target_velocity_y": %.3f,
|
||||
"target_velocity_z": %.3f,
|
||||
"target_is_swinging": %b,
|
||||
"attack_cooldown": %.3f,
|
||||
"combat_timer": %d
|
||||
}
|
||||
""".formatted(
|
||||
bot.getX(), bot.getY(), bot.getZ(),
|
||||
bot.getVelocity().x, bot.getVelocity().y, bot.getVelocity().z,
|
||||
target.getX(), target.getY(), target.getZ(),
|
||||
target.getVelocity().x, target.getVelocity().y, target.getVelocity().z,
|
||||
dist,
|
||||
bot.getHealth(),
|
||||
target.getHealth(),
|
||||
bot.getVelocity().x,
|
||||
bot.getVelocity().y,
|
||||
bot.getVelocity().z,
|
||||
target.getVelocity().x,
|
||||
target.getVelocity().y,
|
||||
target.getVelocity().z,
|
||||
targetIsSwinging,
|
||||
bot.getAttackCooldownProgress(0.5f),
|
||||
bot.age
|
||||
);
|
||||
|
||||
System.out.println("[DEBUG] BotState JSON: " + json);
|
||||
return json;
|
||||
}
|
||||
}
|
||||
96
src/main/java/org/pvpbot/goated/mixin/ServerTickMixin.java
Normal file
96
src/main/java/org/pvpbot/goated/mixin/ServerTickMixin.java
Normal file
@@ -0,0 +1,96 @@
|
||||
package org.pvpbot.goated.mixin;
|
||||
|
||||
import net.minecraft.server.MinecraftServer;
|
||||
import net.minecraft.server.network.ServerPlayerEntity;
|
||||
import org.pvpbot.goated.ai.*;
|
||||
import org.spongepowered.asm.mixin.Mixin;
|
||||
import org.spongepowered.asm.mixin.injection.At;
|
||||
import org.spongepowered.asm.mixin.injection.Inject;
|
||||
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
|
||||
|
||||
@Mixin(MinecraftServer.class)
|
||||
public class ServerTickMixin {
|
||||
|
||||
private String targetName = null;
|
||||
|
||||
@Inject(method = "tick", at = @At("TAIL"))
|
||||
private void onTick(CallbackInfo ci) {
|
||||
MinecraftServer server = (MinecraftServer)(Object)this;
|
||||
|
||||
ServerPlayerEntity botPlayer =
|
||||
server.getPlayerManager().getPlayer("PvPBOT");
|
||||
|
||||
if (botPlayer == null) return;
|
||||
|
||||
// =========================
|
||||
// GET OR CREATE CONTROLLER
|
||||
// =========================
|
||||
BotController controller = BotRegistry.get("PvPBOT");
|
||||
|
||||
if (controller == null) {
|
||||
controller = new BotController(botPlayer);
|
||||
BotRegistry.register(controller);
|
||||
}
|
||||
|
||||
// =========================
|
||||
// GET TARGET
|
||||
// =========================
|
||||
ServerPlayerEntity target = null;
|
||||
if (targetName != null) {
|
||||
target = server.getPlayerManager().getPlayer(targetName);
|
||||
}
|
||||
|
||||
// If current target is gone or dead, try to find a new one or the same one if it respawned
|
||||
if (target == null || target.isRemoved()) {
|
||||
target = findTarget(server, botPlayer, targetName);
|
||||
if (target != null) {
|
||||
targetName = target.getName().getString();
|
||||
}
|
||||
}
|
||||
|
||||
if (target == null) return;
|
||||
|
||||
// =========================
|
||||
// HEAL BOT AND TARGET IF EITHER DIED
|
||||
// =========================
|
||||
if (target.isDead() || target.getHealth() <= 0 || botPlayer.isDead() || botPlayer.getHealth() <= 0) {
|
||||
org.pvpbot.goated.Goated.resetPlayer(server, botPlayer);
|
||||
org.pvpbot.goated.Goated.resetPlayer(server, target);
|
||||
}
|
||||
|
||||
// =========================
|
||||
// INIT BRAIN IF NEEDED
|
||||
// =========================
|
||||
if (controller.getBrain() == null || !controller.getBrain().getTarget().equals(target)) {
|
||||
controller.setBrain(new BotBrain(botPlayer, target));
|
||||
}
|
||||
|
||||
// =========================
|
||||
// UPDATE BRAIN
|
||||
// =========================
|
||||
controller.getBrain().tick(server);
|
||||
|
||||
// =========================
|
||||
// APPLY AI → INPUT
|
||||
// =========================
|
||||
controller.setInput(controller.getBrain().getAIState());
|
||||
BotInputApplier.apply(server, controller);
|
||||
}
|
||||
|
||||
// =========================
|
||||
// SIMPLE TARGET FINDER
|
||||
// =========================
|
||||
private ServerPlayerEntity findTarget(MinecraftServer server, ServerPlayerEntity bot, String preferredName) {
|
||||
if (preferredName != null) {
|
||||
ServerPlayerEntity preferred = server.getPlayerManager().getPlayer(preferredName);
|
||||
if (preferred != null && !preferred.isRemoved()) {
|
||||
return preferred;
|
||||
}
|
||||
}
|
||||
|
||||
return server.getPlayerManager().getPlayerList().stream()
|
||||
.filter(p -> !p.getName().getString().equals("PvPBOT"))
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
}
|
||||
37
src/main/proto/bot.proto
Normal file
37
src/main/proto/bot.proto
Normal file
@@ -0,0 +1,37 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package bot;
|
||||
|
||||
option java_multiple_files = true;
|
||||
|
||||
service BotAI {
|
||||
rpc Predict (BotState) returns (BotAction);
|
||||
}
|
||||
|
||||
message Vec3 {
|
||||
double x = 1;
|
||||
double y = 2;
|
||||
double z = 3;
|
||||
}
|
||||
|
||||
message BotState {
|
||||
Vec3 bot_pos = 1;
|
||||
Vec3 bot_vel = 2;
|
||||
|
||||
Vec3 target_pos = 3;
|
||||
|
||||
double distance = 4;
|
||||
}
|
||||
|
||||
message BotAction {
|
||||
float yaw = 1;
|
||||
float pitch = 2;
|
||||
|
||||
float move_forward = 3;
|
||||
float move_strafe = 4;
|
||||
|
||||
bool jump = 5;
|
||||
bool sprint = 6;
|
||||
bool crouch = 7;
|
||||
bool swing = 8;
|
||||
}
|
||||
BIN
src/main/python/__pycache__/server.cpython-313.pyc
Normal file
BIN
src/main/python/__pycache__/server.cpython-313.pyc
Normal file
Binary file not shown.
BIN
src/main/python/ppo_model.pth
Normal file
BIN
src/main/python/ppo_model.pth
Normal file
Binary file not shown.
343
src/main/python/server.py
Normal file
343
src/main/python/server.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import os
|
||||
import time
|
||||
from flask import Flask, request, jsonify
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
MODEL_PATH = "ppo_model.pth"
|
||||
SAVE_INTERVAL = 300 # 5 minutes in seconds
|
||||
|
||||
# ===================== PPO NEURAL NETWORK =====================
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
"""Neural network with ~20M parameters for PPO"""
|
||||
|
||||
def __init__(self, state_dim=15, action_dim=8):
|
||||
super(ActorCritic, self).__init__()
|
||||
|
||||
# Shared feature extractor (~15M parameters)
|
||||
self.shared = nn.Sequential(
|
||||
nn.Linear(state_dim, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, 1024),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Actor head for policy (~2.5M parameters)
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, action_dim)
|
||||
)
|
||||
|
||||
# Critic head for value function (~2.5M parameters)
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(512, 1)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
features = self.shared(state)
|
||||
action_logits = self.actor(features)
|
||||
value = self.critic(features)
|
||||
return action_logits, value
|
||||
|
||||
def get_action(self, state):
|
||||
with torch.no_grad():
|
||||
action_logits, value = self.forward(state)
|
||||
action_probs = torch.softmax(action_logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(action_probs)
|
||||
action = dist.sample()
|
||||
log_prob = dist.log_prob(action)
|
||||
return action.item(), log_prob.item(), value.item()
|
||||
|
||||
|
||||
# ===================== PPO AGENT =====================
|
||||
|
||||
class PPOAgent:
|
||||
def __init__(self, state_dim=15, action_dim=8, lr=3e-4, gamma=0.99,
|
||||
epsilon=0.2, epochs=10, batch_size=64):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = ActorCritic(state_dim, action_dim).to(self.device)
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.last_save_time = time.time()
|
||||
|
||||
# Load existing model if available
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
|
||||
print(f"Loaded existing model from {MODEL_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
|
||||
# Experience buffer
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.log_probs = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.dones = []
|
||||
|
||||
self.action_mapping = [
|
||||
# [yaw_delta, pitch_delta, forward, strafe, jump, sprint, crouch, swing]
|
||||
[0, 0, 1.0, 0, 0, 1, 0, 0], # Forward sprint
|
||||
[15, 0, 1.0, 0, 0, 1, 0, 1], # Forward right with swing
|
||||
[-15, 0, 1.0, 0, 0, 1, 0, 1], # Forward left with swing
|
||||
[0, 0, 1.0, 1, 0, 1, 0, 1], # Strafe right with swing
|
||||
[0, 0, 1.0, -1, 0, 1, 0, 1], # Strafe left with swing
|
||||
[0, 0, 1.0, 0, 1, 1, 0, 1], # Jump attack
|
||||
[0, 0, 0, 1, 0, 0, 1, 0], # Crouch strafe right
|
||||
[0, 0, 0, -1, 0, 0, 1, 0], # Crouch strafe left
|
||||
]
|
||||
|
||||
def preprocess_state(self, state_dict):
|
||||
"""Convert game state to normalized tensor"""
|
||||
bot_pos = state_dict["bot_pos"]
|
||||
target_pos = state_dict["target_pos"]
|
||||
|
||||
dx = target_pos["x"] - bot_pos["x"]
|
||||
dy = target_pos["y"] - bot_pos["y"]
|
||||
dz = target_pos["z"] - bot_pos["z"]
|
||||
distance = state_dict["distance"]
|
||||
|
||||
# Sigmoid-like distance input: exp(-0.5 * ((dist - 3.0) / 1.5)^2)
|
||||
dist_feature = math.exp(-0.5 * math.pow((distance - 3.0) / 1.5, 2))
|
||||
|
||||
# Normalize values
|
||||
state_vector = [
|
||||
dx / 100.0,
|
||||
dy / 100.0,
|
||||
dz / 100.0,
|
||||
dist_feature,
|
||||
state_dict.get("bot_health", 20.0) / 20.0,
|
||||
state_dict.get("target_health", 20.0) / 20.0,
|
||||
state_dict.get("bot_velocity_x", 0.0),
|
||||
state_dict.get("bot_velocity_y", 0.0),
|
||||
state_dict.get("bot_velocity_z", 0.0),
|
||||
state_dict.get("target_velocity_x", 0.0),
|
||||
state_dict.get("target_velocity_y", 0.0),
|
||||
state_dict.get("target_velocity_z", 0.0),
|
||||
1.0 if state_dict.get("target_is_swinging", False) else 0.0,
|
||||
state_dict.get("attack_cooldown", 1.0),
|
||||
state_dict.get("combat_timer", 0.0) / 100.0,
|
||||
]
|
||||
|
||||
return torch.FloatTensor(state_vector).unsqueeze(0).to(self.device)
|
||||
|
||||
def select_action(self, state_dict):
|
||||
"""Select action using current policy"""
|
||||
state_tensor = self.preprocess_state(state_dict)
|
||||
action_idx, log_prob, value = self.model.get_action(state_tensor)
|
||||
|
||||
# Store for training
|
||||
self.states.append(state_tensor.cpu())
|
||||
self.actions.append(action_idx)
|
||||
self.log_probs.append(log_prob)
|
||||
self.values.append(value)
|
||||
|
||||
return action_idx
|
||||
|
||||
def store_reward(self, reward, done=False):
|
||||
"""Store reward for last action"""
|
||||
self.rewards.append(reward)
|
||||
self.dones.append(done)
|
||||
|
||||
def save_model(self):
|
||||
"""Save current model state"""
|
||||
try:
|
||||
torch.save(self.model.state_dict(), MODEL_PATH)
|
||||
self.last_save_time = time.time()
|
||||
print(f"Model saved automatically to {MODEL_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save model: {e}")
|
||||
|
||||
def train(self):
|
||||
"""Train the model using collected experiences"""
|
||||
if len(self.states) < self.batch_size:
|
||||
return
|
||||
|
||||
# Compute returns and advantages
|
||||
returns = []
|
||||
advantages = []
|
||||
R = 0
|
||||
|
||||
# Avoid empty rewards
|
||||
if not self.rewards:
|
||||
return
|
||||
|
||||
for i in reversed(range(len(self.rewards))):
|
||||
R = self.rewards[i] + self.gamma * R * (1 - self.dones[i])
|
||||
returns.insert(0, R)
|
||||
advantage = R - self.values[i]
|
||||
advantages.insert(0, advantage)
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.cat(self.states).to(self.device)
|
||||
actions = torch.LongTensor(self.actions).to(self.device)
|
||||
old_log_probs = torch.FloatTensor(self.log_probs).to(self.device)
|
||||
returns = torch.FloatTensor(returns).to(self.device)
|
||||
advantages = torch.FloatTensor(advantages).to(self.device)
|
||||
|
||||
# Normalize advantages
|
||||
if advantages.std() < 1e-8:
|
||||
advantages = advantages - advantages.mean()
|
||||
else:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# PPO update
|
||||
for _ in range(self.epochs):
|
||||
# Get current policy predictions
|
||||
action_logits, values = self.model(states)
|
||||
action_probs = torch.softmax(action_logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(action_probs)
|
||||
new_log_probs = dist.log_prob(actions)
|
||||
entropy = dist.entropy()
|
||||
|
||||
# Compute ratios
|
||||
ratios = torch.exp(new_log_probs - old_log_probs)
|
||||
|
||||
# Compute losses
|
||||
surr1 = ratios * advantages
|
||||
surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * advantages
|
||||
actor_loss = -torch.min(surr1, surr2).mean()
|
||||
critic_loss = 0.5 * (returns - values.squeeze()).pow(2).mean()
|
||||
entropy_loss = -0.01 * entropy.mean()
|
||||
|
||||
loss = actor_loss + critic_loss + entropy_loss
|
||||
|
||||
# Update
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
||||
# Clear buffers
|
||||
self.states.clear()
|
||||
self.actions.clear()
|
||||
self.log_probs.clear()
|
||||
self.rewards.clear()
|
||||
self.values.clear()
|
||||
self.dones.clear()
|
||||
|
||||
|
||||
# ===================== GLOBAL PPO AGENT =====================
|
||||
|
||||
ppo_agent = PPOAgent(state_dim=15, action_dim=8)
|
||||
current_yaw = 0.0
|
||||
|
||||
|
||||
@app.post("/predict")
|
||||
def predict():
|
||||
global current_yaw
|
||||
state = request.json
|
||||
|
||||
bot_pos = state["bot_pos"]
|
||||
target_pos = state["target_pos"]
|
||||
distance = state["distance"]
|
||||
|
||||
# Calculate target yaw
|
||||
dx = target_pos["x"] - bot_pos["x"]
|
||||
dz = target_pos["z"] - bot_pos["z"]
|
||||
target_yaw = math.degrees(math.atan2(-dx, dz))
|
||||
|
||||
# Use PPO to select action
|
||||
action_idx = ppo_agent.select_action(state)
|
||||
action = ppo_agent.action_mapping[action_idx]
|
||||
|
||||
# Apply action
|
||||
yaw_delta, pitch_delta, forward, strafe, jump, sprint, crouch, swing = action
|
||||
current_yaw = target_yaw + yaw_delta
|
||||
|
||||
print(f"[DEBUG] Predict: ActionIdx={action_idx}, Yaw={current_yaw:.1f}, Pitch={pitch_delta}, FWD={forward}, STR={strafe}, Jump={jump}, Sprint={sprint}, Swing={swing}")
|
||||
|
||||
return jsonify({
|
||||
"yaw": current_yaw,
|
||||
"pitch": pitch_delta,
|
||||
"move_forward": forward,
|
||||
"move_strafe": strafe,
|
||||
"jump": bool(jump),
|
||||
"sprint": bool(sprint),
|
||||
"crouch": bool(crouch),
|
||||
"swing": bool(swing)
|
||||
})
|
||||
|
||||
|
||||
@app.post("/reward")
|
||||
def receive_reward():
|
||||
"""Receive reward data from Java for reinforcement learning"""
|
||||
reward_data = request.json
|
||||
|
||||
# Use the total reward calculated in Java if provided
|
||||
if "total_reward" in reward_data:
|
||||
total_reward = reward_data["total_reward"]
|
||||
else:
|
||||
# Fallback to Python calculation if Java hasn't sent it (for backward compatibility)
|
||||
bot_health = reward_data.get("bot_health", 20.0)
|
||||
target_health = reward_data.get("target_health", 20.0)
|
||||
damage_dealt = reward_data.get("damage_dealt", 0.0)
|
||||
damage_taken = reward_data.get("damage_taken", 0.0)
|
||||
distance = reward_data.get("distance", 0.0)
|
||||
hit_success = reward_data.get("hit_success", False)
|
||||
|
||||
total_reward = 0.0
|
||||
total_reward += damage_dealt * 10.0
|
||||
total_reward -= damage_taken * 5.0
|
||||
total_reward += 5.0 if hit_success else 0.0
|
||||
total_reward -= distance * 0.1 if distance > 4 else 0.0
|
||||
total_reward += 1.0 if 1.5 < distance < 3.5 else 0.0
|
||||
total_reward += 100.0 if target_health <= 0 else 0.0
|
||||
total_reward -= 100.0 if bot_health <= 0 else 0.0
|
||||
|
||||
done = reward_data.get("done", False)
|
||||
|
||||
# Store reward and train PPO
|
||||
ppo_agent.store_reward(total_reward, done)
|
||||
|
||||
# Periodic auto-save every 5 minutes
|
||||
time_since_last_save = time.time() - ppo_agent.last_save_time
|
||||
if time_since_last_save > SAVE_INTERVAL:
|
||||
ppo_agent.save_model()
|
||||
else:
|
||||
# Log every minute or so how much time is left for next save
|
||||
if int(time_since_last_save) % 60 == 0:
|
||||
print(f"[DEBUG] Next auto-save in {int(SAVE_INTERVAL - time_since_last_save)} seconds")
|
||||
|
||||
print(f"[DEBUG] Reward: Total={total_reward:.4f}, Done={done}")
|
||||
|
||||
# Train every batch
|
||||
if len(ppo_agent.rewards) >= ppo_agent.batch_size:
|
||||
ppo_agent.train()
|
||||
print(f"PPO Model trained. Buffer cleared. Last total reward: {total_reward:.2f}")
|
||||
|
||||
return jsonify({
|
||||
"status": "reward_received",
|
||||
"total_reward": total_reward,
|
||||
"model_trained": len(ppo_agent.rewards) == 0
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Disable reloader to prevent multiple initializations and potential state loss
|
||||
app.run(port=5000, debug=True, use_reloader=False)
|
||||
30
src/main/resources/fabric.mod.json
Normal file
30
src/main/resources/fabric.mod.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"schemaVersion": 1,
|
||||
"id": "goated",
|
||||
"version": "${version}",
|
||||
"name": "goated",
|
||||
"description": "",
|
||||
"authors": [],
|
||||
"contact": {},
|
||||
"license": "All-Rights-Reserved",
|
||||
"icon": "assets/goated/icon.png",
|
||||
"environment": "*",
|
||||
|
||||
"entrypoints": {
|
||||
"main": [
|
||||
"org.pvpbot.goated.Goated"
|
||||
]
|
||||
},
|
||||
|
||||
"mixins": [
|
||||
{
|
||||
"config": "goated.mixins.json"
|
||||
}
|
||||
],
|
||||
|
||||
"depends": {
|
||||
"fabricloader": ">=0.19.1",
|
||||
"fabric-api": "*",
|
||||
"minecraft": "1.21.4"
|
||||
}
|
||||
}
|
||||
15
src/main/resources/goated.mixins.json
Normal file
15
src/main/resources/goated.mixins.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"required": true,
|
||||
"minVersion": "0.8",
|
||||
"package": "org.pvpbot.goated.mixin",
|
||||
"compatibilityLevel": "JAVA_21",
|
||||
"mixins": [
|
||||
"ServerTickMixin"
|
||||
],
|
||||
"injectors": {
|
||||
"defaultRequire": 1
|
||||
},
|
||||
"overwrites": {
|
||||
"requireAnnotations": true
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user