1.
Controller
package com.nepse.controller;
import com.nepse.dto.PortfolioSummary;
import com.nepse.dto.PredictionResult;
import com.nepse.service.LstmPredictionService;
import com.nepse.service.PortfolioService;
import com.nepse.service.StockDataService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
@Controller
@RequestMapping("/dashboard")
public class DashboardController {
private final LstmPredictionService predictionService;
private final PortfolioService portfolioService;
private final StockDataService stockDataService;
public DashboardController(LstmPredictionService predictionService,
PortfolioService portfolioService,
StockDataService stockDataService) {
this.predictionService = predictionService;
this.portfolioService = portfolioService;
this.stockDataService = stockDataService;
}
@GetMapping
public String showDashboard(Model model) {
Long userId = 1L; // Example user ID
try {
stockDataService.ensureMinimumData("NABIL", 60);
PredictionResult prediction = predictionService.predictStock("NABIL");
model.addAttribute("prediction", prediction);
} catch (Exception e) {
model.addAttribute("predictionError", "Error predicting stock: " + e.getMessage());
}
PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);
model.addAttribute("portfolio", portfolio);
model.addAttribute("topGainers", stockDataService.getTopGainers(5));
model.addAttribute("topLosers", stockDataService.getTopLosers(5));
return "dashboard";
}
@PostMapping("/predict")
public String predictStock(@RequestParam String symbol,
Model model) {
Long userId = 1L; // Example user ID
try {
PredictionResult prediction = predictionService.predictStock(symbol);
model.addAttribute("prediction", prediction);
} catch (Exception e) {
model.addAttribute("predictionError", "Error predicting stock: " + e.getMessage());
}
PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);
model.addAttribute("portfolio", portfolio);
model.addAttribute("topGainers", stockDataService.getTopGainers(5));
model.addAttribute("topLosers", stockDataService.getTopLosers(5));
return "dashboard";
}
}
package com.nepse.controller;
import com.nepse.dto.PortfolioRequest;
import com.nepse.dto.PortfolioSummary;
import com.nepse.service.PortfolioService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
@Controller
@RequestMapping("/portfolio")
public class PortfolioController {
private final PortfolioService portfolioService;
public PortfolioController(PortfolioService portfolioService) {
this.portfolioService = portfolioService;
}
@GetMapping
public String showPortfolio(Model model) {
// Hardcode user ID or use session-based approach
Long userId = 1L; // Example user ID
PortfolioSummary portfolio = portfolioService.getUserPortfolio(userId);
model.addAttribute("portfolio", portfolio);
model.addAttribute("portfolioRequest", new PortfolioRequest());
return "portfolio";
}
@PostMapping("/add")
public String addToPortfolio(@ModelAttribute PortfolioRequest request,
Model model) {
Long userId = 1L; // Example user ID
try {
portfolioService.addStockToPortfolio(
userId,
request.getSymbol(),
request.getQuantity(),
request.getAveragePrice()
);
model.addAttribute("success", "Stock added to portfolio successfully");
} catch (Exception e) {
model.addAttribute("error", "Failed to add stock: " + e.getMessage());
}
return "redirect:/portfolio";
}
@PostMapping("/remove")
public String removeFromPortfolio(@ModelAttribute PortfolioRequest request,
Model model) {
Long userId = 1L; // Example user ID
try {
portfolioService.removeStockFromPortfolio(
userId,
request.getSymbol(),
request.getQuantity()
);
model.addAttribute("success", "Stock removed from portfolio successfully");
} catch (Exception e) {
model.addAttribute("error", "Failed to remove stock: " + e.getMessage());
}
return "redirect:/portfolio";
}
}
package com.nepse.controller;
import com.nepse.dto.PredictionResult;
import com.nepse.service.LstmPredictionService;
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
@Controller
@RequestMapping("/prediction")
public class PredictionController {
private final LstmPredictionService predictionService;
public PredictionController(LstmPredictionService predictionService) {
this.predictionService = predictionService;
}
@GetMapping
public String showPrediction(@RequestParam(required = false) String symbol,
Model model) {
String predictionSymbol = symbol != null ? symbol : "NEPSE";
try {
PredictionResult prediction = predictionService.predictStock(predictionSymbol);
model.addAttribute("prediction", prediction);
model.addAttribute("symbol", predictionSymbol);
} catch (Exception e) {
model.addAttribute("error", "Error generating prediction: " + e.getMessage());
}
// Remove user reference
return "prediction";
}
}
2.domain
package com.nepse.domain;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDate;
@Entity
@Table(name = "portfolio")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Portfolio {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "user_id", nullable = false)
private User user;
@Column(nullable = false)
private String symbol;
@Column(nullable = false)
private int quantity;
@Column(nullable = false)
private double averagePrice;
@Column(nullable = false)
private LocalDate purchaseDate = LocalDate.now();
@Column(nullable = false)
private String transactionType; // BUY or SELL
private String notes;
@Column(nullable = false)
private double totalInvestment;
// Additional fields for tracking
private double currentValue;
private double profitLoss;
private double profitLossPercentage;
}
package com.nepse.domain;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDate;
@Entity
@Table(name = "stock_data")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class StockData {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(nullable = false)
private String symbol;
@Column(nullable = false)
private LocalDate date;
@Column(nullable = false)
private double openingPrice;
@Column(nullable = false)
private double closingPrice;
@Column(nullable = false)
private double highPrice;
@Column(nullable = false)
private double lowPrice;
@Column(nullable = false)
private long volume;
@Column(nullable = false)
private double changeAmount;
@Column(nullable = false)
private double changePercentage;
// Additional technical indicators for LSTM model
private double movingAverage5;
private double movingAverage20;
private double movingAverage50;
private double rsi14;
private double macd;
private double bollingerUpper;
private double bollingerLower;
}
package com.nepse.domain;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
@Entity
@Table(name = "users")
@Data
@NoArgsConstructor
@AllArgsConstructor
public class User {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(unique = true, nullable = false)
private String username;
@Column(nullable = false)
private String password; // Added back for authentication
private String fullName;
private String email;
@Column(nullable = false)
private LocalDateTime createdAt = LocalDateTime.now();
}
3.dto
package com.nepse.dto;
import jakarta.validation.constraints.NotBlank;
public class LoginRequest {
@NotBlank(message = "Username is required")
private String username;
@NotBlank(message = "Password is required")
private String password;
// Getters and Setters
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public String getPassword() {
return password;
}
public void setPassword(String password) {
this.password = password;
}
}
package com.nepse.dto;
import java.math.BigDecimal;
public class PortfolioItem {
private String symbol;
private int quantity;
private BigDecimal averagePrice;
private BigDecimal currentPrice;
private BigDecimal investmentValue;
private BigDecimal currentValue;
private BigDecimal profitLoss;
private BigDecimal profitLossPercentage;
// Constructors, Getters and Setters
public PortfolioItem() {
}
public PortfolioItem(String symbol, int quantity, BigDecimal averagePrice, BigDecimal
currentPrice) {
this.symbol = symbol;
this.quantity = quantity;
this.averagePrice = averagePrice;
this.currentPrice = currentPrice;
calculateValues();
}
private void calculateValues() {
this.investmentValue = averagePrice.multiply(BigDecimal.valueOf(quantity));
this.currentValue = currentPrice.multiply(BigDecimal.valueOf(quantity));
this.profitLoss = currentValue.subtract(investmentValue);
this.profitLossPercentage = investmentValue.compareTo(BigDecimal.ZERO) != 0
? profitLoss.divide(investmentValue, 4,
BigDecimal.ROUND_HALF_UP).multiply(BigDecimal.valueOf(100))
: BigDecimal.ZERO;
}
// Getters and Setters
public String getSymbol() {
return symbol;
}
public void setSymbol(String symbol) {
this.symbol = symbol;
}
public int getQuantity() {
return quantity;
}
public void setQuantity(int quantity) {
this.quantity = quantity;
calculateValues();
}
public BigDecimal getAveragePrice() {
return averagePrice;
}
public void setAveragePrice(BigDecimal averagePrice) {
this.averagePrice = averagePrice;
calculateValues();
}
public BigDecimal getCurrentPrice() {
return currentPrice;
}
public void setCurrentPrice(BigDecimal currentPrice) {
this.currentPrice = currentPrice;
calculateValues();
}
public BigDecimal getInvestmentValue() {
return investmentValue;
}
public BigDecimal getCurrentValue() {
return currentValue;
}
public BigDecimal getProfitLoss() {
return profitLoss;
}
public BigDecimal getProfitLossPercentage() {
return profitLossPercentage;
}
}
package com.nepse.dto;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Positive;
public class
PortfolioRequest {
@NotBlank(message = "Stock symbol is required")
private String symbol;
@NotNull(message = "Quantity is required")
@Min(value = 1, message = "Quantity must be at least 1")
private Integer quantity;
@NotNull(message = "Price is required")
@Positive(message = "Price must be positive")
private Double averagePrice;
// Getters and Setters
public String getSymbol() {
return symbol;
}
public void setSymbol(String symbol) {
this.symbol = symbol;
}
public Integer getQuantity() {
return quantity;
}
public void setQuantity(Integer quantity) {
this.quantity = quantity;
}
public Double getAveragePrice() {
return averagePrice;
}
public void setAveragePrice(Double averagePrice) {
this.averagePrice = averagePrice;
}
}
package com.nepse.dto;
import java.math.BigDecimal;
import java.util.List;
public class PortfolioSummary {
private final double totalValue;
private final double todaysGain;
private final double overallGain;
private final double gainPercentage;
private final List<PortfolioItem> items;
public PortfolioSummary(double totalValue, double todaysGain,
double overallGain, double gainPercentage,
List<PortfolioItem> items) {
this.totalValue = totalValue;
this.todaysGain = todaysGain;
this.overallGain = overallGain;
this.gainPercentage = gainPercentage;
this.items = items;
}
// Getters
public double getTotalValue() { return totalValue; }
public double getTodaysGain() { return todaysGain; }
public double getOverallGain() { return overallGain; }
public double getGainPercentage() { return gainPercentage; }
public List<PortfolioItem> getItems() { return items; }
// PortfolioItem inner class
public static class PortfolioItem {
private final String symbol;
private final int quantity;
private final double averagePrice;
private final double currentPrice;
private final double gain;
private final double gainPercentage;
public PortfolioItem(String symbol, int quantity,
double averagePrice, double currentPrice,
double gain, double gainPercentage) {
this.symbol = symbol;
this.quantity = quantity;
this.averagePrice = averagePrice;
this.currentPrice = currentPrice;
this.gain = gain;
this.gainPercentage = gainPercentage;
}
// Getters
public String getSymbol() { return symbol; }
public int getQuantity() { return quantity; }
public double getAveragePrice() { return averagePrice; }
public double getCurrentPrice() { return currentPrice; }
public double getGain() { return gain; }
public double getGainPercentage() { return gainPercentage; }
}
}
package com.nepse.dto;
import java.time.LocalDate;
public class PredictionResult {
private String symbol;
private double currentPrice;
private double predictedPrice;
private String trend; // UP/DOWN
private double confidence; // 0-100
private double potentialGain; // percentage
private LocalDate predictionDate;
private LocalDate targetDate;
public PredictionResult(String symbol, double currentPrice, double predictedPrice, String s,
double confidence, double v, LocalDate now, LocalDate localDate) {
}
// Getters and Setters
public String getSymbol() {
return symbol;
}
public void setSymbol(String symbol) {
this.symbol = symbol;
}
public double getCurrentPrice() {
return currentPrice;
}
public void setCurrentPrice(double currentPrice) {
this.currentPrice = currentPrice;
}
public double getPredictedPrice() {
return predictedPrice;
}
public void setPredictedPrice(double predictedPrice) {
this.predictedPrice = predictedPrice;
}
public String getTrend() {
return trend;
}
public void setTrend(String trend) {
this.trend = trend;
}
public double getConfidence() {
return confidence;
}
public void setConfidence(double confidence) {
this.confidence = confidence;
}
public double getPotentialGain() {
return potentialGain;
}
public void setPotentialGain(double potentialGain) {
this.potentialGain = potentialGain;
}
public LocalDate getPredictionDate() {
return predictionDate;
}
public void setPredictionDate(LocalDate predictionDate) {
this.predictionDate = predictionDate;
}
public LocalDate getTargetDate() {
return targetDate;
}
public void setTargetDate(LocalDate targetDate) {
this.targetDate = targetDate;
}
}
package com.nepse.dto;
import jakarta.validation.constraints.Email;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;
public class RegisterRequest {
@NotBlank(message = "Username is required")
@Size(min = 3, max = 20, message = "Username must be between 3 and 20 characters")
private String username;
@NotBlank(message = "Password is required")
@Size(min = 6, max = 40, message = "Password must be between 6 and 40 characters")
private String password;
@NotBlank(message = "Full name is required")
private String fullName;
@NotBlank(message = "Email is required")
@Email(message = "Email should be valid")
private String email;
// Getters and Setters
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public String getPassword() {
return password;
}
public void setPassword(String password) {
this.password = password;
}
public String getFullName() {
return fullName;
}
public void setFullName(String fullName) {
this.fullName = fullName;
}
public String getEmail() {
return email;
}
public void setEmail(String email) {
this.email = email;
}
}
4.exceprtion
package com.nepse.exception;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ResponseStatus;
@ResponseStatus(value = HttpStatus.BAD_REQUEST)
public class BadRequestException extends RuntimeException {
public BadRequestException(String message) {
super(message);
}
public BadRequestException(String message, Throwable cause) {
super(message, cause);
}
}
package com.nepse.exception;
public record ErrorResponse(
int status,
String message,
long timestamp
) {}
package com.nepse.exception;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.context.request.WebRequest;
@ControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(BadRequestException.class)
public ResponseEntity<ErrorResponse> handleBadRequest(
BadRequestException ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.BAD_REQUEST.value(),
ex.getMessage(),
System.currentTimeMillis());
return ResponseEntity.badRequest().body(response);
}
@ExceptionHandler(ResourceNotFoundException.class)
public ResponseEntity<ErrorResponse> handleResourceNotFound(
ResourceNotFoundException ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.NOT_FOUND.value(),
ex.getMessage(),
System.currentTimeMillis());
return ResponseEntity.status(HttpStatus.NOT_FOUND).body(response);
}
// Add generic exception handler (optional but recommended)
@ExceptionHandler(Exception.class)
public ResponseEntity<ErrorResponse> handleGlobalException(
Exception ex, WebRequest request) {
ErrorResponse response = new ErrorResponse(
HttpStatus.INTERNAL_SERVER_ERROR.value(),
"An unexpected error occurred",
System.currentTimeMillis());
return ResponseEntity.internalServerError().body(response);
}
}
package com.nepse.exception;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.ResponseStatus;
@ResponseStatus(value = HttpStatus.NOT_FOUND)
public class ResourceNotFoundException extends RuntimeException {
private final String resourceName;
private final String fieldName;
private final Object fieldValue;
public ResourceNotFoundException(String resourceName, String fieldName, Object
fieldValue) {
super(String.format("%s not found with %s : '%s'", resourceName, fieldName, fieldValue));
this.resourceName = resourceName;
this.fieldName = fieldName;
this.fieldValue = fieldValue;
}
public String getResourceName() {
return resourceName;
}
public String getFieldName() {
return fieldName;
}
public Object getFieldValue() {
return fieldValue;
}
}
5.init
package com.nepse.init;
import com.nepse.service.StockDataService;
import jakarta.annotation.PostConstruct;
import org.springframework.stereotype.Component;
import java.io.IOException;
@Component
public class DataLoader {
private final StockDataService stockDataService;
public DataLoader(StockDataService stockDataService) {
this.stockDataService = stockDataService;
}
@PostConstruct
public void importData() {
String filePath = "src/main/resources/static/data/NABIL.csv"; // path to your CSV file
String symbol = "NABIL"; // the stock symbol
try {
stockDataService.importStockDataFromCsv(filePath, symbol);
System.out.println("Stock data imported successfully.");
} catch (IOException e) {
e.printStackTrace();
}
}
}
repository
package com.nepse.repository;
import com.nepse.domain.Portfolio;
import com.nepse.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
public interface PortfolioRepository extends JpaRepository<Portfolio, Long> {
List<Portfolio> findByUser(User user);
List<Portfolio> findByUserAndSymbol(User user, String symbol);
@Query("SELECT p FROM Portfolio p WHERE p.user = :user GROUP BY p.symbol")
List<Portfolio> findDistinctByUser(User user);
@Query("SELECT p.symbol FROM Portfolio p WHERE p.user = :user GROUP BY p.symbol")
List<String> findDistinctSymbolsByUser(User user);
@Query("SELECT SUM(p.quantity) FROM Portfolio p WHERE p.user = :user AND p.symbol =
:symbol AND p.transactionType = 'BUY'")
Integer sumBoughtQuantityByUserAndSymbol(User user, String symbol);
@Query("SELECT SUM(p.quantity) FROM Portfolio p WHERE p.user = :user AND p.symbol =
:symbol AND p.transactionType = 'SELL'")
Integer sumSoldQuantityByUserAndSymbol(User user, String symbol);
@Query("SELECT COALESCE(SUM(p.quantity * p.averagePrice), 0) FROM Portfolio p WHERE
p.user = :user AND p.symbol = :symbol AND p.transactionType = 'BUY'")
Double sumInvestmentByUserAndSymbol(User user, String symbol);
}
package com.nepse.repository;
import com.nepse.domain.StockData;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.stereotype.Repository;
import java.time.LocalDate;
import java.util.List;
@Repository
public interface StockDataRepository extends JpaRepository<StockData, Long> {
List<StockData> findBySymbolOrderByDateDesc(String symbol);
@Query("SELECT s FROM StockData s WHERE s.symbol = :symbol ORDER BY s.date DESC
LIMIT :limit")
List<StockData> findTopNBySymbolOrderByDateDesc(String symbol, int limit);
StockData findTopBySymbolOrderByDateDesc(String symbol);
StockData findBySymbolAndDate(String symbol, LocalDate date);
@Query("SELECT DISTINCT s.symbol FROM StockData s")
List<String> findAllDistinctSymbols();
@Query("SELECT s FROM StockData s WHERE s.date = (SELECT MAX(s2.date) FROM
StockData s2) ORDER BY s.changePercentage DESC LIMIT :count")
List<StockData> getTopGainers(int count);
@Query("SELECT s FROM StockData s WHERE s.date = (SELECT MAX(s2.date) FROM
StockData s2) ORDER BY s.changePercentage ASC LIMIT :count")
List<StockData> getTopLosers(int count);
List<StockData> findTop60BySymbolOrderByDateDesc(String symbol);
long countBySymbol(String symbol);
@Modifying
@Query("DELETE FROM StockData s WHERE s.symbol = :symbol")
void deleteBySymbol(String symbol);
List<StockData> findBySymbol(String symbol);
}
package com.nepse.repository;
import com.nepse.domain.User;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.Optional;
@Repository
public interface UserRepository extends JpaRepository<User, Long> {
Optional<User> findByUsername(String username);
Optional<User> findByEmail(String email);
Boolean existsByUsername(String username);
Boolean existsByEmail(String email);
@Query("SELECT u FROM User u WHERE u.username = :username OR u.email = :email")
Optional<User> findByUsernameOrEmail(@Param("username") String username,
@Param("email") String email);
}
6.Service
package com.nepse.service;
import com.nepse.domain.StockData;
import com.nepse.dto.PredictionResult;
import com.nepse.repository.StockDataRepository;
import com.nepse.util.ModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.io.File;
import java.io.IOException;
import java.time.LocalDate;
import java.util.Collections;
import java.util.List;
@Service
public class LstmPredictionService {
private final StockDataRepository stockDataRepository;
@Value("${lstm.model.directory}")
private String modelDirectory;
public LstmPredictionService(StockDataRepository stockDataRepository) {
this.stockDataRepository = stockDataRepository;
}
public PredictionResult predictStock(String symbol) {
// Get historical data
List<StockData> historicalData = stockDataRepository
.findTop60BySymbolOrderByDateDesc(symbol);
Collections.reverse(historicalData);
System.out.println("Fetched data size: " + historicalData.size());
for (StockData data : historicalData) {
System.out.println(data.getDate() + " - " + data.getClosingPrice());
}
if (historicalData.size() < 60) {
throw new IllegalArgumentException("Not enough historical data for prediction");
}
// Preprocess data
double[] normalizedData = normalizeData(historicalData);
// Load model
MultiLayerNetwork model = loadModel(symbol);
// Prepare input
INDArray input = Nd4j.create(normalizedData, new int[]{1, 60, 1});
// Make prediction
INDArray output = model.output(input);
double predictedValue = output.getDouble(0);
// Post-process prediction
double min =
historicalData.stream().mapToDouble(StockData::getClosingPrice).min().orElse(0);
double max =
historicalData.stream().mapToDouble(StockData::getClosingPrice).max().orElse(1);
double predictedPrice = predictedValue * (max - min) + min;
// Create result
double currentPrice = historicalData.get(0).getClosingPrice();
double confidence = calculateConfidence(historicalData, predictedPrice);
return new PredictionResult(
symbol,
currentPrice,
predictedPrice,
predictedPrice > currentPrice ? "UP" : "DOWN",
confidence,
((predictedPrice - currentPrice) / currentPrice) * 100,
LocalDate.now(),
LocalDate.now().plusDays(7)
);
}
private double[] normalizeData(List<StockData> data) {
double min = data.stream().mapToDouble(StockData::getClosingPrice).min().orElse(0);
double max = data.stream().mapToDouble(StockData::getClosingPrice).max().orElse(1);
return data.stream()
.mapToDouble(d -> (d.getClosingPrice() - min) / (max - min))
.toArray();
}
private MultiLayerNetwork loadModel(String symbol) {
try {
return ModelUtils.loadModel(new File(modelDirectory, symbol + ".zip"));
} catch (IOException e) {
throw new RuntimeException("Failed to load model for symbol: " + symbol, e);
}
}
private double calculateConfidence(List<StockData> historicalData, double predictedPrice) {
// Simple confidence calculation based on recent volatility
double sum = 0;
double count = 0;
for (int i = 0; i < historicalData.size() - 1; i++) {
double change = Math.abs(historicalData.get(i).getClosingPrice() -
historicalData.get(i + 1).getClosingPrice());
sum += change;
count++;
}
double avgChange = sum / count;
double diff = Math.abs(predictedPrice - historicalData.get(0).getClosingPrice());
// Higher confidence when prediction is within average volatility range
return Math.min(100, 80 + (20 * (1 - (diff / (avgChange * 3)))));
}
}
package com.nepse.service;
import com.nepse.domain.Portfolio;
import com.nepse.domain.StockData;
import com.nepse.domain.User;
import com.nepse.dto.PortfolioRequest;
import com.nepse.dto.PortfolioSummary;
import com.nepse.exception.ResourceNotFoundException;
import com.nepse.repository.PortfolioRepository;
import com.nepse.repository.UserRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class PortfolioService {
private final PortfolioRepository portfolioRepository;
private final UserRepository userRepository;
private final StockDataService stockDataService;
public PortfolioService(PortfolioRepository portfolioRepository,
UserRepository userRepository,
StockDataService stockDataService) {
this.portfolioRepository = portfolioRepository;
this.userRepository = userRepository;
this.stockDataService = stockDataService;
}
@Transactional
public void addStockToPortfolio(Long userId, String symbol, int quantity, double averagePrice)
{
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));
Portfolio portfolio = new Portfolio();
portfolio.setUser(user);
portfolio.setSymbol(symbol);
portfolio.setQuantity(quantity);
portfolio.setAveragePrice(averagePrice);
portfolio.setTransactionType("BUY");
portfolio.setTotalInvestment(quantity * averagePrice);
portfolioRepository.save(portfolio);
}
@Transactional
public void removeStockFromPortfolio(Long userId, String symbol, int quantity) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));
int currentQuantity = getAvailableQuantity(user, symbol);
if (currentQuantity < quantity) {
throw new IllegalArgumentException("Not enough shares to sell");
}
Portfolio portfolio = new Portfolio();
portfolio.setUser(user);
portfolio.setSymbol(symbol);
portfolio.setQuantity(quantity);
portfolio.setAveragePrice(stockDataService.getCurrentPrice(symbol));
portfolio.setTransactionType("SELL");
portfolio.setTotalInvestment(quantity * portfolio.getAveragePrice());
portfolioRepository.save(portfolio);
}
@Transactional(readOnly = true)
public PortfolioSummary getUserPortfolio(Long userId) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", userId));
List<String> symbols = portfolioRepository.findDistinctSymbolsByUser(user);
// Using a container object to hold our accumulators
PortfolioSummaryContainer container = new PortfolioSummaryContainer();
List<PortfolioSummary.PortfolioItem> items = symbols.stream()
.map(symbol -> processSymbol(user, symbol, container))
.collect(Collectors.toList());
double overallGain = container.totalValue - container.totalInvestment;
double gainPercentage = container.totalInvestment > 0 ?
(overallGain / container.totalInvestment) * 100 : 0;
return new PortfolioSummary(
container.totalValue,
container.todaysGain,
overallGain,
gainPercentage,
items
);
}
private PortfolioSummary.PortfolioItem processSymbol(User user, String symbol,
PortfolioSummaryContainer container) {
int bought = portfolioRepository.sumBoughtQuantityByUserAndSymbol(user, symbol);
int sold = portfolioRepository.sumSoldQuantityByUserAndSymbol(user, symbol);
int available = bought - sold;
double currentPrice = stockDataService.getCurrentPrice(symbol);
double investment = portfolioRepository.sumInvestmentByUserAndSymbol(user, symbol);
double avgPrice = bought > 0 ? investment / bought : 0;
double value = available * currentPrice;
double gain = value - (available * avgPrice);
double gainPercentage = (available * avgPrice) != 0 ? (gain / (available * avgPrice)) * 100 : 0;
// Update container values
container.totalValue += value;
container.totalInvestment += (available * avgPrice);
// Calculate today's gain with null check
List<StockData> history = stockDataService.getHistoricalData(symbol, 2);
double yesterdayPrice = history.size() > 1 ? history.get(1).getClosingPrice() : currentPrice;
container.todaysGain += available * (currentPrice - yesterdayPrice);
return new PortfolioSummary.PortfolioItem(
symbol,
available,
avgPrice,
currentPrice,
gain,
gainPercentage
);
}
private int getAvailableQuantity(User user, String symbol) {
Integer bought = portfolioRepository.sumBoughtQuantityByUserAndSymbol(user, symbol);
Integer sold = portfolioRepository.sumSoldQuantityByUserAndSymbol(user, symbol);
return (bought != null ? bought : 0) - (sold != null ? sold : 0);
}
// Helper container class to hold accumulated values
private static class PortfolioSummaryContainer {
double totalValue = 0;
double totalInvestment = 0;
double todaysGain = 0;
}
}
package com.nepse.service;
import com.nepse.domain.StockData;
import com.nepse.repository.StockDataRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@Service
public class StockDataService {
private final StockDataRepository stockDataRepository;
public StockDataService(StockDataRepository stockDataRepository) {
this.stockDataRepository = stockDataRepository;
}
@Transactional(readOnly = true)
public List<StockData> getHistoricalData(String symbol, int days) {
return stockDataRepository.findTopNBySymbolOrderByDateDesc(symbol, days);
}
@Transactional(readOnly = true)
public StockData getLatestData(String symbol) {
return stockDataRepository.findTopBySymbolOrderByDateDesc(symbol);
}
@Transactional(readOnly = true)
public List<StockData> getTopGainers(int count) {
return stockDataRepository.getTopGainers(count);
}
@Transactional(readOnly = true)
public List<StockData> getTopLosers(int count) {
return stockDataRepository.getTopLosers(count);
}
@Transactional(readOnly = true)
public List<String> getAllSymbols() {
return stockDataRepository.findAllDistinctSymbols();
}
@Transactional
public void updateStockData(List<StockData> stockDataList) {
stockDataRepository.saveAll(stockDataList);
}
@Transactional(readOnly = true)
public Double getCurrentPrice(String symbol) {
StockData latest = stockDataRepository.findTopBySymbolOrderByDateDesc(symbol);
return latest != null ? latest.getClosingPrice() : 0.0;
}
// === NEW METHOD: Import stock data from CSV ===
@Transactional
public void importStockDataFromCsv(String filePath, String symbol) throws IOException {
List<StockData> stockDataList = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
String line;
// Skip CSV header
reader.readLine();
while ((line = reader.readLine()) != null) {
String[] tokens = line.split(",");
// Make sure we have enough columns
if (tokens.length < 8) {
continue; // skip incomplete rows
}
// Parse data - adjust column indexes as per your CSV format
LocalDate date = LocalDate.parse(tokens[0].trim()); // e.g. "2025-08-05"
double openingPrice = Double.parseDouble(tokens[1].trim());
double closingPrice = Double.parseDouble(tokens[2].trim());
double highPrice = Double.parseDouble(tokens[3].trim());
double lowPrice = Double.parseDouble(tokens[4].trim());
long volume = Long.parseLong(tokens[5].trim());
double changeAmount = Double.parseDouble(tokens[6].trim());
double changePercentage = Double.parseDouble(tokens[7].trim());
// Optional: Parse technical indicators if present in CSV
double movingAverage5 = tokens.length > 8 ? Double.parseDouble(tokens[8].trim()) :
0.0;
double movingAverage20 = tokens.length > 9 ? Double.parseDouble(tokens[9].trim()) :
0.0;
double movingAverage50 = tokens.length > 10 ?
Double.parseDouble(tokens[10].trim()) : 0.0;
double rsi14 = tokens.length > 11 ? Double.parseDouble(tokens[11].trim()) : 0.0;
double macd = tokens.length > 12 ? Double.parseDouble(tokens[12].trim()) : 0.0;
double bollingerUpper = tokens.length > 13 ? Double.parseDouble(tokens[13].trim()) :
0.0;
double bollingerLower = tokens.length > 14 ? Double.parseDouble(tokens[14].trim()) :
0.0;
// Create StockData entity and set fields
StockData stockData = new StockData();
stockData.setSymbol(symbol);
stockData.setDate(date);
stockData.setOpeningPrice(openingPrice);
stockData.setClosingPrice(closingPrice);
stockData.setHighPrice(highPrice);
stockData.setLowPrice(lowPrice);
stockData.setVolume(volume);
stockData.setChangeAmount(changeAmount);
stockData.setChangePercentage(changePercentage);
stockData.setMovingAverage5(movingAverage5);
stockData.setMovingAverage20(movingAverage20);
stockData.setMovingAverage50(movingAverage50);
stockData.setRsi14(rsi14);
stockData.setMacd(macd);
stockData.setBollingerUpper(bollingerUpper);
stockData.setBollingerLower(bollingerLower);
stockDataList.add(stockData);
}
}
// Save all records in batch
stockDataRepository.saveAll(stockDataList);
}
@Transactional
public void ensureMinimumData(String symbol, int minDays) {
long count = stockDataRepository.countBySymbol(symbol);
if (count < minDays) {
// Add sample data if we don't have enough
double basePrice = switch (symbol) {
case "NABIL" -> 2000.0;
case "NICA" -> 1500.0;
case "NBL" -> 1200.0;
case "SCB" -> 1800.0;
case "HIDCL" -> 500.0;
case "GBIME" -> 800.0;
default -> 1000.0;
};
// Delete existing data if any
List<StockData> existing = stockDataRepository.findBySymbol(symbol);
if (!existing.isEmpty()) {
stockDataRepository.deleteAll(existing);
}
}}}
package com.nepse.service;
import com.nepse.domain.User;
import com.nepse.dto.RegisterRequest;
import com.nepse.exception.BadRequestException;
import com.nepse.repository.UserRepository;
import com.nepse.security.UserPrincipal;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.nepse.exception.ResourceNotFoundException;
@Service
public class UserService {
private final UserRepository userRepository;
private final PasswordEncoder passwordEncoder;
public UserService(UserRepository userRepository, PasswordEncoder passwordEncoder) {
this.userRepository = userRepository;
this.passwordEncoder = passwordEncoder;
}
@Transactional
public User createUser(RegisterRequest registerRequest) {
if (userRepository.existsByUsername(registerRequest.getUsername())) {
throw new BadRequestException("Username already in use");
}
if (userRepository.existsByEmail(registerRequest.getEmail())) {
throw new BadRequestException("Email already in use");
}
User user = new User();
user.setUsername(registerRequest.getUsername());
user.setPassword(passwordEncoder.encode(registerRequest.getPassword()));
user.setFullName(registerRequest.getFullName());
user.setEmail(registerRequest.getEmail());
return userRepository.save(user);
}
@Transactional(readOnly = true)
public UserPrincipal loadUserById(Long id) {
User user = userRepository.findById(id)
.orElseThrow(() -> new ResourceNotFoundException("User", "id", id));
return UserPrincipal.create(user);
}
public boolean existsByUsername(String username) {
return userRepository.existsByUsername(username);
}
public boolean existsByEmail(String email) {
return userRepository.existsByEmail(email);
}
}
7Util
// src/main/java/com/nepse/util/DataInitializer.java
package com.nepse.util;
import com.nepse.domain.StockData;
import com.nepse.repository.StockDataRepository;
import jakarta.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
@Component
public class DataInitializer {
private final StockDataRepository stockDataRepository;
@Autowired
public DataInitializer(StockDataRepository stockDataRepository) {
this.stockDataRepository = stockDataRepository;
}
@PostConstruct
@Transactional
public void init() {
// Ensure we have data for key symbols
ensureSymbolData("NABIL", 2000.0, 60);
ensureSymbolData("NICA", 1500.0, 60);
ensureSymbolData("NBL", 1200.0, 60);
ensureSymbolData("SCB", 1800.0, 60);
ensureSymbolData("HIDCL", 500.0, 60);
ensureSymbolData("GBIME", 800.0, 60);
private void ensureSymbolData(String symbol, double startPrice, int days) {
long count = stockDataRepository.countBySymbol(symbol);
if (count < days) {
// Remove existing incomplete data
if (count > 0) {
stockDataRepository.deleteBySymbol(symbol);
}
// Add new sample data
addSampleData(symbol, startPrice, days);
System.out.println("Added " + days + " days of data for " + symbol);
}
}
private void addSampleData(String symbol, double startPrice, int days) {
List<StockData> data = new ArrayList<>();
LocalDate startDate = LocalDate.now().minusDays(days);
double price = startPrice;
for (int i = 0; i < days; i++) {
double change = (Math.random() - 0.5) * 50; // Random change between -50 to +50
price += change;
StockData stock = new StockData();
stock.setSymbol(symbol);
stock.setDate(startDate.plusDays(i));
stock.setOpeningPrice(price - 10);
stock.setClosingPrice(price);
stock.setHighPrice(price + 5);
stock.setLowPrice(price - 15);
stock.setVolume(100000 + (long)(Math.random() * 50000));
stock.setChangeAmount(change);
stock.setChangePercentage((change / (price - change)) * 100);
// Add technical indicators
stock.setMovingAverage5(calculateMovingAverage(data, 5, price));
stock.setMovingAverage20(calculateMovingAverage(data, 20, price));
stock.setMovingAverage50(calculateMovingAverage(data, 50, price));
data.add(stock);
}
stockDataRepository.saveAll(data);
}
private double calculateMovingAverage(List<StockData> data, int period, double currentPrice)
{
if (data.size() < period - 1) {
return currentPrice;
}
double sum = currentPrice;
int count = 1;
for (int i = data.size() - 1; i >= Math.max(0, data.size() - period + 1); i--) {
sum += data.get(i).getClosingPrice();
count++;
}
return sum / count;
}
}
package com.nepse.util;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
public class ModelUtils {
private static final Logger logger = LoggerFactory.getLogger(ModelUtils.class);
/**
* Saves the LSTM model and its normalizer to disk
*
* @param model The trained LSTM model
* @param normalizer The data normalizer used with the model
* @param modelFile The file to save the model to
* @throws IOException If there's an error saving the files
*/
public static void saveModel(MultiLayerNetwork model,
NormalizerMinMaxScaler normalizer,
File modelFile) throws IOException {
// Create parent directories if they don't exist
modelFile.getParentFile().mkdirs();
// Save the model
ModelSerializer.writeModel(model, modelFile, true);
logger.info("Saved model to: {}", modelFile.getAbsolutePath());
// Save the normalizer
File normalizerFile = new File(modelFile.getParent(),
modelFile.getName().replace(".zip", "-normalizer.bin"));
NormalizerSerializer.getDefault().write(normalizer, normalizerFile);
logger.info("Saved normalizer to: {}", normalizerFile.getAbsolutePath());
}
/**
* Loads a trained LSTM model from disk
*
* @param modelFile The file containing the saved model
* @return The loaded MultiLayerNetwork model
* @throws IOException If there's an error loading the model
*/
public static MultiLayerNetwork loadModel(File modelFile) throws IOException {
if (!modelFile.exists()) {
throw new IOException("Model file not found: " + modelFile.getAbsolutePath());
}
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
logger.info("Loaded model from: {}", modelFile.getAbsolutePath());
return model;
}
/**
* Loads the normalizer used with a specific model
*
* @param modelFile The model file path
* @return The loaded NormalizerMinMaxScaler
* @throws IOException If there's an error loading the normalizer
*/
public static NormalizerMinMaxScaler loadNormalizer(File modelFile) throws Exception {
File normalizerFile = new File(modelFile.getParent(),
modelFile.getName().replace(".zip", "-normalizer.bin"));
if (!normalizerFile.exists()) {
throw new IOException("Normalizer file not found: " + normalizerFile.getAbsolutePath());
}
return NormalizerSerializer.getDefault().restore(normalizerFile);
}
/**
* Checks if a trained model exists for a given symbol
*
* @param modelDir The directory containing models
* @param symbol The stock symbol to check
* @return true if model exists, false otherwise
*/
public static boolean modelExists(File modelDir, String symbol) {
File modelFile = new File(modelDir, symbol + ".zip");
File normalizerFile = new File(modelDir, symbol + "-normalizer.bin");
return modelFile.exists() && normalizerFile.exists();
}
/**
* Deletes model and normalizer files for a given symbol
*
* @param modelDir The directory containing models
* @param symbol The stock symbol to delete
* @return true if files were deleted, false otherwise
*/
public static boolean deleteModel(File modelDir, String symbol) {
File modelFile = new File(modelDir, symbol + ".zip");
File normalizerFile = new File(modelDir, symbol + "-normalizer.bin");
boolean modelDeleted = modelFile.exists() && modelFile.delete();
boolean normalizerDeleted = normalizerFile.exists() && normalizerFile.delete();
return modelDeleted || normalizerDeleted;
}
}
and finally the main class
package com.nepse;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class StockPredictionApplication {
public static void main(String[] args) {
SpringApplication.run(StockPredictionApplication.class, args);
}
}