0% found this document useful (0 votes)
2 views36 pages

java code

The document contains Java code for a stock portfolio management system using Spring framework, including controllers for dashboard, portfolio, and predictions. It defines entities for Portfolio, StockData, and User, along with data transfer objects (DTOs) for handling requests and responses. The system allows users to view their portfolio, add or remove stocks, and predict stock prices using LSTM models.

Uploaded by

Shiva Acharya
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
2 views36 pages

java code

The document contains Java code for a stock portfolio management system using Spring framework, including controllers for dashboard, portfolio, and predictions. It defines entities for Portfolio, StockData, and User, along with data transfer objects (DTOs) for handling requests and responses. The system allows users to view their portfolio, add or remove stocks, and predict stock prices using LSTM models.

Uploaded by

Shiva Acharya
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 36

1.

Controller
package com.nepse.controller;

import com.nepse.dto.PortfolioSummary;
import com.nepse.dto.PredictionResult;
import com.nepse.service.LstmPredictionService;
import com.nepse.service.PortfolioService;
import com.nepse.service.StockDataService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;

@Controller
@RequestMapping("/dashboard")
public class DashboardController {

private final LstmPredictionService predictionService;


private final PortfolioService portfolioService;
private final StockDataService stockDataService;

public DashboardController(LstmPredictionService predictionService,


PortfolioService portfolioService,
StockDataService stockDataService) {
this.predictionService = predictionService;
this.portfolioService = portfolioService;
this.stockDataService = stockDataService;
}

@GetMapping
public String showDashboard(Model model) {
Long userId = 1L; // Example user ID

try {
stockDataService.ensureMinimumData("NABIL", 60);
PredictionResult prediction = predictionService.predictStock("NABIL");
model.addAttribute("prediction", prediction);
} catch (Exception e) {
model.addAttribute("predictionError", "Error predicting stock: " + e.getMessage());
}

PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);


model.addAttribute("portfolio", portfolio);
model.addAttribute("topGainers", stockDataService.getTopGainers(5));
model.addAttribute("topLosers", stockDataService.getTopLosers(5));

return "dashboard";
}

@PostMapping("/predict")
public String predictStock(@RequestParam String symbol,
Model model) {
Long userId = 1L; // Example user ID

try {
PredictionResult prediction = predictionService.predictStock(symbol);
model.addAttribute("prediction", prediction);
} catch (Exception e) {
model.addAttribute("predictionError", "Error predicting stock: " + e.getMessage());
}

PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);


model.addAttribute("portfolio", portfolio);
model.addAttribute("topGainers", stockDataService.getTopGainers(5));
model.addAttribute("topLosers", stockDataService.getTopLosers(5));

return "dashboard";
}
}
package com.nepse.controller;

import com.nepse.dto.PortfolioRequest;
import com.nepse.dto.PortfolioSummary;
import com.nepse.service.PortfolioService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;

@Controller
@RequestMapping("/portfolio")
public class PortfolioController {

private final PortfolioService portfolioService;


public PortfolioController(PortfolioService portfolioService) {
this.portfolioService = portfolioService;
}

@GetMapping
public String showPortfolio(Model model) {
// Hardcode user ID or use session-based approach
Long userId = 1L; // Example user ID
PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);

model.addAttribute("portfolio", portfolio);
model.addAttribute("portfolioRequest", new PortfolioRequest());
return "portfolio";
}

@PostMapping("/add")
public String addToPortfolio(@ModelAttribute PortfolioRequest request,
Model model) {
Long userId = 1L; // Example user ID

try {
portfolioService.addStockToPortfolio(
userId,
request.getSymbol(),
request.getQuantity(),
request.getAveragePrice()
);
model.addAttribute("success", "Stock added to portfolio successfully");
} catch (Exception e) {
model.addAttribute("error", "Failed to add stock: " + e.getMessage());
}

return "redirect:/portfolio";
}

@PostMapping("/remove")
public String removeFromPortfolio(@ModelAttribute PortfolioRequest request,
Model model) {
Long userId = 1L; // Example user ID

try {
portfolioService.removeStockFromPortfolio(
userId,
request.getSymbol(),
request.getQuantity()
);
model.addAttribute("success", "Stock removed from portfolio successfully");
} catch (Exception e) {
model.addAttribute("error", "Failed to remove stock: " + e.getMessage());
}

return "redirect:/portfolio";
}
}

package com.nepse.controller;

import com.nepse.dto.PredictionResult;
import com.nepse.service.LstmPredictionService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;

@Controller
@RequestMapping("/prediction")
public class PredictionController {

private final LstmPredictionService predictionService;

public PredictionController(LstmPredictionService predictionService) {


this.predictionService = predictionService;
}

@GetMapping
public String showPrediction(@RequestParam(required = false) String symbol,
Model model) {
String predictionSymbol = symbol != null ? symbol : "NEPSE";

try {
PredictionResult prediction = predictionService.predictStock(predictionSymbol);
model.addAttribute("prediction", prediction);
model.addAttribute("symbol", predictionSymbol);
} catch (Exception e) {
model.addAttribute("error", "Error generating prediction: " + e.getMessage());
}
// Remove user reference
return "prediction";
}
}
2.domain
package com.nepse.domain;

import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.time.LocalDate;

@Entity
@Table(name = "portfolio")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Portfolio {

@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;

@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "user_id", nullable = false)
private User user;

@Column(nullable = false)
private String symbol;

@Column(nullable = false)
private int quantity;

@Column(nullable = false)
private double averagePrice;

@Column(nullable = false)
private LocalDate purchaseDate = LocalDate.now();

@Column(nullable = false)
private String transactionType; // BUY or SELL
private String notes;

@Column(nullable = false)
private double totalInvestment;

// Additional fields for tracking


private double currentValue;
private double profitLoss;
private double profitLossPercentage;
}
package com.nepse.domain;

import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.time.LocalDate;

@Entity
@Table(name = "stock_data")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class StockData {

@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;

@Column(nullable = false)
private String symbol;

@Column(nullable = false)
private LocalDate date;

@Column(nullable = false)
private double openingPrice;

@Column(nullable = false)
private double closingPrice;

@Column(nullable = false)
private double highPrice;

@Column(nullable = false)
private double lowPrice;

@Column(nullable = false)
private long volume;

@Column(nullable = false)
private double changeAmount;

@Column(nullable = false)
private double changePercentage;

// Additional technical indicators for LSTM model


private double movingAverage5;
private double movingAverage20;
private double movingAverage50;
private double rsi14;
private double macd;
private double bollingerUpper;
private double bollingerLower;
}
package com.nepse.domain;

import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;

@Entity
@Table(name = "users")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class User {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;

@Column(unique = true, nullable = false)


private String username;
@Column(nullable = false)
private String password; // Added back for authentication

private String fullName;


private String email;

@Column(nullable = false)
private LocalDateTime createdAt = LocalDateTime.now();
}
3.dto
package com.nepse.dto;

import jakarta.validation.constraints.NotBlank;

public class LoginRequest {

@NotBlank(message = "Username is required")


private String username;

@NotBlank(message = "Password is required")


private String password;

// Getters and Setters


public String getUsername() {
return username;
}

public void setUsername(String username) {


this.username = username;
}

public String getPassword() {


return password;
}

public void setPassword(String password) {


this.password = password;
}
}
package com.nepse.dto;

import java.math.BigDecimal;

public class PortfolioItem {


private String symbol;
private int quantity;
private BigDecimal averagePrice;
private BigDecimal currentPrice;
private BigDecimal investmentValue;
private BigDecimal currentValue;
private BigDecimal profitLoss;
private BigDecimal profitLossPercentage;

// Constructors, Getters and Setters


public PortfolioItem() {
}

public PortfolioItem(String symbol, int quantity, BigDecimal averagePrice, BigDecimal


currentPrice) {
this.symbol = symbol;
this.quantity = quantity;
this.averagePrice = averagePrice;
this.currentPrice = currentPrice;
calculateValues();
}

private void calculateValues() {


this.investmentValue = averagePrice.multiply(BigDecimal.valueOf(quantity));
this.currentValue = currentPrice.multiply(BigDecimal.valueOf(quantity));
this.profitLoss = currentValue.subtract(investmentValue);
this.profitLossPercentage = investmentValue.compareTo(BigDecimal.ZERO) != 0
? profitLoss.divide(investmentValue, 4,
BigDecimal.ROUND_HALF_UP).multiply(BigDecimal.valueOf(100))
: BigDecimal.ZERO;
}

// Getters and Setters


public String getSymbol() {
return symbol;
}

public void setSymbol(String symbol) {


this.symbol = symbol;
}

public int getQuantity() {


return quantity;
}
public void setQuantity(int quantity) {
this.quantity = quantity;
calculateValues();
}

public BigDecimal getAveragePrice() {


return averagePrice;
}

public void setAveragePrice(BigDecimal averagePrice) {


this.averagePrice = averagePrice;
calculateValues();
}

public BigDecimal getCurrentPrice() {


return currentPrice;
}

public void setCurrentPrice(BigDecimal currentPrice) {


this.currentPrice = currentPrice;
calculateValues();
}

public BigDecimal getInvestmentValue() {


return investmentValue;
}

public BigDecimal getCurrentValue() {


return currentValue;
}

public BigDecimal getProfitLoss() {


return profitLoss;
}

public BigDecimal getProfitLossPercentage() {


return profitLossPercentage;
}
}
package com.nepse.dto;

import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Positive;

public class
PortfolioRequest {

@NotBlank(message = "Stock symbol is required")


private String symbol;

@NotNull(message = "Quantity is required")


@Min(value = 1, message = "Quantity must be at least 1")
private Integer quantity;

@NotNull(message = "Price is required")


@Positive(message = "Price must be positive")
private Double averagePrice;

// Getters and Setters


public String getSymbol() {
return symbol;
}

public void setSymbol(String symbol) {


this.symbol = symbol;
}

public Integer getQuantity() {


return quantity;
}

public void setQuantity(Integer quantity) {


this.quantity = quantity;
}

public Double getAveragePrice() {


return averagePrice;
}

public void setAveragePrice(Double averagePrice) {


this.averagePrice = averagePrice;
}
}
package com.nepse.dto;
import java.math.BigDecimal;
import java.util.List;

public class PortfolioSummary {


private final double totalValue;
private final double todaysGain;
private final double overallGain;
private final double gainPercentage;
private final List<PortfolioItem> items;

public PortfolioSummary(double totalValue, double todaysGain,


double overallGain, double gainPercentage,
List<PortfolioItem> items) {
this.totalValue = totalValue;
this.todaysGain = todaysGain;
this.overallGain = overallGain;
this.gainPercentage = gainPercentage;
this.items = items;
}

// Getters
public double getTotalValue() { return totalValue; }
public double getTodaysGain() { return todaysGain; }
public double getOverallGain() { return overallGain; }
public double getGainPercentage() { return gainPercentage; }
public List<PortfolioItem> getItems() { return items; }

// PortfolioItem inner class


public static class PortfolioItem {
private final String symbol;
private final int quantity;
private final double averagePrice;
private final double currentPrice;
private final double gain;
private final double gainPercentage;

public PortfolioItem(String symbol, int quantity,


double averagePrice, double currentPrice,
double gain, double gainPercentage) {
this.symbol = symbol;
this.quantity = quantity;
this.averagePrice = averagePrice;
this.currentPrice = currentPrice;
this.gain = gain;
this.gainPercentage = gainPercentage;
}

// Getters
public String getSymbol() { return symbol; }
public int getQuantity() { return quantity; }
public double getAveragePrice() { return averagePrice; }
public double getCurrentPrice() { return currentPrice; }
public double getGain() { return gain; }
public double getGainPercentage() { return gainPercentage; }
}
}
package com.nepse.dto;

import java.time.LocalDate;

public class PredictionResult {

private String symbol;


private double currentPrice;
private double predictedPrice;
private String trend; // UP/DOWN
private double confidence; // 0-100
private double potentialGain; // percentage
private LocalDate predictionDate;
private LocalDate targetDate;

public PredictionResult(String symbol, double currentPrice, double predictedPrice, String s,


double confidence, double v, LocalDate now, LocalDate localDate) {
}

// Getters and Setters


public String getSymbol() {
return symbol;
}

public void setSymbol(String symbol) {


this.symbol = symbol;
}

public double getCurrentPrice() {


return currentPrice;
}
public void setCurrentPrice(double currentPrice) {
this.currentPrice = currentPrice;
}

public double getPredictedPrice() {


return predictedPrice;
}

public void setPredictedPrice(double predictedPrice) {


this.predictedPrice = predictedPrice;
}

public String getTrend() {


return trend;
}

public void setTrend(String trend) {


this.trend = trend;
}

public double getConfidence() {


return confidence;
}

public void setConfidence(double confidence) {


this.confidence = confidence;
}

public double getPotentialGain() {


return potentialGain;
}

public void setPotentialGain(double potentialGain) {


this.potentialGain = potentialGain;
}

public LocalDate getPredictionDate() {


return predictionDate;
}

public void setPredictionDate(LocalDate predictionDate) {


this.predictionDate = predictionDate;
}
public LocalDate getTargetDate() {
return targetDate;
}

public void setTargetDate(LocalDate targetDate) {


this.targetDate = targetDate;
}
}
package com.nepse.dto;

import jakarta.validation.constraints.Email;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;

public class RegisterRequest {

@NotBlank(message = "Username is required")


@Size(min = 3, max = 20, message = "Username must be between 3 and 20 characters")
private String username;

@NotBlank(message = "Password is required")


@Size(min = 6, max = 40, message = "Password must be between 6 and 40 characters")
private String password;

@NotBlank(message = "Full name is required")


private String fullName;

@NotBlank(message = "Email is required")


@Email(message = "Email should be valid")
private String email;

// Getters and Setters


public String getUsername() {
return username;
}

public void setUsername(String username) {


this.username = username;
}

public String getPassword() {


return password;
}
public void setPassword(String password) {
this.password = password;
}

public String getFullName() {


return fullName;
}

public void setFullName(String fullName) {


this.fullName = fullName;
}

public String getEmail() {


return email;
}

public void setEmail(String email) {


this.email = email;
}
}
4.exceprtion
package com.nepse.exception;

import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ResponseStatus;

@ResponseStatus(value = HttpStatus.BAD_REQUEST)
public class BadRequestException extends RuntimeException {
public BadRequestException(String message) {
super(message);
}

public BadRequestException(String message, Throwable cause) {


super(message, cause);
}
}
package com.nepse.exception;

public record ErrorResponse(


int status,
String message,
long timestamp
) {}
package com.nepse.exception;

import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.context.request.WebRequest;

@ControllerAdvice
public class GlobalExceptionHandler {

@ExceptionHandler(BadRequestException.class)
public ResponseEntity<ErrorResponse> handleBadRequest(
BadRequestException ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.BAD_REQUEST.value(),
ex.getMessage(),
System.currentTimeMillis());
return ResponseEntity.badRequest().body(response);
}

@ExceptionHandler(ResourceNotFoundException.class)
public ResponseEntity<ErrorResponse> handleResourceNotFound(
ResourceNotFoundException ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.NOT_FOUND.value(),
ex.getMessage(),
System.currentTimeMillis());
return ResponseEntity.status(HttpStatus.NOT_FOUND).body(response);
}

// Add generic exception handler (optional but recommended)


@ExceptionHandler(Exception.class)
public ResponseEntity<ErrorResponse> handleGlobalException(
Exception ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.INTERNAL_SERVER_ERROR.value(),
"An unexpected error occurred",
System.currentTimeMillis());
return ResponseEntity.internalServerError().body(response);
}
}
package com.nepse.exception;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ResponseStatus;

@ResponseStatus(value = HttpStatus.NOT_FOUND)
public class ResourceNotFoundException extends RuntimeException {
private final String resourceName;
private final String fieldName;
private final Object fieldValue;

public ResourceNotFoundException(String resourceName, String fieldName, Object


fieldValue) {
super(String.format("%s not found with %s : '%s'", resourceName, fieldName, fieldValue));
this.resourceName = resourceName;
this.fieldName = fieldName;
this.fieldValue = fieldValue;
}

public String getResourceName() {


return resourceName;
}

public String getFieldName() {


return fieldName;
}

public Object getFieldValue() {


return fieldValue;
}
}
5.init
package com.nepse.init;
import com.nepse.service.StockDataService;
import jakarta.annotation.PostConstruct;
import org.springframework.stereotype.Component;
import java.io.IOException;

@Component
public class DataLoader {

private final StockDataService stockDataService;

public DataLoader(StockDataService stockDataService) {


this.stockDataService = stockDataService;
}
@PostConstruct
public void importData() {
String filePath = "src/main/resources/static/data/NABIL.csv"; // path to your CSV file
String symbol = "NABIL"; // the stock symbol

try {
stockDataService.importStockDataFromCsv(filePath, symbol);
System.out.println("Stock data imported successfully.");
} catch (IOException e) {
e.printStackTrace();
}
}
}
repository
package com.nepse.repository;

import com.nepse.domain.Portfolio;
import com.nepse.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;

import java.util.List;

@Repository
public interface PortfolioRepository extends JpaRepository<Portfolio, Long> {

List<Portfolio> findByUser(User user);

List<Portfolio> findByUserAndSymbol(User user, String symbol);

@Query("SELECT p FROM Portfolio p WHERE p.user = :user GROUP BY p.symbol")


List<Portfolio> findDistinctByUser(User user);

@Query("SELECT p.symbol FROM Portfolio p WHERE p.user = :user GROUP BY p.symbol")


List<String> findDistinctSymbolsByUser(User user);

@Query("SELECT SUM(p.quantity) FROM Portfolio p WHERE p.user = :user AND p.symbol =


:symbol AND p.transactionType = 'BUY'")
Integer sumBoughtQuantityByUserAndSymbol(User user, String symbol);

@Query("SELECT SUM(p.quantity) FROM Portfolio p WHERE p.user = :user AND p.symbol =


:symbol AND p.transactionType = 'SELL'")
Integer sumSoldQuantityByUserAndSymbol(User user, String symbol);

@Query("SELECT COALESCE(SUM(p.quantity * p.averagePrice), 0) FROM Portfolio p WHERE


p.user = :user AND p.symbol = :symbol AND p.transactionType = 'BUY'")
Double sumInvestmentByUserAndSymbol(User user, String symbol);
}
package com.nepse.repository;

import com.nepse.domain.StockData;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;

import java.time.LocalDate;
import java.util.List;

@Repository
public interface StockDataRepository extends JpaRepository<StockData, Long> {

List<StockData> findBySymbolOrderByDateDesc(String symbol);

@Query("SELECT s FROM StockData s WHERE s.symbol = :symbol ORDER BY s.date DESC


LIMIT :limit")
List<StockData> findTopNBySymbolOrderByDateDesc(String symbol, int limit);

StockData findTopBySymbolOrderByDateDesc(String symbol);

StockData findBySymbolAndDate(String symbol, LocalDate date);

@Query("SELECT DISTINCT s.symbol FROM StockData s")


List<String> findAllDistinctSymbols();

@Query("SELECT s FROM StockData s WHERE s.date = (SELECT MAX(s2.date) FROM


StockData s2) ORDER BY s.changePercentage DESC LIMIT :count")
List<StockData> getTopGainers(int count);

@Query("SELECT s FROM StockData s WHERE s.date = (SELECT MAX(s2.date) FROM


StockData s2) ORDER BY s.changePercentage ASC LIMIT :count")
List<StockData> getTopLosers(int count);

List<StockData> findTop60BySymbolOrderByDateDesc(String symbol);

long countBySymbol(String symbol);


@Modifying
@Query("DELETE FROM StockData s WHERE s.symbol = :symbol")
void deleteBySymbol(String symbol);

List<StockData> findBySymbol(String symbol);

}
package com.nepse.repository;

import com.nepse.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

import java.util.Optional;

@Repository
public interface UserRepository extends JpaRepository<User, Long> {
Optional<User> findByUsername(String username);
Optional<User> findByEmail(String email);
Boolean existsByUsername(String username);
Boolean existsByEmail(String email);

@Query("SELECT u FROM User u WHERE u.username = :username OR u.email = :email")


Optional<User> findByUsernameOrEmail(@Param("username") String username,
@Param("email") String email);
}

6.Service
package com.nepse.service;

import com.nepse.domain.StockData;
import com.nepse.dto.PredictionResult;
import com.nepse.repository.StockDataRepository;
import com.nepse.util.ModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.File;
import java.io.IOException;
import java.time.LocalDate;
import java.util.Collections;
import java.util.List;

@Service
public class LstmPredictionService {

private final StockDataRepository stockDataRepository;

@Value("${lstm.model.directory}")
private String modelDirectory;

public LstmPredictionService(StockDataRepository stockDataRepository) {


this.stockDataRepository = stockDataRepository;
}

public PredictionResult predictStock(String symbol) {


// Get historical data
List<StockData> historicalData = stockDataRepository
.findTop60BySymbolOrderByDateDesc(symbol);
Collections.reverse(historicalData);
System.out.println("Fetched data size: " + historicalData.size());
for (StockData data : historicalData) {
System.out.println(data.getDate() + " - " + data.getClosingPrice());
}

if (historicalData.size() < 60) {


throw new IllegalArgumentException("Not enough historical data for prediction");
}

// Preprocess data
double[] normalizedData = normalizeData(historicalData);

// Load model
MultiLayerNetwork model = loadModel(symbol);

// Prepare input
INDArray input = Nd4j.create(normalizedData, new int[]{1, 60, 1});

// Make prediction
INDArray output = model.output(input);
double predictedValue = output.getDouble(0);

// Post-process prediction
double min =
historicalData.stream().mapToDouble(StockData::getClosingPrice).min().orElse(0);
double max =
historicalData.stream().mapToDouble(StockData::getClosingPrice).max().orElse(1);
double predictedPrice = predictedValue * (max - min) + min;

// Create result
double currentPrice = historicalData.get(0).getClosingPrice();
double confidence = calculateConfidence(historicalData, predictedPrice);

return new PredictionResult(


symbol,
currentPrice,
predictedPrice,
predictedPrice > currentPrice ? "UP" : "DOWN",
confidence,
((predictedPrice - currentPrice) / currentPrice) * 100,
LocalDate.now(),
LocalDate.now().plusDays(7)
);
}

private double[] normalizeData(List<StockData> data) {


double min = data.stream().mapToDouble(StockData::getClosingPrice).min().orElse(0);
double max = data.stream().mapToDouble(StockData::getClosingPrice).max().orElse(1);

return data.stream()
.mapToDouble(d -> (d.getClosingPrice() - min) / (max - min))
.toArray();
}

private MultiLayerNetwork loadModel(String symbol) {


try {
return ModelUtils.loadModel(new File(modelDirectory, symbol + ".zip"));
} catch (IOException e) {
throw new RuntimeException("Failed to load model for symbol: " + symbol, e);
}
}

private double calculateConfidence(List<StockData> historicalData, double predictedPrice) {


// Simple confidence calculation based on recent volatility
double sum = 0;
double count = 0;
for (int i = 0; i < historicalData.size() - 1; i++) {
double change = Math.abs(historicalData.get(i).getClosingPrice() -
historicalData.get(i + 1).getClosingPrice());
sum += change;
count++;
}

double avgChange = sum / count;


double diff = Math.abs(predictedPrice - historicalData.get(0).getClosingPrice());

// Higher confidence when prediction is within average volatility range


return Math.min(100, 80 + (20 * (1 - (diff / (avgChange * 3)))));
}
}
package com.nepse.service;

import com.nepse.domain.Portfolio;
import com.nepse.domain.StockData;
import com.nepse.domain.User;
import com.nepse.dto.PortfolioRequest;
import com.nepse.dto.PortfolioSummary;
import com.nepse.exception.ResourceNotFoundException;
import com.nepse.repository.PortfolioRepository;
import com.nepse.repository.UserRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.List;
import java.util.stream.Collectors;

@Service
public class PortfolioService {

private final PortfolioRepository portfolioRepository;


private final UserRepository userRepository;
private final StockDataService stockDataService;

public PortfolioService(PortfolioRepository portfolioRepository,


UserRepository userRepository,
StockDataService stockDataService) {
this.portfolioRepository = portfolioRepository;
this.userRepository = userRepository;
this.stockDataService = stockDataService;
}
@Transactional
public void addStockToPortfolio(Long userId, String symbol, int quantity, double averagePrice)
{
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));

Portfolio portfolio = new Portfolio();


portfolio.setUser(user);
portfolio.setSymbol(symbol);
portfolio.setQuantity(quantity);
portfolio.setAveragePrice(averagePrice);
portfolio.setTransactionType("BUY");
portfolio.setTotalInvestment(quantity * averagePrice);

portfolioRepository.save(portfolio);
}

@Transactional
public void removeStockFromPortfolio(Long userId, String symbol, int quantity) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));

int currentQuantity = getAvailableQuantity(user, symbol);


if (currentQuantity < quantity) {
throw new IllegalArgumentException("Not enough shares to sell");
}

Portfolio portfolio = new Portfolio();


portfolio.setUser(user);
portfolio.setSymbol(symbol);
portfolio.setQuantity(quantity);
portfolio.setAveragePrice(stockDataService.getCurrentPrice(symbol));
portfolio.setTransactionType("SELL");
portfolio.setTotalInvestment(quantity * portfolio.getAveragePrice());

portfolioRepository.save(portfolio);
}

@Transactional(readOnly = true)
public PortfolioSummary getUserPortfolio(Long userId) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));
List<String> symbols = portfolioRepository.findDistinctSymbolsByUser(user);

// Using a container object to hold our accumulators


PortfolioSummaryContainer container = new PortfolioSummaryContainer();

List<PortfolioSummary.PortfolioItem> items = symbols.stream()


.map(symbol -> processSymbol(user, symbol, container))
.collect(Collectors.toList());

double overallGain = container.totalValue - container.totalInvestment;


double gainPercentage = container.totalInvestment > 0 ?
(overallGain / container.totalInvestment) * 100 : 0;

return new PortfolioSummary(


container.totalValue,
container.todaysGain,
overallGain,
gainPercentage,
items
);
}

private PortfolioSummary.PortfolioItem processSymbol(User user, String symbol,


PortfolioSummaryContainer container) {
int bought = portfolioRepository.sumBoughtQuantityByUserAndSymbol(user, symbol);
int sold = portfolioRepository.sumSoldQuantityByUserAndSymbol(user, symbol);
int available = bought - sold;

double currentPrice = stockDataService.getCurrentPrice(symbol);


double investment = portfolioRepository.sumInvestmentByUserAndSymbol(user, symbol);
double avgPrice = bought > 0 ? investment / bought : 0;

double value = available * currentPrice;


double gain = value - (available * avgPrice);
double gainPercentage = (available * avgPrice) != 0 ? (gain / (available * avgPrice)) * 100 : 0;

// Update container values


container.totalValue += value;
container.totalInvestment += (available * avgPrice);

// Calculate today's gain with null check


List<StockData> history = stockDataService.getHistoricalData(symbol, 2);
double yesterdayPrice = history.size() > 1 ? history.get(1).getClosingPrice() : currentPrice;
container.todaysGain += available * (currentPrice - yesterdayPrice);
return new PortfolioSummary.PortfolioItem(
symbol,
available,
avgPrice,
currentPrice,
gain,
gainPercentage
);
}

private int getAvailableQuantity(User user, String symbol) {


Integer bought = portfolioRepository.sumBoughtQuantityByUserAndSymbol(user, symbol);
Integer sold = portfolioRepository.sumSoldQuantityByUserAndSymbol(user, symbol);
return (bought != null ? bought : 0) - (sold != null ? sold : 0);
}

// Helper container class to hold accumulated values


private static class PortfolioSummaryContainer {
double totalValue = 0;
double totalInvestment = 0;
double todaysGain = 0;
}
}
package com.nepse.service;

import com.nepse.domain.StockData;
import com.nepse.repository.StockDataRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;

@Service
public class StockDataService {

private final StockDataRepository stockDataRepository;

public StockDataService(StockDataRepository stockDataRepository) {


this.stockDataRepository = stockDataRepository;
}

@Transactional(readOnly = true)
public List<StockData> getHistoricalData(String symbol, int days) {
return stockDataRepository.findTopNBySymbolOrderByDateDesc(symbol, days);
}

@Transactional(readOnly = true)
public StockData getLatestData(String symbol) {
return stockDataRepository.findTopBySymbolOrderByDateDesc(symbol);
}

@Transactional(readOnly = true)
public List<StockData> getTopGainers(int count) {
return stockDataRepository.getTopGainers(count);
}

@Transactional(readOnly = true)
public List<StockData> getTopLosers(int count) {
return stockDataRepository.getTopLosers(count);
}

@Transactional(readOnly = true)
public List<String> getAllSymbols() {
return stockDataRepository.findAllDistinctSymbols();
}

@Transactional
public void updateStockData(List<StockData> stockDataList) {
stockDataRepository.saveAll(stockDataList);
}

@Transactional(readOnly = true)
public Double getCurrentPrice(String symbol) {
StockData latest = stockDataRepository.findTopBySymbolOrderByDateDesc(symbol);
return latest != null ? latest.getClosingPrice() : 0.0;
}

// === NEW METHOD: Import stock data from CSV ===


@Transactional
public void importStockDataFromCsv(String filePath, String symbol) throws IOException {
List<StockData> stockDataList = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
String line;

// Skip CSV header


reader.readLine();

while ((line = reader.readLine()) != null) {


String[] tokens = line.split(",");

// Make sure we have enough columns


if (tokens.length < 8) {
continue; // skip incomplete rows
}

// Parse data - adjust column indexes as per your CSV format


LocalDate date = LocalDate.parse(tokens[0].trim()); // e.g. "2025-08-05"
double openingPrice = Double.parseDouble(tokens[1].trim());
double closingPrice = Double.parseDouble(tokens[2].trim());
double highPrice = Double.parseDouble(tokens[3].trim());
double lowPrice = Double.parseDouble(tokens[4].trim());
long volume = Long.parseLong(tokens[5].trim());
double changeAmount = Double.parseDouble(tokens[6].trim());
double changePercentage = Double.parseDouble(tokens[7].trim());

// Optional: Parse technical indicators if present in CSV


double movingAverage5 = tokens.length > 8 ? Double.parseDouble(tokens[8].trim()) :
0.0;
double movingAverage20 = tokens.length > 9 ? Double.parseDouble(tokens[9].trim()) :
0.0;
double movingAverage50 = tokens.length > 10 ?
Double.parseDouble(tokens[10].trim()) : 0.0;
double rsi14 = tokens.length > 11 ? Double.parseDouble(tokens[11].trim()) : 0.0;
double macd = tokens.length > 12 ? Double.parseDouble(tokens[12].trim()) : 0.0;
double bollingerUpper = tokens.length > 13 ? Double.parseDouble(tokens[13].trim()) :
0.0;
double bollingerLower = tokens.length > 14 ? Double.parseDouble(tokens[14].trim()) :
0.0;

// Create StockData entity and set fields


StockData stockData = new StockData();
stockData.setSymbol(symbol);
stockData.setDate(date);
stockData.setOpeningPrice(openingPrice);
stockData.setClosingPrice(closingPrice);
stockData.setHighPrice(highPrice);
stockData.setLowPrice(lowPrice);
stockData.setVolume(volume);
stockData.setChangeAmount(changeAmount);
stockData.setChangePercentage(changePercentage);
stockData.setMovingAverage5(movingAverage5);
stockData.setMovingAverage20(movingAverage20);
stockData.setMovingAverage50(movingAverage50);
stockData.setRsi14(rsi14);
stockData.setMacd(macd);
stockData.setBollingerUpper(bollingerUpper);
stockData.setBollingerLower(bollingerLower);

stockDataList.add(stockData);
}
}

// Save all records in batch


stockDataRepository.saveAll(stockDataList);
}

@Transactional
public void ensureMinimumData(String symbol, int minDays) {
long count = stockDataRepository.countBySymbol(symbol);
if (count < minDays) {
// Add sample data if we don't have enough
double basePrice = switch (symbol) {
case "NABIL" -> 2000.0;
case "NICA" -> 1500.0;
case "NBL" -> 1200.0;
case "SCB" -> 1800.0;
case "HIDCL" -> 500.0;
case "GBIME" -> 800.0;
default -> 1000.0;
};

// Delete existing data if any


List<StockData> existing = stockDataRepository.findBySymbol(symbol);
if (!existing.isEmpty()) {
stockDataRepository.deleteAll(existing);
}
}}}
package com.nepse.service;

import com.nepse.domain.User;
import com.nepse.dto.RegisterRequest;
import com.nepse.exception.BadRequestException;
import com.nepse.repository.UserRepository;
import com.nepse.security.UserPrincipal;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.nepse.exception.ResourceNotFoundException;

@Service
public class UserService {

private final UserRepository userRepository;


private final PasswordEncoder passwordEncoder;

public UserService(UserRepository userRepository, PasswordEncoder passwordEncoder) {


this.userRepository = userRepository;
this.passwordEncoder = passwordEncoder;
}

@Transactional
public User createUser(RegisterRequest registerRequest) {
if (userRepository.existsByUsername(registerRequest.getUsername())) {
throw new BadRequestException("Username already in use");
}

if (userRepository.existsByEmail(registerRequest.getEmail())) {
throw new BadRequestException("Email already in use");
}

User user = new User();


user.setUsername(registerRequest.getUsername());
user.setPassword(passwordEncoder.encode(registerRequest.getPassword()));
user.setFullName(registerRequest.getFullName());
user.setEmail(registerRequest.getEmail());

return userRepository.save(user);
}

@Transactional(readOnly = true)
public UserPrincipal loadUserById(Long id) {
User user = userRepository.findById(id)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", id));

return UserPrincipal.create(user);
}

public boolean existsByUsername(String username) {


return userRepository.existsByUsername(username);
}
public boolean existsByEmail(String email) {
return userRepository.existsByEmail(email);
}
}

7Util
// src/main/java/com/nepse/util/DataInitializer.java
package com.nepse.util;

import com.nepse.domain.StockData;
import com.nepse.repository.StockDataRepository;
import jakarta.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;

import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;

@Component
public class DataInitializer {

private final StockDataRepository stockDataRepository;

@Autowired
public DataInitializer(StockDataRepository stockDataRepository) {
this.stockDataRepository = stockDataRepository;
}

@PostConstruct
@Transactional
public void init() {
// Ensure we have data for key symbols
ensureSymbolData("NABIL", 2000.0, 60);
ensureSymbolData("NICA", 1500.0, 60);
ensureSymbolData("NBL", 1200.0, 60);
ensureSymbolData("SCB", 1800.0, 60);
ensureSymbolData("HIDCL", 500.0, 60);
ensureSymbolData("GBIME", 800.0, 60);

private void ensureSymbolData(String symbol, double startPrice, int days) {


long count = stockDataRepository.countBySymbol(symbol);
if (count < days) {
// Remove existing incomplete data
if (count > 0) {
stockDataRepository.deleteBySymbol(symbol);
}

// Add new sample data


addSampleData(symbol, startPrice, days);
System.out.println("Added " + days + " days of data for " + symbol);
}
}

private void addSampleData(String symbol, double startPrice, int days) {


List<StockData> data = new ArrayList<>();
LocalDate startDate = LocalDate.now().minusDays(days);
double price = startPrice;

for (int i = 0; i < days; i++) {


double change = (Math.random() - 0.5) * 50; // Random change between -50 to +50
price += change;

StockData stock = new StockData();


stock.setSymbol(symbol);
stock.setDate(startDate.plusDays(i));
stock.setOpeningPrice(price - 10);
stock.setClosingPrice(price);
stock.setHighPrice(price + 5);
stock.setLowPrice(price - 15);
stock.setVolume(100000 + (long)(Math.random() * 50000));
stock.setChangeAmount(change);
stock.setChangePercentage((change / (price - change)) * 100);

// Add technical indicators


stock.setMovingAverage5(calculateMovingAverage(data, 5, price));
stock.setMovingAverage20(calculateMovingAverage(data, 20, price));
stock.setMovingAverage50(calculateMovingAverage(data, 50, price));

data.add(stock);
}

stockDataRepository.saveAll(data);
}

private double calculateMovingAverage(List<StockData> data, int period, double currentPrice)


{
if (data.size() < period - 1) {
return currentPrice;
}

double sum = currentPrice;


int count = 1;
for (int i = data.size() - 1; i >= Math.max(0, data.size() - period + 1); i--) {
sum += data.get(i).getClosingPrice();
count++;
}
return sum / count;
}
}
package com.nepse.util;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;

public class ModelUtils {

private static final Logger logger = LoggerFactory.getLogger(ModelUtils.class);

/**
* Saves the LSTM model and its normalizer to disk
*
* @param model The trained LSTM model
* @param normalizer The data normalizer used with the model
* @param modelFile The file to save the model to
* @throws IOException If there's an error saving the files
*/
public static void saveModel(MultiLayerNetwork model,
NormalizerMinMaxScaler normalizer,
File modelFile) throws IOException {
// Create parent directories if they don't exist
modelFile.getParentFile().mkdirs();

// Save the model


ModelSerializer.writeModel(model, modelFile, true);
logger.info("Saved model to: {}", modelFile.getAbsolutePath());

// Save the normalizer


File normalizerFile = new File(modelFile.getParent(),
modelFile.getName().replace(".zip", "-normalizer.bin"));
NormalizerSerializer.getDefault().write(normalizer, normalizerFile);
logger.info("Saved normalizer to: {}", normalizerFile.getAbsolutePath());
}

/**
* Loads a trained LSTM model from disk
*
* @param modelFile The file containing the saved model
* @return The loaded MultiLayerNetwork model
* @throws IOException If there's an error loading the model
*/
public static MultiLayerNetwork loadModel(File modelFile) throws IOException {
if (!modelFile.exists()) {
throw new IOException("Model file not found: " + modelFile.getAbsolutePath());
}

MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);


logger.info("Loaded model from: {}", modelFile.getAbsolutePath());
return model;
}
/**
* Loads the normalizer used with a specific model
*
* @param modelFile The model file path
* @return The loaded NormalizerMinMaxScaler
* @throws IOException If there's an error loading the normalizer
*/
public static NormalizerMinMaxScaler loadNormalizer(File modelFile) throws Exception {
File normalizerFile = new File(modelFile.getParent(),
modelFile.getName().replace(".zip", "-normalizer.bin"));

if (!normalizerFile.exists()) {
throw new IOException("Normalizer file not found: " + normalizerFile.getAbsolutePath());
}

return NormalizerSerializer.getDefault().restore(normalizerFile);
}

/**
* Checks if a trained model exists for a given symbol
*
* @param modelDir The directory containing models
* @param symbol The stock symbol to check
* @return true if model exists, false otherwise
*/
public static boolean modelExists(File modelDir, String symbol) {
File modelFile = new File(modelDir, symbol + ".zip");
File normalizerFile = new File(modelDir, symbol + "-normalizer.bin");
return modelFile.exists() && normalizerFile.exists();
}

/**
* Deletes model and normalizer files for a given symbol
*
* @param modelDir The directory containing models
* @param symbol The stock symbol to delete
* @return true if files were deleted, false otherwise
*/
public static boolean deleteModel(File modelDir, String symbol) {
File modelFile = new File(modelDir, symbol + ".zip");
File normalizerFile = new File(modelDir, symbol + "-normalizer.bin");

boolean modelDeleted = modelFile.exists() && modelFile.delete();


boolean normalizerDeleted = normalizerFile.exists() && normalizerFile.delete();

return modelDeleted || normalizerDeleted;


}
}
and finally the main class
package com.nepse;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class StockPredictionApplication {
public static void main(String[] args) {
SpringApplication.run(StockPredictionApplication.class, args);
}
}

You might also like