From 6d282e155e528b7ab17181a7d95f51eb723f6c99 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 26 May 2026 09:41:07 +0800 Subject: [PATCH 01/11] update ML KEDF output --- .../source_estate/module_pot/pot_ml_exx.cpp | 41 +++++++++++++------ source/source_estate/module_pot/pot_ml_exx.h | 11 ++--- .../module_pot/pot_ml_exx_label.cpp | 3 +- .../source_io/module_ctrl/ctrl_output_pw.cpp | 3 +- .../module_ml/cal_mlkedf_descriptors.cpp | 5 ++- .../module_ml/cal_mlkedf_descriptors.h | 3 +- .../source_pw/module_ofdft/kedf_manager.cpp | 40 +++++++++++++++--- source/source_pw/module_ofdft/kedf_ml.cpp | 20 +++++---- source/source_pw/module_ofdft/kedf_ml.h | 6 ++- .../source_pw/module_ofdft/kedf_ml_label.cpp | 7 ++-- source/source_pw/module_ofdft/ml_base.cpp | 10 ++--- source/source_pw/module_ofdft/ml_base.h | 2 +- source/source_pw/module_ofdft/nn_of.cpp | 6 +-- source/source_pw/module_ofdft/nn_of.h | 3 +- 14 files changed, 108 insertions(+), 52 deletions(-) diff --git a/source/source_estate/module_pot/pot_ml_exx.cpp b/source/source_estate/module_pot/pot_ml_exx.cpp index 53393b1e33f..b5c1fbcd598 100644 --- a/source/source_estate/module_pot/pot_ml_exx.cpp +++ b/source/source_estate/module_pot/pot_ml_exx.cpp @@ -17,13 +17,13 @@ ML_EXX::ML_EXX() ML_EXX::~ML_EXX(){} -void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in) +void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running) { torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble)); auto output = torch::get_default_dtype(); - std::cout << "Default type: " << output << std::endl; + ofs_running << " Default type: " << output << std::endl; - this->set_device(inp.of_ml_device); + this->set_device(inp.of_ml_device, ofs_running); this->nx = rho_basis_in->nrxx; this->nx_tot = rho_basis_in->nrxx; @@ -46,17 +46,18 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, - inp.of_ml_tanhq_nl); + inp.of_ml_tanhq_nl, + ofs_running); - std::cout << "ninput = " << this->ninput << std::endl; + ofs_running << "ninput = " << this->ninput << std::endl; if (PARAM.inp.ml_exx) { int nnode = 100; int nlayer = 3; - this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device); + this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); torch::load(this->nn, "net.pt", this->device_type); - std::cout << "load net done" << std::endl; + ofs_running << "load net done" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -74,7 +75,7 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - std::cout << "feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << "feg_net_F = " << this->feg_net_F << std::endl; } } @@ -88,8 +89,24 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod this->chi_pnl = inp.of_ml_chi_pnl; this->chi_qnl = inp.of_ml_chi_qnl; - this->cal_tool->set_para(this->nx, inp.nelec, inp.of_tf_weight, inp.of_vw_weight, this->chi_p, this->chi_q, - this->chi_xi, this->chi_pnl, this->chi_qnl, this->nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, this->dV * rho_basis_in->nxyz, rho_basis_in); + this->cal_tool->set_para( + this->nx, + inp.nelec, + inp.of_tf_weight, + inp.of_vw_weight, + this->chi_p, + this->chi_q, + this->chi_xi, + this->chi_pnl, + this->chi_qnl, + this->nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, + this->dV * rho_basis_in->nxyz, + rho_basis_in, + ofs_running); } } @@ -177,7 +194,7 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba * @param prho charge density * @param pw_rho PW_Basis */ -void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho) +void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running) { // for test ===================== std::vector cshape = {(long unsigned) this->nx}; @@ -192,7 +209,7 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw for (int ir = 0; ir < this->nx; ++ir) { if (prho[0][ir] == 0.){ - std::cout << "WARNING: rho = 0" << std::endl; + ofs_running << "WARNING: rho = 0" << std::endl; } }; // ============================== diff --git a/source/source_estate/module_pot/pot_ml_exx.h b/source/source_estate/module_pot/pot_ml_exx.h index 5936add9050..a2a866801c2 100644 --- a/source/source_estate/module_pot/pot_ml_exx.h +++ b/source/source_estate/module_pot/pot_ml_exx.h @@ -16,13 +16,13 @@ class ML_EXX : public ML_Base ML_EXX(); virtual ~ML_EXX(); - void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in); + void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running); void ml_potential(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); // output all parameters void generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff); - void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho); + void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running); void init_data( const int &nkernel, @@ -40,7 +40,8 @@ class ML_EXX : public ML_Base const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ); double ml_exx_energy = 0.0; @@ -56,13 +57,13 @@ class PotML_EXX : public PotBase this->dynamic_mode = true; this->fixed_mode = false; - this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in); + this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in, GlobalV::ofs_running); } ~PotML_EXX() {}; void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override { - if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_); + if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_, GlobalV::ofs_running); this->ml_exx.ml_potential(chg->rho, this->rho_basis_, v_eff); } diff --git a/source/source_estate/module_pot/pot_ml_exx_label.cpp b/source/source_estate/module_pot/pot_ml_exx_label.cpp index 3908b7c5ef9..a073ce58832 100644 --- a/source/source_estate/module_pot/pot_ml_exx_label.cpp +++ b/source/source_estate/module_pot/pot_ml_exx_label.cpp @@ -40,7 +40,8 @@ void ML_EXX::init_data( const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ) { diff --git a/source/source_io/module_ctrl/ctrl_output_pw.cpp b/source/source_io/module_ctrl/ctrl_output_pw.cpp index f084854238b..580bc314efb 100644 --- a/source/source_io/module_ctrl/ctrl_output_pw.cpp +++ b/source/source_io/module_ctrl/ctrl_output_pw.cpp @@ -361,7 +361,8 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, ucell.omega, - pw_rho); + pw_rho, + GlobalV::ofs_running); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, stp.template get_psi_t(), diff --git a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp index b0b627fec23..4a1e3c86da6 100644 --- a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp +++ b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp @@ -19,7 +19,8 @@ void Cal_MLKEDF_Descriptors::set_para( const std::vector &yukawa_alpha, const std::vector &kernel_file, const double &omega, - const ModulePW::PW_Basis *pw_rho + const ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running ) { this->nx = nx; @@ -34,7 +35,7 @@ void Cal_MLKEDF_Descriptors::set_para( this->kernel_scaling = kernel_scaling; this->yukawa_alpha = yukawa_alpha; this->kernel_file = kernel_file; - std::cout << "nkernel = " << nkernel << std::endl; + ofs_running << "nkernel = " << nkernel << std::endl; if (PARAM.inp.of_wt_rho0 != 0) { diff --git a/source/source_io/module_ml/cal_mlkedf_descriptors.h b/source/source_io/module_ml/cal_mlkedf_descriptors.h index 2569d091c8a..7a0b70b69f3 100644 --- a/source/source_io/module_ml/cal_mlkedf_descriptors.h +++ b/source/source_io/module_ml/cal_mlkedf_descriptors.h @@ -38,7 +38,8 @@ class Cal_MLKEDF_Descriptors const std::vector &yukawa_alpha, const std::vector &kernel_file, const double &omega, - const ModulePW::PW_Basis *pw_rho); + const ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running); // get input parameters void getGamma(const double * const *prho, std::vector &rgamma); void getP(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::vector> &pnablaRho, std::vector &rp); diff --git a/source/source_pw/module_ofdft/kedf_manager.cpp b/source/source_pw/module_ofdft/kedf_manager.cpp index 8aaec2a8a4f..76a8f5ecbb4 100644 --- a/source/source_pw/module_ofdft/kedf_manager.cpp +++ b/source/source_pw/module_ofdft/kedf_manager.cpp @@ -109,12 +109,40 @@ void KEDF_Manager::init( { if (this->ml_ == nullptr) this->ml_ = new KEDF_ML(); - this->ml_->set_para(pw_rho->nrxx, dV, nelec, inp.of_tf_weight, inp.of_vw_weight, - inp.of_ml_chi_p, inp.of_ml_chi_q, inp.of_ml_chi_xi, inp.of_ml_chi_pnl, inp.of_ml_chi_qnl, - inp.of_ml_nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, - inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, inp.of_ml_gamma, inp.of_ml_p, inp.of_ml_q, inp.of_ml_tanhp, inp.of_ml_tanhq, - inp.of_ml_gammanl, inp.of_ml_pnl, inp.of_ml_qnl, inp.of_ml_xi, inp.of_ml_tanhxi, - inp.of_ml_tanhxi_nl, inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, inp.of_ml_tanhq_nl, inp.of_ml_device, pw_rho); + this->ml_->set_para( + pw_rho->nrxx, + dV, + nelec, + inp.of_tf_weight, + inp.of_vw_weight, + inp.of_ml_chi_p, + inp.of_ml_chi_q, + inp.of_ml_chi_xi, + inp.of_ml_chi_pnl, + inp.of_ml_chi_qnl, + inp.of_ml_nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, + inp.of_ml_gamma, + inp.of_ml_p, + inp.of_ml_q, + inp.of_ml_tanhp, + inp.of_ml_tanhq, + inp.of_ml_gammanl, + inp.of_ml_pnl, + inp.of_ml_qnl, + inp.of_ml_xi, + inp.of_ml_tanhxi, + inp.of_ml_tanhxi_nl, + inp.of_ml_tanh_pnl, + inp.of_ml_tanh_qnl, + inp.of_ml_tanhp_nl, + inp.of_ml_tanhq_nl, + inp.of_ml_device, + pw_rho, + GlobalV::ofs_running); } #endif } diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index 6b2643a7b58..ec3f9c1ac1d 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -38,14 +38,15 @@ void KEDF_ML::set_para( const std::vector &of_ml_tanhp_nl, const std::vector &of_ml_tanhq_nl, const std::string device_inpt, - ModulePW::PW_Basis *pw_rho + ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running ) { torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble)); auto output = torch::get_default_dtype(); - std::cout << "Default type: " << output << std::endl; + ofs_running << " Default type: " << output << std::endl; - this->set_device(device_inpt); + this->set_device(device_inpt, ofs_running); this->nx = nx; this->nx_tot = nx; @@ -68,17 +69,18 @@ void KEDF_ML::set_para( of_ml_tanh_pnl, of_ml_tanh_qnl, of_ml_tanhp_nl, - of_ml_tanhq_nl); + of_ml_tanhq_nl, + ofs_running); - std::cout << "ninput = " << ninput << std::endl; + ofs_running << "ninput = " << ninput << std::endl; if (PARAM.inp.of_kinetic == "ml") { int nnode = 100; int nlayer = 3; - this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device); + this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); torch::load(this->nn, "net.pt", this->device_type); - std::cout << "load net done" << std::endl; + ofs_running << "load net done" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -96,7 +98,7 @@ void KEDF_ML::set_para( this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - std::cout << "feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << "feg_net_F = " << this->feg_net_F << std::endl; } } @@ -111,7 +113,7 @@ void KEDF_ML::set_para( this->chi_qnl = chi_qnl; this->cal_tool->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q, - chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho); + chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho, ofs_running); } } diff --git a/source/source_pw/module_ofdft/kedf_ml.h b/source/source_pw/module_ofdft/kedf_ml.h index 202c6958f0c..d7acfb734ba 100644 --- a/source/source_pw/module_ofdft/kedf_ml.h +++ b/source/source_pw/module_ofdft/kedf_ml.h @@ -47,7 +47,8 @@ class KEDF_ML : public ML_Base const std::vector &of_ml_tanhp_nl, const std::vector &of_ml_tanhq_nl, const std::string device_inpt, - ModulePW::PW_Basis *pw_rho); + ModulePW::PW_Basis *pw_rho, + std::ostream& ofs_running); double get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho); // double get_energy_density(const double * const *prho, int is, int ir, ModulePW::PW_Basis *pw_rho); @@ -78,7 +79,8 @@ class KEDF_ML : public ML_Base const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ); }; diff --git a/source/source_pw/module_ofdft/kedf_ml_label.cpp b/source/source_pw/module_ofdft/kedf_ml_label.cpp index 100ad4c387e..0cce3d0cec3 100644 --- a/source/source_pw/module_ofdft/kedf_ml_label.cpp +++ b/source/source_pw/module_ofdft/kedf_ml_label.cpp @@ -38,7 +38,8 @@ void KEDF_ML::init_data( const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl + const std::vector &of_ml_tanhq_nl, + std::ostream& ofs_running ) { @@ -152,8 +153,8 @@ void KEDF_ML::init_data( this->descriptor2kernel[descriptor_type[i]].push_back(kernel_index[i]); this->descriptor2index[descriptor_type[i]].push_back(i); } - std::cout << "descriptor2index " << descriptor2index << std::endl; - std::cout << "descriptor2kernel " << descriptor2kernel << std::endl; + ofs_running << "descriptor2index " << descriptor2index << std::endl; + ofs_running << "descriptor2kernel " << descriptor2kernel << std::endl; this->ml_gamma = this->descriptor2index["gamma"].size() > 0; this->ml_p = this->descriptor2index["p"].size() > 0; diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index 0bb09b441ac..54feecb261a 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -10,24 +10,24 @@ ML_Base::~ML_Base() if (this->cal_tool) delete this->cal_tool; } -void ML_Base::set_device(std::string device_inpt) +void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_running) { if (device_inpt == "cpu") { - std::cout << "------------------- Running NN on CPU -------------------" << std::endl; + ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU; } else if (device_inpt == "gpu") { if (torch::cuda::cudnn_is_available()) { - std::cout << "------------------- Running NN on GPU -------------------" << std::endl; + ofs_running << "------------------- Running Neural Network on GPU -------------------" << std::endl; this->device_type = torch::kCUDA; } else { - std::cout << "--------------- Warning: GPU is unaviable ---------------" << std::endl; - std::cout << "------------------- Running NN on CPU -------------------" << std::endl; + ofs_running << "--------------- Warning: GPU is unaviable ---------------" << std::endl; + ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU; } } diff --git a/source/source_pw/module_ofdft/ml_base.h b/source/source_pw/module_ofdft/ml_base.h index 41c08d22327..3f7b04d582c 100644 --- a/source/source_pw/module_ofdft/ml_base.h +++ b/source/source_pw/module_ofdft/ml_base.h @@ -20,7 +20,7 @@ class ML_Base ~ML_Base(); // Common Interface - void set_device(std::string device_inpt); + void set_device(const std::string& device_inpt, std::ostream& ofs_running); // Tools void loadVector(std::string filename, std::vector &data); diff --git a/source/source_pw/module_ofdft/nn_of.cpp b/source/source_pw/module_ofdft/nn_of.cpp index 5aa81069587..cfe48bf1d86 100644 --- a/source/source_pw/module_ofdft/nn_of.cpp +++ b/source/source_pw/module_ofdft/nn_of.cpp @@ -1,14 +1,14 @@ #include "nn_of.h" -NN_OFImpl::NN_OFImpl(int nrxx, int nrxx_vali, int ninpt, int nnode, int nlayer, torch::Device device) +NN_OFImpl::NN_OFImpl(int nrxx, int nrxx_vali, int ninpt, int nnode, int nlayer, torch::Device device, std::ostream& ofs_running) { this->nrxx = nrxx; this->nrxx_vali = nrxx_vali; this->ninpt = ninpt; this->nnode = nnode; - std::cout << "nnode = " << this->nnode << std::endl; + ofs_running << "nnode = " << this->nnode << std::endl; this->nlayer = nlayer; - std::cout << "nlayer = " << this->nlayer << std::endl; + ofs_running << "nlayer = " << this->nlayer << std::endl; this->nfc = nlayer + 1; this->inputs = torch::zeros({this->nrxx, this->ninpt}).to(device); diff --git a/source/source_pw/module_ofdft/nn_of.h b/source/source_pw/module_ofdft/nn_of.h index 6623ddcfc80..6566dcae191 100644 --- a/source/source_pw/module_ofdft/nn_of.h +++ b/source/source_pw/module_ofdft/nn_of.h @@ -11,7 +11,8 @@ struct NN_OFImpl:torch::nn::Module{ int ninpt, int nnode, int nlayer, - torch::Device device + torch::Device device, + std::ostream& ofs_running ); ~NN_OFImpl() { From 4eda8d782c70325e9fc79e52f1f114acb3509534 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Tue, 26 May 2026 09:42:16 +0800 Subject: [PATCH 02/11] Refactor OFDFT ML KEDF logging output to use ofs_running stream Summary of changes: 1. Modified ML_Base::set_device() to accept std::ostream& ofs_running parameter instead of using std::cout directly 2. Updated KEDF_ML::set_para() to pass ofs_running through the call chain 3. Modified KEDF_ML::init_data() to accept ofs_running parameter 4. Updated NN_OFImpl constructor to accept ofs_running parameter for logging nnode/nlayer 5. Modified Cal_MLKEDF_Descriptors::set_para() to accept ofs_running parameter for logging nkernel 6. Updated ML_EXX class methods (set_para, init_data, localTest) to use ofs_running 7. Updated all call sites to pass GlobalV::ofs_running 8. Changed 'NN' to 'Neural Network' in device initialization messages 9. Fixed 'WARNING: ML >= TF' message in KEDF_Manager::get_energy() to use ofs_running 10. Reformatted KEDF_ML::set_para() and cal_tool->set_para() calls with one parameter per line All ML KEDF related output messages now write to the running log file instead of stdout. --- source/source_pw/module_ofdft/kedf_manager.cpp | 4 ++-- tests/integrate/Autotest.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/source_pw/module_ofdft/kedf_manager.cpp b/source/source_pw/module_ofdft/kedf_manager.cpp index 76a8f5ecbb4..a2f270da50a 100644 --- a/source/source_pw/module_ofdft/kedf_manager.cpp +++ b/source/source_pw/module_ofdft/kedf_manager.cpp @@ -267,8 +267,8 @@ double KEDF_Manager::get_energy() const kinetic_energy += this->ml_->ml_energy; if (this->ml_->ml_energy >= this->tf_->tf_energy) { - std::cout << "WARNING: ML >= TF" << std::endl; - std::cout << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; + GlobalV::ofs_running << "WARNING: ML >= TF" << std::endl; + GlobalV::ofs_running << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; } } #endif diff --git a/tests/integrate/Autotest.sh b/tests/integrate/Autotest.sh index dc466a096fc..b7e260afcd9 100755 --- a/tests/integrate/Autotest.sh +++ b/tests/integrate/Autotest.sh @@ -1,7 +1,7 @@ #!/bin/bash # ABACUS executable path -abacus=abacus +abacus=/home/510Group/2_abacus/abacus-mc/build_ml_para/abacus_ml_para # number of MPI processes np=4 nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS From 17c76950fe40f87fe5b306d05ef1cd595bdd1dfa Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 29 May 2026 17:28:57 +0800 Subject: [PATCH 03/11] fix --- tests/integrate/Autotest.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrate/Autotest.sh b/tests/integrate/Autotest.sh index b7e260afcd9..6eeb68cb720 100755 --- a/tests/integrate/Autotest.sh +++ b/tests/integrate/Autotest.sh @@ -1,7 +1,7 @@ #!/bin/bash # ABACUS executable path -abacus=/home/510Group/2_abacus/abacus-mc/build_ml_para/abacus_ml_para +abacus= # number of MPI processes np=4 nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS From 10c7c819223e2f1689db7d13227743162d16de22 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 30 May 2026 06:02:53 +0800 Subject: [PATCH 04/11] fix --- tests/integrate/Autotest.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrate/Autotest.sh b/tests/integrate/Autotest.sh index 6eeb68cb720..dc466a096fc 100755 --- a/tests/integrate/Autotest.sh +++ b/tests/integrate/Autotest.sh @@ -1,7 +1,7 @@ #!/bin/bash # ABACUS executable path -abacus= +abacus=abacus # number of MPI processes np=4 nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS From 3478047bd53afbd32a31de6e7ddc067f519a2bf6 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 08:51:54 +0800 Subject: [PATCH 05/11] update the output formats --- source/source_estate/module_pot/pot_ml_exx.cpp | 3 +-- source/source_estate/module_pot/pot_ml_exx.h | 3 +-- .../source_estate/module_pot/pot_ml_exx_label.cpp | 3 +-- source/source_pw/module_ofdft/kedf_ml.cpp | 12 ++++++++---- source/source_pw/module_ofdft/kedf_ml.h | 13 +++++++------ source/source_pw/module_ofdft/ml_base.cpp | 2 +- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/source/source_estate/module_pot/pot_ml_exx.cpp b/source/source_estate/module_pot/pot_ml_exx.cpp index b5c1fbcd598..df4b905f618 100644 --- a/source/source_estate/module_pot/pot_ml_exx.cpp +++ b/source/source_estate/module_pot/pot_ml_exx.cpp @@ -46,8 +46,7 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod inp.of_ml_tanh_pnl, inp.of_ml_tanh_qnl, inp.of_ml_tanhp_nl, - inp.of_ml_tanhq_nl, - ofs_running); + inp.of_ml_tanhq_nl); ofs_running << "ninput = " << this->ninput << std::endl; diff --git a/source/source_estate/module_pot/pot_ml_exx.h b/source/source_estate/module_pot/pot_ml_exx.h index a2a866801c2..2dff9d922f2 100644 --- a/source/source_estate/module_pot/pot_ml_exx.h +++ b/source/source_estate/module_pot/pot_ml_exx.h @@ -40,8 +40,7 @@ class ML_EXX : public ML_Base const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl, - std::ostream& ofs_running + const std::vector &of_ml_tanhq_nl ); double ml_exx_energy = 0.0; diff --git a/source/source_estate/module_pot/pot_ml_exx_label.cpp b/source/source_estate/module_pot/pot_ml_exx_label.cpp index a073ce58832..3908b7c5ef9 100644 --- a/source/source_estate/module_pot/pot_ml_exx_label.cpp +++ b/source/source_estate/module_pot/pot_ml_exx_label.cpp @@ -40,8 +40,7 @@ void ML_EXX::init_data( const std::vector &of_ml_tanh_pnl, const std::vector &of_ml_tanh_qnl, const std::vector &of_ml_tanhp_nl, - const std::vector &of_ml_tanhq_nl, - std::ostream& ofs_running + const std::vector &of_ml_tanhq_nl ) { diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index ec3f9c1ac1d..e62fd3f78f2 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -72,7 +72,7 @@ void KEDF_ML::set_para( of_ml_tanhq_nl, ofs_running); - ofs_running << "ninput = " << ninput << std::endl; + ofs_running << " ninput = " << ninput << std::endl; if (PARAM.inp.of_kinetic == "ml") { @@ -98,7 +98,7 @@ void KEDF_ML::set_para( this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - ofs_running << "feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << " feg_net_F = " << this->feg_net_F << std::endl; } } @@ -113,7 +113,9 @@ void KEDF_ML::set_para( this->chi_qnl = chi_qnl; this->cal_tool->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q, - chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho, ofs_running); + chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, + kernel_scaling, yukawa_alpha, kernel_file, + this->dV * pw_rho->nxyz, pw_rho, ofs_running); } } @@ -139,7 +141,6 @@ double KEDF_ML::get_energy(const double * const * prho, ModulePW::PW_Basis *pw_r { energy += enhancement_cpu_ptr[ir] * std::pow(prho[0][ir], this->energy_exponent); } - std::cout << "energy" << energy << std::endl; energy *= this->dV * this->energy_prefactor; this->ml_energy = energy; Parallel_Reduce::reduce_all(this->ml_energy); @@ -160,8 +161,11 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r this->NN_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); + this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); + torch::Tensor gradient_cpu_tensor = this->nn->inputs.grad().to(this->device_CPU).contiguous(); + this->gradient_cpu_ptr = gradient_cpu_tensor.data_ptr(); this->get_potential_(prho, pw_rho, rpotential); diff --git a/source/source_pw/module_ofdft/kedf_ml.h b/source/source_pw/module_ofdft/kedf_ml.h index d7acfb734ba..631bb413736 100644 --- a/source/source_pw/module_ofdft/kedf_ml.h +++ b/source/source_pw/module_ofdft/kedf_ml.h @@ -7,12 +7,14 @@ class KEDF_ML : public ML_Base { + public: + KEDF_ML() { - this->energy_prefactor = 3. /10. * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2) + this->energy_prefactor = 3. /10. * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; + // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2) this->energy_exponent = 5. / 3.; - // this->stress.create(3,3); } void set_para( @@ -51,16 +53,15 @@ class KEDF_ML : public ML_Base std::ostream& ofs_running); double get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho); - // double get_energy_density(const double * const *prho, int is, int ir, ModulePW::PW_Basis *pw_rho); + void ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); - // void get_stress(double cellVol, const double * const * prho, ModulePW::PW_Basis *pw_rho, double vw_weight); // output all parameters void generateTrainData(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff); + void localTest(const double * const *prho, ModulePW::PW_Basis *pw_rho); double ml_energy = 0.; - // ModuleBase::matrix stress; // maps void init_data( @@ -85,4 +86,4 @@ class KEDF_ML : public ML_Base }; #endif -#endif \ No newline at end of file +#endif diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index 54feecb261a..8a3daeab4ba 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -26,7 +26,7 @@ void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_runni } else { - ofs_running << "--------------- Warning: GPU is unaviable ---------------" << std::endl; + ofs_running << "--------------- Warning: GPU is unavailable ---------------" << std::endl; ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU; } From 809e7cb1b6b28bbe609c56801934e762ccc9cbd4 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 10:02:06 +0800 Subject: [PATCH 06/11] update KEDF --- source/source_pw/module_ofdft/kedf_extwt.cpp | 1 + source/source_pw/module_ofdft/kedf_lkt.cpp | 5 +- .../source_pw/module_ofdft/kedf_manager.cpp | 90 ++++++++--- source/source_pw/module_ofdft/kedf_ml.cpp | 8 +- .../source_pw/module_ofdft/kedf_ml_label.cpp | 145 ++++++++++++++---- source/source_pw/module_ofdft/kedf_tf.cpp | 1 + source/source_pw/module_ofdft/kedf_vw.cpp | 1 + source/source_pw/module_ofdft/kedf_wt.cpp | 1 + source/source_pw/module_ofdft/kedf_xwm.cpp | 1 + source/source_pw/module_ofdft/ml_base.cpp | 79 +++++++--- source/source_pw/module_ofdft/ml_base.h | 17 +- tests/07_OFDFT/06_OF_KE_MPN/INPUT | 8 +- 12 files changed, 268 insertions(+), 89 deletions(-) diff --git a/source/source_pw/module_ofdft/kedf_extwt.cpp b/source/source_pw/module_ofdft/kedf_extwt.cpp index be46416a2b4..924f8eb0123 100644 --- a/source/source_pw/module_ofdft/kedf_extwt.cpp +++ b/source/source_pw/module_ofdft/kedf_extwt.cpp @@ -265,6 +265,7 @@ void KEDF_ExtWT::tau_extwt(const double* const* prho, ModulePW::PW_Basis* pw_rho */ void KEDF_ExtWT::extwt_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential) { + ModuleBase::TITLE("KEDF_ExtWT", "extwt_potential"); ModuleBase::timer::start("KEDF_ExtWT", "extwt_potential"); // 1. WT potential diff --git a/source/source_pw/module_ofdft/kedf_lkt.cpp b/source/source_pw/module_ofdft/kedf_lkt.cpp index e35dca6d905..ec1299fd45f 100644 --- a/source/source_pw/module_ofdft/kedf_lkt.cpp +++ b/source/source_pw/module_ofdft/kedf_lkt.cpp @@ -132,7 +132,8 @@ void KEDF_LKT::tau_lkt(const double* const* prho, ModulePW::PW_Basis* pw_rho, do */ void KEDF_LKT::lkt_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential) { - ModuleBase::timer::start("KEDF_LKT", "LKT_potential"); + ModuleBase::TITLE("KEDF_LKT", "lkt_potential"); + ModuleBase::timer::start("KEDF_LKT", "lkt_potential"); this->lkt_energy = 0.; double* as = new double[pw_rho->nrxx]; // a*s double** nabla_rho = new double*[3]; @@ -193,7 +194,7 @@ void KEDF_LKT::lkt_potential(const double* const* prho, ModulePW::PW_Basis* pw_r delete[] nabla_rho; delete[] nabla_term; - ModuleBase::timer::end("KEDF_LKT", "LKT_potential"); + ModuleBase::timer::end("KEDF_LKT", "lkt_potential"); } /** diff --git a/source/source_pw/module_ofdft/kedf_manager.cpp b/source/source_pw/module_ofdft/kedf_manager.cpp index a2f270da50a..6caedf433bb 100644 --- a/source/source_pw/module_ofdft/kedf_manager.cpp +++ b/source/source_pw/module_ofdft/kedf_manager.cpp @@ -3,7 +3,7 @@ #include "source_io/module_parameter/parameter.h" /** - * @brief [Interface to kedf] + * @brief [Interface to KEDF] * Initialize the KEDFs. * * @param inp @@ -20,11 +20,11 @@ void KEDF_Manager::init( { this->of_kinetic_ = inp.of_kinetic; - //! Thomas-Fermi (TF) KEDF, TF+ KEDF, Want-Teter (WT) KEDF, and XWM KEDF + //! Thomas-Fermi (TF) KEDF, TF+ KEDF, Wang-Teter (WT) KEDF, and XWM KEDF if (this->of_kinetic_ == "tf" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" - || this->of_kinetic_ == "ext-wt" + || this->of_kinetic_ == "ext-wt" || this->of_kinetic_ == "ml" || this->of_kinetic_ == "xwm") { @@ -108,7 +108,9 @@ void KEDF_Manager::init( if (this->of_kinetic_ == "ml") { if (this->ml_ == nullptr) + { this->ml_ = new KEDF_ML(); + } this->ml_->set_para( pw_rho->nrxx, dV, @@ -163,14 +165,19 @@ void KEDF_Manager::get_potential( ModuleBase::matrix& rpot ) { + ModuleBase::TITLE("KEDF_Manager", "get_potential"); + ModuleBase::timer::start("KEDF_Manager", "get_potential"); #ifdef __MLALGO // for ML KEDF test if (PARAM.inp.of_ml_local_test) this->ml_->localTest(prho, pw_rho); #endif - if (this->of_kinetic_ == "tf" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" - || this->of_kinetic_ == "ext-wt" || this->of_kinetic_ == "xwm") + if (this->of_kinetic_ == "tf" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" + || this->of_kinetic_ == "ext-wt" + || this->of_kinetic_ == "xwm") { this->tf_->tf_potential(prho, rpot); } @@ -203,9 +210,13 @@ void KEDF_Manager::get_potential( } } - if (this->of_kinetic_ == "vw" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" + if (this->of_kinetic_ == "vw" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" || this->of_kinetic_ == "ext-wt" - || this->of_kinetic_ == "xwm" || this->of_kinetic_ == "lkt" || this->of_kinetic_ == "ml") + || this->of_kinetic_ == "xwm" + || this->of_kinetic_ == "lkt" + || this->of_kinetic_ == "ml") { this->vw_->vw_potential(pphi, pw_rho, rpot); } @@ -217,6 +228,8 @@ void KEDF_Manager::get_potential( this->extwt_->update_dkernel_deta(PARAM.inp.of_vw_weight, pw_rho); this->extwt_->extwt_potential(prho, pw_rho, rpot); } + + ModuleBase::timer::end("KEDF_Manager", "get_potential"); } /** @@ -227,17 +240,27 @@ void KEDF_Manager::get_potential( */ double KEDF_Manager::get_energy() const { + ModuleBase::TITLE("KEDF_Manager", "get_energy"); + ModuleBase::timer::start("KEDF_Manager", "get_energy"); + double kinetic_energy = 0.0; - if (this->of_kinetic_ == "tf" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" - || this->of_kinetic_ == "ext-wt" || this->of_kinetic_ == "xwm") + if (this->of_kinetic_ == "tf" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" + || this->of_kinetic_ == "ext-wt" + || this->of_kinetic_ == "xwm") { kinetic_energy += this->tf_->tf_energy; } - if (this->of_kinetic_ == "vw" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" + if (this->of_kinetic_ == "vw" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" || this->of_kinetic_ == "ext-wt" - || this->of_kinetic_ == "xwm" || this->of_kinetic_ == "lkt" || this->of_kinetic_ == "ml") + || this->of_kinetic_ == "xwm" + || this->of_kinetic_ == "lkt" + || this->of_kinetic_ == "ml") { kinetic_energy += this->vw_->vw_energy; } @@ -267,12 +290,15 @@ double KEDF_Manager::get_energy() const kinetic_energy += this->ml_->ml_energy; if (this->ml_->ml_energy >= this->tf_->tf_energy) { - GlobalV::ofs_running << "WARNING: ML >= TF" << std::endl; - GlobalV::ofs_running << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; + GlobalV::ofs_running << " WARNING: ML >= TF" << std::endl; + GlobalV::ofs_running << " ML Term = " << this->ml_->ml_energy + << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl; } } #endif + ModuleBase::timer::end("KEDF_Manager", "get_energy"); + return kinetic_energy; } @@ -292,19 +318,28 @@ void KEDF_Manager::get_energy_density( double** rtau ) { + ModuleBase::TITLE("KEDF_Manager", "get_energy_density"); + ModuleBase::timer::start("KEDF_Manager", "get_energy_density"); + for (int ir = 0; ir < pw_rho->nrxx; ++ir) { rtau[0][ir] = 0.0; } - if (this->of_kinetic_ == "tf" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" - || this->of_kinetic_ == "ext-wt" || this->of_kinetic_ == "xwm") + if (this->of_kinetic_ == "tf" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" + || this->of_kinetic_ == "ext-wt" + || this->of_kinetic_ == "xwm") { this->tf_->tau_tf(prho, rtau[0]); } - if (this->of_kinetic_ == "vw" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" + if (this->of_kinetic_ == "vw" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" || this->of_kinetic_ == "ext-wt" - || this->of_kinetic_ == "xwm" || this->of_kinetic_ == "lkt") + || this->of_kinetic_ == "xwm" + || this->of_kinetic_ == "lkt") { this->vw_->tau_vw(pphi, pw_rho, rtau[0]); } @@ -324,6 +359,8 @@ void KEDF_Manager::get_energy_density( { this->lkt_->tau_lkt(prho, pw_rho, rtau[0]); } + + ModuleBase::timer::end("KEDF_Manager", "get_energy_density"); } /** @@ -344,6 +381,9 @@ void KEDF_Manager::get_stress( ModuleBase::matrix& kinetic_stress_ ) { + ModuleBase::TITLE("KEDF_Manager", "get_stress"); + ModuleBase::timer::start("KEDF_Manager", "get_stress"); + for (int i = 0; i < 3; ++i) { for (int j = 0; j < 3; ++j) @@ -352,16 +392,22 @@ void KEDF_Manager::get_stress( } } - if (this->of_kinetic_ == "tf" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" - || this->of_kinetic_ == "ext-wt" || this->of_kinetic_ == "xwm") + if (this->of_kinetic_ == "tf" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" + || this->of_kinetic_ == "ext-wt" + || this->of_kinetic_ == "xwm") { this->tf_->get_stress(omega); kinetic_stress_ += this->tf_->stress; } - if (this->of_kinetic_ == "vw" || this->of_kinetic_ == "tf+" || this->of_kinetic_ == "wt" + if (this->of_kinetic_ == "vw" + || this->of_kinetic_ == "tf+" + || this->of_kinetic_ == "wt" || this->of_kinetic_ == "ext-wt" - || this->of_kinetic_ == "xwm" || this->of_kinetic_ == "lkt") + || this->of_kinetic_ == "xwm" + || this->of_kinetic_ == "lkt") { this->vw_->get_stress(pphi, pw_rho); kinetic_stress_ += this->vw_->stress; @@ -394,6 +440,8 @@ void KEDF_Manager::get_stress( { std::cout << "Sorry, the stress of MPN KEDF is not yet supported." << std::endl; } + + ModuleBase::timer::end("KEDF_Manager", "get_stress"); } void KEDF_Manager::record_energy( diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index e62fd3f78f2..9ee0edfc9de 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -156,6 +156,9 @@ double KEDF_ML::get_energy(const double * const * prho, ModulePW::PW_Basis *pw_r */ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential) { + ModuleBase::TITLE("KEDF_ML", "ml_potential"); + ModuleBase::timer::start("KEDF_ML", "pauli_energy"); + this->updateInput(prho, pw_rho); this->NN_forward(prho, pw_rho, true); @@ -170,8 +173,6 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r this->get_potential_(prho, pw_rho, rpotential); - // get energy - ModuleBase::timer::start("KEDF_ML", "Pauli Energy"); double energy = 0.; for (int ir = 0; ir < this->nx; ++ir) { @@ -180,7 +181,8 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r energy *= this->dV * this->energy_prefactor; this->ml_energy = energy; Parallel_Reduce::reduce_all(this->ml_energy); - ModuleBase::timer::end("KEDF_ML", "Pauli Energy"); + + ModuleBase::timer::end("KEDF_ML", "pauli_energy"); } /** diff --git a/source/source_pw/module_ofdft/kedf_ml_label.cpp b/source/source_pw/module_ofdft/kedf_ml_label.cpp index 0cce3d0cec3..3baf3fb1a83 100644 --- a/source/source_pw/module_ofdft/kedf_ml_label.cpp +++ b/source/source_pw/module_ofdft/kedf_ml_label.cpp @@ -1,6 +1,7 @@ #ifdef __MLALGO #include "kedf_ml.h" +#include /** * @brief Initialize the data for ML KEDF, and generate the mapping between descriptor and kernel @@ -42,21 +43,26 @@ void KEDF_ML::init_data( std::ostream& ofs_running ) { + ModuleBase::TITLE("KEDF_ML", "init_data"); + ModuleBase::timer::start("KEDF_ML", "init_data"); this->ninput = 0; // --------- semi-local descriptors --------- - if (of_ml_gamma){ + if (of_ml_gamma) + { this->descriptor_type.push_back("gamma"); this->kernel_index.push_back(-1); ninput++; } - if (of_ml_p){ + if (of_ml_p) + { this->descriptor_type.push_back("p"); this->kernel_index.push_back(-1); ninput++; } - if (of_ml_q){ + if (of_ml_q) + { this->descriptor_type.push_back("q"); this->kernel_index.push_back(-1); ninput++; @@ -64,44 +70,52 @@ void KEDF_ML::init_data( // --------- non-local descriptors --------- for (int ik = 0; ik < nkernel; ++ik) { - if (of_ml_gammanl[ik]){ + if (of_ml_gammanl[ik]) + { this->descriptor_type.push_back("gammanl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_pnl[ik]){ + if (of_ml_pnl[ik]) + { this->descriptor_type.push_back("pnl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_qnl[ik]){ + if (of_ml_qnl[ik]) + { this->descriptor_type.push_back("qnl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_xi[ik]){ + if (of_ml_xi[ik]) + { this->descriptor_type.push_back("xi"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_tanhxi[ik]){ + if (of_ml_tanhxi[ik]) + { this->descriptor_type.push_back("tanhxi"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_tanhxi_nl[ik]){ + if (of_ml_tanhxi_nl[ik]) + { this->descriptor_type.push_back("tanhxi_nl"); this->kernel_index.push_back(ik); this->ninput++; } } // --------- semi-local descriptors --------- - if (of_ml_tanhp){ + if (of_ml_tanhp) + { this->descriptor_type.push_back("tanhp"); this->kernel_index.push_back(-1); ninput++; } - if (of_ml_tanhq){ + if (of_ml_tanhq) + { this->descriptor_type.push_back("tanhq"); this->kernel_index.push_back(-1); ninput++; @@ -109,22 +123,26 @@ void KEDF_ML::init_data( // --------- non-local descriptors --------- for (int ik = 0; ik < nkernel; ++ik) { - if (of_ml_tanh_pnl[ik]){ + if (of_ml_tanh_pnl[ik]) + { this->descriptor_type.push_back("tanh_pnl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_tanh_qnl[ik]){ + if (of_ml_tanh_qnl[ik]) + { this->descriptor_type.push_back("tanh_qnl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_tanhp_nl[ik]){ + if (of_ml_tanhp_nl[ik]) + { this->descriptor_type.push_back("tanhp_nl"); this->kernel_index.push_back(ik); this->ninput++; } - if (of_ml_tanhq_nl[ik]){ + if (of_ml_tanhq_nl[ik]) + { this->descriptor_type.push_back("tanhq_nl"); this->kernel_index.push_back(ik); this->ninput++; @@ -153,8 +171,52 @@ void KEDF_ML::init_data( this->descriptor2kernel[descriptor_type[i]].push_back(kernel_index[i]); this->descriptor2index[descriptor_type[i]].push_back(i); } - ofs_running << "descriptor2index " << descriptor2index << std::endl; - ofs_running << "descriptor2kernel " << descriptor2kernel << std::endl; + + ofs_running << "\n ------------------- ML-KEDF Reference -------------------\n"; + ofs_running << " Liang Sun and Mohan Chen*, \"Multi-channel machine learning based \n"; + ofs_running << " nonlocal kinetic energy density functional for semiconductors,\"\n"; + ofs_running << " Electronic Structure, 6, 045006 (2024).\n"; + ofs_running << " ---------------------------------------------------------\n"; + ofs_running << "\n ------------------- Descriptor Mapping -------------------\n"; + ofs_running << " Legend:\n"; + ofs_running << " descriptor2index: [indices] in neural network input vector\n"; + ofs_running << " descriptor2kernel: [-1]=no kernel (semi-local), [N]=use kernel N (non-local)\n"; + ofs_running << " Semi-local descriptors: gamma, p, q, tanhp, tanhq (no kernel needed)\n"; + ofs_running << " Non-local descriptors: gammanl, pnl, qnl, xi, tanhxi, etc. (need kernel)\n"; + ofs_running << " ---------------------------------------------------------\n"; + ofs_running << " descriptor2index (input vector positions):\n"; + + for (const auto& pair : this->descriptor2index) + { + ofs_running << " " << std::setw(15) << std::left << pair.first << " : ["; + for (size_t i = 0; i < pair.second.size(); ++i) + { + ofs_running << pair.second[i]; + if (i < pair.second.size() - 1) + { + ofs_running << ", "; + } + } + ofs_running << "]\n"; + } + + ofs_running << " ---------------------------------------------------------\n"; + ofs_running << " descriptor2kernel (kernel indices):\n"; + + for (const auto& pair : this->descriptor2kernel) + { + ofs_running << " " << std::setw(15) << std::left << pair.first << " : ["; + for (size_t i = 0; i < pair.second.size(); ++i) + { + ofs_running << pair.second[i]; + if (i < pair.second.size() - 1) + { + ofs_running << ", "; + } + } + ofs_running << "]\n"; + } + ofs_running << " ---------------------------------------------------------\n"; this->ml_gamma = this->descriptor2index["gamma"].size() > 0; this->ml_p = this->descriptor2index["p"].size() > 0; @@ -249,20 +311,25 @@ void KEDF_ML::init_data( this->gene_data_label["q"][0] = of_ml_q || this->gene_data_label["tanhq"][0] || gene_qnl_tot; - if (this->gene_data_label["gamma"][0]){ + if (this->gene_data_label["gamma"][0]) + { this->gamma = std::vector(this->nx, 0.); } - if (this->gene_data_label["p"][0]){ + if (this->gene_data_label["p"][0]) + { this->nablaRho = std::vector >(3, std::vector(this->nx, 0.)); this->p = std::vector(this->nx, 0.); } - if (this->gene_data_label["q"][0]){ + if (this->gene_data_label["q"][0]) + { this->q = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhp"][0]){ + if (this->gene_data_label["tanhp"][0]) + { this->tanhp = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhq"][0]){ + if (this->gene_data_label["tanhq"][0]) + { this->tanhq = std::vector(this->nx, 0.); } @@ -279,36 +346,48 @@ void KEDF_ML::init_data( this->tanhp_nl.push_back({}); this->tanhq_nl.push_back({}); - if (this->gene_data_label["gammanl"][ik]){ + if (this->gene_data_label["gammanl"][ik]) + { this->gammanl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["pnl"][ik]){ + if (this->gene_data_label["pnl"][ik]) + { this->pnl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["qnl"][ik]){ + if (this->gene_data_label["qnl"][ik]) + { this->qnl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["xi"][ik]){ + if (this->gene_data_label["xi"][ik]) + { this->xi[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhxi"][ik]){ + if (this->gene_data_label["tanhxi"][ik]) + { this->tanhxi[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhxi_nl"][ik]){ + if (this->gene_data_label["tanhxi_nl"][ik]) + { this->tanhxi_nl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanh_pnl"][ik]){ + if (this->gene_data_label["tanh_pnl"][ik]) + { this->tanh_pnl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanh_qnl"][ik]){ + if (this->gene_data_label["tanh_qnl"][ik]) + { this->tanh_qnl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhp_nl"][ik]){ + if (this->gene_data_label["tanhp_nl"][ik]) + { this->tanhp_nl[ik] = std::vector(this->nx, 0.); } - if (this->gene_data_label["tanhq_nl"][ik]){ + if (this->gene_data_label["tanhq_nl"][ik]) + { this->tanhq_nl[ik] = std::vector(this->nx, 0.); } } + + ModuleBase::timer::end("KEDF_ML", "init_data"); } -#endif \ No newline at end of file +#endif diff --git a/source/source_pw/module_ofdft/kedf_tf.cpp b/source/source_pw/module_ofdft/kedf_tf.cpp index a84be35887d..c5bf62173a5 100644 --- a/source/source_pw/module_ofdft/kedf_tf.cpp +++ b/source/source_pw/module_ofdft/kedf_tf.cpp @@ -100,6 +100,7 @@ void KEDF_TF::tau_tf(const double* const* prho, double* rtau_tf) */ void KEDF_TF::tf_potential(const double* const* prho, ModuleBase::matrix& rpotential) { + ModuleBase::TITLE("KEDF_TF", "tf_potential"); ModuleBase::timer::start("KEDF_TF", "tf_potential"); if (PARAM.inp.nspin == 1) { diff --git a/source/source_pw/module_ofdft/kedf_vw.cpp b/source/source_pw/module_ofdft/kedf_vw.cpp index 33bd3bb5a12..3f272bf1e15 100644 --- a/source/source_pw/module_ofdft/kedf_vw.cpp +++ b/source/source_pw/module_ofdft/kedf_vw.cpp @@ -171,6 +171,7 @@ void KEDF_vW::tau_vw(const double* const* pphi, ModulePW::PW_Basis* pw_rho, doub */ void KEDF_vW::vw_potential(const double* const* pphi, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential) { + ModuleBase::TITLE("KEDF_vW", "vw_potential"); ModuleBase::timer::start("KEDF_vW", "vw_potential"); // since pphi may contain minus element, we define tempPhi = std::abs(phi), which is true sqrt(rho) diff --git a/source/source_pw/module_ofdft/kedf_wt.cpp b/source/source_pw/module_ofdft/kedf_wt.cpp index f9d99aef09e..21f4f1f2b48 100644 --- a/source/source_pw/module_ofdft/kedf_wt.cpp +++ b/source/source_pw/module_ofdft/kedf_wt.cpp @@ -188,6 +188,7 @@ void KEDF_WT::tau_wt(const double* const* prho, ModulePW::PW_Basis* pw_rho, doub */ void KEDF_WT::wt_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential) { + ModuleBase::TITLE("KEDF_WT", "wt_potential"); ModuleBase::timer::start("KEDF_WT", "wt_potential"); double** kernelRhoBeta = new double*[PARAM.inp.nspin]; diff --git a/source/source_pw/module_ofdft/kedf_xwm.cpp b/source/source_pw/module_ofdft/kedf_xwm.cpp index cacd3cdf23b..4f550af0037 100644 --- a/source/source_pw/module_ofdft/kedf_xwm.cpp +++ b/source/source_pw/module_ofdft/kedf_xwm.cpp @@ -202,6 +202,7 @@ void KEDF_XWM::tau_xwm(const double* const* prho, ModulePW::PW_Basis* pw_rho, do */ void KEDF_XWM::xwm_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential) { + ModuleBase::TITLE("KEDF_XWM", "xwm_potential"); ModuleBase::timer::start("KEDF_XWM", "xwm_potential"); double** w1Rho5_6 = new double*[PARAM.inp.nspin]; for (int is = 0; is < PARAM.inp.nspin; ++is) diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index 8a3daeab4ba..acf75249a40 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -12,6 +12,7 @@ ML_Base::~ML_Base() void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_running) { + ModuleBase::TITLE("ML_Base", "set_device"); if (device_inpt == "cpu") { ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; @@ -36,7 +37,8 @@ void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_runni void ML_Base::updateInput(const double * const * prho, const ModulePW::PW_Basis *pw_rho) { - ModuleBase::timer::start("ML_Base", "updateInput"); + ModuleBase::TITLE("ML_Base", "update_input"); + ModuleBase::timer::start("ML_Base", "update_input"); if (this->gene_data_label["gamma"][0]) { this->cal_tool->getGamma(prho, this->gamma); @@ -92,12 +94,13 @@ void ML_Base::updateInput(const double * const * prho, const ModulePW::PW_Basis this->cal_tool->getTanhQ_nl(ik, this->tanhq, pw_rho, this->tanhq_nl[ik]); } } - ModuleBase::timer::end("ML_Base", "updateInput"); + ModuleBase::timer::end("ML_Base", "update_input"); } void ML_Base::NN_forward(const double * const * prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad) { - ModuleBase::timer::start("ML_Base", "Forward"); + ModuleBase::TITLE("ML_Base", "nn_forward"); + ModuleBase::timer::start("ML_Base", "forward"); this->nn->zero_grad(); this->nn->inputs.requires_grad_(false); @@ -122,13 +125,13 @@ void ML_Base::NN_forward(const double * const * prho, const ModulePW::PW_Basis * { this->nn->F = torch::softplus(this->nn->F - this->feg_net_F + this->feg3_correct); } - ModuleBase::timer::end("ML_Base", "Forward"); + ModuleBase::timer::end("ML_Base", "forward"); if (cal_grad) { - ModuleBase::timer::start("ML_Base", "Backward"); + ModuleBase::timer::start("ML_Base", "backward"); this->nn->F.backward(torch::ones({this->nx, 1}, this->device_type)); - ModuleBase::timer::end("ML_Base", "Backward"); + ModuleBase::timer::end("ML_Base", "backward"); } } @@ -154,7 +157,8 @@ torch::Tensor ML_Base::get_data(std::string parameter, const int ikernel) const void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential) { - ModuleBase::timer::start("ML_Base", "Pauli Potential"); + ModuleBase::TITLE("ML_Base", "get_potential_"); + ModuleBase::timer::start("ML_Base", "pauli_potential"); std::vector pauli_potential(this->nx, 0.); std::vector tau_lda(this->nx, 0.); // Dummy or calculated inside @@ -163,26 +167,61 @@ void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Bas tau_lda[ir] = this->energy_prefactor * std::pow(prho[0][ir], this->energy_exponent); } - if (this->ml_gammanl) this->potGammanlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_xi) this->potXinlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_tanhxi) this->potTanhxinlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_tanhxi_nl) this->potTanhxi_nlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_p || this->ml_pnl) this->potPPnlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_q || this->ml_qnl) this->potQQnlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_tanh_pnl) this->potTanhpTanh_pnlTerm(prho, tau_lda, pw_rho, pauli_potential); - if (this->ml_tanh_qnl) this->potTanhqTanh_qnlTerm(prho, tau_lda, pw_rho, pauli_potential); - if ((this->ml_tanhp || this->ml_tanhp_nl) && !this->ml_tanh_pnl) this->potTanhpTanhp_nlTerm(prho, tau_lda, pw_rho, pauli_potential); - if ((this->ml_tanhq || this->ml_tanhq_nl) && !this->ml_tanh_qnl) this->potTanhqTanhq_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + if (this->ml_gammanl) + { + this->potGammanlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_xi) + { + this->potXinlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_tanhxi) + { + this->potTanhxinlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_tanhxi_nl) + { + this->potTanhxi_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_p || this->ml_pnl) + { + this->potPPnlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_q || this->ml_qnl) + { + this->potQQnlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_tanh_pnl) + { + this->potTanhpTanh_pnlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if (this->ml_tanh_qnl) + { + this->potTanhqTanh_qnlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if ((this->ml_tanhp || this->ml_tanhp_nl) && !this->ml_tanh_pnl) + { + this->potTanhpTanhp_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + } + if ((this->ml_tanhq || this->ml_tanhq_nl) && !this->ml_tanh_qnl) + { + this->potTanhqTanhq_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + } for (int ir = 0; ir < this->nx; ++ir) { double factor = tau_lda[ir] / prho[0][ir]; + pauli_potential[ir] += factor * - (this->energy_exponent * this->enhancement_cpu_ptr[ir] + this->potGammaTerm(ir) + this->potPTerm1(ir) + this->potQTerm1(ir) - + this->potXiTerm1(ir) + this->potTanhxiTerm1(ir) + this->potTanhpTerm1(ir) + this->potTanhqTerm1(ir)); + (this->energy_exponent * this->enhancement_cpu_ptr[ir] + + this->potGammaTerm(ir) + this->potPTerm1(ir) + this->potQTerm1(ir) + + this->potXiTerm1(ir) + this->potTanhxiTerm1(ir) + + this->potTanhpTerm1(ir) + this->potTanhqTerm1(ir)); + rpotential(0, ir) += pauli_potential[ir]; } - ModuleBase::timer::end("ML_Base", "Pauli Potential"); + + ModuleBase::timer::end("ML_Base", "pauli_potential"); } // IO tools diff --git a/source/source_pw/module_ofdft/ml_base.h b/source/source_pw/module_ofdft/ml_base.h index 3f7b04d582c..8f6cb093f91 100644 --- a/source/source_pw/module_ofdft/ml_base.h +++ b/source/source_pw/module_ofdft/ml_base.h @@ -132,12 +132,17 @@ class ML_Base bool ml_tanhp_nl = false; bool ml_tanhq_nl = false; - // Maps - std::vector descriptor_type; - std::vector kernel_index; - std::map> descriptor2kernel; - std::map> descriptor2index; - std::map> gene_data_label; + // Maps for descriptor management + std::vector descriptor_type; // List of enabled descriptors (e.g., "gamma", "pnl", "tanhxi") + std::vector kernel_index; // Kernel index for each descriptor (-1 = no kernel for semi-local) + std::map> descriptor2kernel; // Maps descriptor name to its kernel index(s) + // - []: descriptor not enabled + // - [-1]: semi-local descriptor (no kernel needed) + // - [N]: non-local descriptor using kernel N + std::map> descriptor2index; // Maps descriptor name to its position(s) in NN input vector + // - []: descriptor not enabled + // - [0, 1, ...]: indices in input vector + std::map> gene_data_label; // Flags indicating whether to compute each descriptor }; #endif // __MLALGO diff --git a/tests/07_OFDFT/06_OF_KE_MPN/INPUT b/tests/07_OFDFT/06_OF_KE_MPN/INPUT index e846e2ce174..e4602237b7d 100644 --- a/tests/07_OFDFT/06_OF_KE_MPN/INPUT +++ b/tests/07_OFDFT/06_OF_KE_MPN/INPUT @@ -28,7 +28,7 @@ of_ml_device cpu #of_ml_chi_qnl 0.1 # #of_ml_tanhxi 1 -#of_ml_tanhxi_nl 1 -#of_ml_tanhp 1 -#of_ml_tanhp_nl 1 -#of_ml_feg 3 +#of_ml_tanhxi_nl 1 +#of_ml_tanhp 1 +#of_ml_tanhp_nl 1 +#of_ml_feg 3 From f860b2ba94e6ab7e32e34f6a3b07041aac19f99b Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 10:12:09 +0800 Subject: [PATCH 07/11] output format update --- source/source_io/module_ml/cal_mlkedf_descriptors.cpp | 1 - source/source_pw/module_ofdft/kedf_ml.cpp | 7 ++++--- source/source_pw/module_ofdft/nn_of.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp index 4a1e3c86da6..22a7b794d43 100644 --- a/source/source_io/module_ml/cal_mlkedf_descriptors.cpp +++ b/source/source_io/module_ml/cal_mlkedf_descriptors.cpp @@ -35,7 +35,6 @@ void Cal_MLKEDF_Descriptors::set_para( this->kernel_scaling = kernel_scaling; this->yukawa_alpha = yukawa_alpha; this->kernel_file = kernel_file; - ofs_running << "nkernel = " << nkernel << std::endl; if (PARAM.inp.of_wt_rho0 != 0) { diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index 9ee0edfc9de..ab20fdb2348 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -72,7 +72,8 @@ void KEDF_ML::set_para( of_ml_tanhq_nl, ofs_running); - ofs_running << " ninput = " << ninput << std::endl; + ofs_running << " ninput = " << ninput << " (number of descriptors)" << std::endl; + ofs_running << " nkernel = " << this->nkernel << " (number of kernel functions)" << std::endl; if (PARAM.inp.of_kinetic == "ml") { @@ -80,7 +81,7 @@ void KEDF_ML::set_para( int nlayer = 3; this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); torch::load(this->nn, "net.pt", this->device_type); - ofs_running << "load net done" << std::endl; + ofs_running << " load net done (neural network loaded successfully)" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -98,7 +99,7 @@ void KEDF_ML::set_para( this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - ofs_running << " feg_net_F = " << this->feg_net_F << std::endl; + ofs_running << " feg_net_F = " << this->feg_net_F << " (Fermi energy guess factor)" << std::endl << std::endl; } } diff --git a/source/source_pw/module_ofdft/nn_of.cpp b/source/source_pw/module_ofdft/nn_of.cpp index cfe48bf1d86..1015c1af8a2 100644 --- a/source/source_pw/module_ofdft/nn_of.cpp +++ b/source/source_pw/module_ofdft/nn_of.cpp @@ -6,9 +6,9 @@ NN_OFImpl::NN_OFImpl(int nrxx, int nrxx_vali, int ninpt, int nnode, int nlayer, this->nrxx_vali = nrxx_vali; this->ninpt = ninpt; this->nnode = nnode; - ofs_running << "nnode = " << this->nnode << std::endl; + ofs_running << " nnode = " << this->nnode << " (number of nodes per hidden layer)" << std::endl; this->nlayer = nlayer; - ofs_running << "nlayer = " << this->nlayer << std::endl; + ofs_running << " nlayer = " << this->nlayer << " (number of hidden layers)" << std::endl; this->nfc = nlayer + 1; this->inputs = torch::zeros({this->nrxx, this->ninpt}).to(device); From f9e6c87a328298239779f2e9346ec4b79735e6ba Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 10:36:12 +0800 Subject: [PATCH 08/11] update --- source/source_pw/module_ofdft/kedf_ml.cpp | 26 ++-- source/source_pw/module_ofdft/ml_base.cpp | 130 ++++++++++++------ source/source_pw/module_ofdft/ml_base.h | 130 ++++++++++++++---- source/source_pw/module_ofdft/ml_base_pot.cpp | 34 ++--- 4 files changed, 221 insertions(+), 99 deletions(-) diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index ab20fdb2348..2689000f285 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -130,9 +130,9 @@ void KEDF_ML::set_para( */ double KEDF_ML::get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho) { - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, false); + this->nn_forward(prho, pw_rho, false); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -160,9 +160,9 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r ModuleBase::TITLE("KEDF_ML", "ml_potential"); ModuleBase::timer::start("KEDF_ML", "pauli_energy"); - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, true); + this->nn_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); @@ -200,9 +200,9 @@ void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis * // this->cal_tool->generateTrainData_WT(prho, wt, tf, pw_rho, veff); // Will be fixed in next pr if (PARAM.inp.of_kinetic == "ml") { - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, true); + this->nn_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -214,8 +214,8 @@ void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis * this->get_potential_(prho, pw_rho, potential); - this->dumpTensor("enhancement.npy", enhancement); - this->dumpMatrix("potential.npy", potential); + this->dump_tensor("enhancement.npy", enhancement); + this->dump_matrix("potential.npy", potential); } } @@ -232,7 +232,7 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho) bool fortran_order = false; std::vector temp_prho(this->nx); - this->loadVector("dir_of_input_rho", temp_prho); + this->load_vector("dir_of_input_rho", temp_prho); double ** prho = new double *[1]; prho[0] = new double[this->nx]; for (int ir = 0; ir < this->nx; ++ir) prho[0][ir] = temp_prho[ir]; @@ -244,9 +244,9 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho) }; // ============================== - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, true); + this->nn_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -258,8 +258,8 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho) this->get_potential_(prho, pw_rho, potential); - this->dumpTensor("enhancement-abacus.npy", enhancement); - this->dumpMatrix("potential-abacus.npy", potential); + this->dump_tensor("enhancement-abacus.npy", enhancement); + this->dump_matrix("potential-abacus.npy", potential); exit(0); } #endif diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index acf75249a40..5eefec0be26 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -35,7 +35,7 @@ void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_runni this->device = torch::Device(this->device_type); } -void ML_Base::updateInput(const double * const * prho, const ModulePW::PW_Basis *pw_rho) +void ML_Base::update_input(const double * const * prho, const ModulePW::PW_Basis *pw_rho) { ModuleBase::TITLE("ML_Base", "update_input"); ModuleBase::timer::start("ML_Base", "update_input"); @@ -97,10 +97,10 @@ void ML_Base::updateInput(const double * const * prho, const ModulePW::PW_Basis ModuleBase::timer::end("ML_Base", "update_input"); } -void ML_Base::NN_forward(const double * const * prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad) +void ML_Base::nn_forward(const double * const * prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad) { ModuleBase::TITLE("ML_Base", "nn_forward"); - ModuleBase::timer::start("ML_Base", "forward"); + ModuleBase::timer::start("ML_Base", "nn_forward"); this->nn->zero_grad(); this->nn->inputs.requires_grad_(false); @@ -125,7 +125,7 @@ void ML_Base::NN_forward(const double * const * prho, const ModulePW::PW_Basis * { this->nn->F = torch::softplus(this->nn->F - this->feg_net_F + this->feg3_correct); } - ModuleBase::timer::end("ML_Base", "forward"); + ModuleBase::timer::end("ML_Base", "nn_forward"); if (cal_grad) { @@ -135,24 +135,72 @@ void ML_Base::NN_forward(const double * const * prho, const ModulePW::PW_Basis * } } -torch::Tensor ML_Base::get_data(std::string parameter, const int ikernel) const { - - if (parameter == "gamma") return torch::tensor(this->gamma, this->device_type); - if (parameter == "p") return torch::tensor(this->p, this->device_type); - if (parameter == "q") return torch::tensor(this->q, this->device_type); - if (parameter == "tanhp") return torch::tensor(this->tanhp, this->device_type); - if (parameter == "tanhq") return torch::tensor(this->tanhq, this->device_type); - if (parameter == "gammanl") return torch::tensor(this->gammanl[ikernel], this->device_type); - if (parameter == "pnl") return torch::tensor(this->pnl[ikernel], this->device_type); - if (parameter == "qnl") return torch::tensor(this->qnl[ikernel], this->device_type); - if (parameter == "xi") return torch::tensor(this->xi[ikernel], this->device_type); - if (parameter == "tanhxi") return torch::tensor(this->tanhxi[ikernel], this->device_type); - if (parameter == "tanhxi_nl") return torch::tensor(this->tanhxi_nl[ikernel], this->device_type); - if (parameter == "tanh_pnl") return torch::tensor(this->tanh_pnl[ikernel], this->device_type); - if (parameter == "tanh_qnl") return torch::tensor(this->tanh_qnl[ikernel], this->device_type); - if (parameter == "tanhp_nl") return torch::tensor(this->tanhp_nl[ikernel], this->device_type); - if (parameter == "tanhq_nl") return torch::tensor(this->tanhq_nl[ikernel], this->device_type); - return torch::zeros({}); +torch::Tensor ML_Base::get_data(std::string parameter, const int ikernel) const +{ + if (parameter == "gamma") + { + return torch::tensor(this->gamma, this->device_type); + } + else if (parameter == "p") + { + return torch::tensor(this->p, this->device_type); + } + else if (parameter == "q") + { + return torch::tensor(this->q, this->device_type); + } + else if (parameter == "tanhp") + { + return torch::tensor(this->tanhp, this->device_type); + } + else if (parameter == "tanhq") + { + return torch::tensor(this->tanhq, this->device_type); + } + else if (parameter == "gammanl") + { + return torch::tensor(this->gammanl[ikernel], this->device_type); + } + else if (parameter == "pnl") + { + return torch::tensor(this->pnl[ikernel], this->device_type); + } + else if (parameter == "qnl") + { + return torch::tensor(this->qnl[ikernel], this->device_type); + } + else if (parameter == "xi") + { + return torch::tensor(this->xi[ikernel], this->device_type); + } + else if (parameter == "tanhxi") + { + return torch::tensor(this->tanhxi[ikernel], this->device_type); + } + else if (parameter == "tanhxi_nl") + { + return torch::tensor(this->tanhxi_nl[ikernel], this->device_type); + } + else if (parameter == "tanh_pnl") + { + return torch::tensor(this->tanh_pnl[ikernel], this->device_type); + } + else if (parameter == "tanh_qnl") + { + return torch::tensor(this->tanh_qnl[ikernel], this->device_type); + } + else if (parameter == "tanhp_nl") + { + return torch::tensor(this->tanhp_nl[ikernel], this->device_type); + } + else if (parameter == "tanhq_nl") + { + return torch::tensor(this->tanhq_nl[ikernel], this->device_type); + } + else + { + return torch::zeros({}); + } } void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential) @@ -169,43 +217,43 @@ void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Bas if (this->ml_gammanl) { - this->potGammanlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_gammanl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_xi) { - this->potXinlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_xi_nl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_tanhxi) { - this->potTanhxinlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhxi_nl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_tanhxi_nl) { - this->potTanhxi_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhxi_nl_nl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_p || this->ml_pnl) { - this->potPPnlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_p_pnl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_q || this->ml_qnl) { - this->potQQnlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_q_qnl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_tanh_pnl) { - this->potTanhpTanh_pnlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhp_tanh_pnl_term(prho, tau_lda, pw_rho, pauli_potential); } if (this->ml_tanh_qnl) { - this->potTanhqTanh_qnlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhq_tanh_qnl_term(prho, tau_lda, pw_rho, pauli_potential); } if ((this->ml_tanhp || this->ml_tanhp_nl) && !this->ml_tanh_pnl) { - this->potTanhpTanhp_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhp_tanhp_nl_term(prho, tau_lda, pw_rho, pauli_potential); } if ((this->ml_tanhq || this->ml_tanhq_nl) && !this->ml_tanh_qnl) { - this->potTanhqTanhq_nlTerm(prho, tau_lda, pw_rho, pauli_potential); + this->pot_tanhq_tanhq_nl_term(prho, tau_lda, pw_rho, pauli_potential); } for (int ir = 0; ir < this->nx; ++ir) @@ -214,9 +262,9 @@ void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Bas pauli_potential[ir] += factor * (this->energy_exponent * this->enhancement_cpu_ptr[ir] - + this->potGammaTerm(ir) + this->potPTerm1(ir) + this->potQTerm1(ir) - + this->potXiTerm1(ir) + this->potTanhxiTerm1(ir) - + this->potTanhpTerm1(ir) + this->potTanhqTerm1(ir)); + + this->pot_gamma_term(ir) + this->pot_p_term_1(ir) + this->pot_q_term_1(ir) + + this->pot_xi_term_1(ir) + this->pot_tanhxi_term_1(ir) + + this->pot_tanhp_term_1(ir) + this->pot_tanhq_term_1(ir)); rpotential(0, ir) += pauli_potential[ir]; } @@ -225,13 +273,13 @@ void ML_Base::get_potential_(const double * const * prho, const ModulePW::PW_Bas } // IO tools -void ML_Base::loadVector(std::string filename, std::vector &data) +void ML_Base::load_vector(std::string filename, std::vector &data) { npy::npy_data d = npy::read_npy(filename); data = d.data; } -void ML_Base::dumpVector(std::string filename, const std::vector &data) +void ML_Base::dump_vector(std::string filename, const std::vector &data) { npy::npy_data_ptr d; d.data_ptr = data.data(); @@ -240,19 +288,19 @@ void ML_Base::dumpVector(std::string filename, const std::vector &data) npy::write_npy(filename, d); } -void ML_Base::dumpTensor(std::string filename, const torch::Tensor &data) +void ML_Base::dump_tensor(std::string filename, const torch::Tensor &data) { std::cout << "Dumping " << filename << std::endl; torch::Tensor data_cpu = data.to(this->device_CPU).contiguous(); std::vector v(data_cpu.data_ptr(), data_cpu.data_ptr() + data_cpu.numel()); - this->dumpVector(filename, v); + this->dump_vector(filename, v); } -void ML_Base::dumpMatrix(std::string filename, const ModuleBase::matrix &data) +void ML_Base::dump_matrix(std::string filename, const ModuleBase::matrix &data) { std::cout << "Dumping " << filename << std::endl; std::vector v(data.c, data.c + this->nx); - this->dumpVector(filename, v); + this->dump_vector(filename, v); } #endif diff --git a/source/source_pw/module_ofdft/ml_base.h b/source/source_pw/module_ofdft/ml_base.h index 8f6cb093f91..758f08add53 100644 --- a/source/source_pw/module_ofdft/ml_base.h +++ b/source/source_pw/module_ofdft/ml_base.h @@ -20,45 +20,119 @@ class ML_Base ~ML_Base(); // Common Interface - void set_device(const std::string& device_inpt, std::ostream& ofs_running); - + void set_device( + const std::string& device_inpt, + std::ostream& ofs_running); + // Tools - void loadVector(std::string filename, std::vector &data); - void dumpVector(std::string filename, const std::vector &data); - void dumpTensor(std::string filename, const torch::Tensor &data); - void dumpMatrix(std::string filename, const ModuleBase::matrix &data); + void load_vector( + std::string filename, + std::vector &data); + + void dump_vector( + std::string filename, + const std::vector &data); + + void dump_tensor( + std::string filename, + const torch::Tensor &data); + + void dump_matrix( + std::string filename, + const ModuleBase::matrix &data); int nx_tot = 0; // equal to nx (called by NN) - torch::Tensor get_data(std::string parameter, const int ikernel) const; + torch::Tensor get_data( + std::string parameter, + const int ikernel) const; protected: - void updateInput(const double * const * prho, const ModulePW::PW_Basis *pw_rho); - void NN_forward(const double * const * prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad); - void get_potential_(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); + void update_input( + const double * const * prho, + const ModulePW::PW_Basis *pw_rho); + + void nn_forward( + const double * const * prho, + const ModulePW::PW_Basis *pw_rho, + bool cal_grad); + + void get_potential_( + const double * const * prho, + const ModulePW::PW_Basis *pw_rho, + ModuleBase::matrix &rpotential); // Potential Terms - these appear identical in both classes or are intended to be shared - double potGammaTerm(int ir); - double potPTerm1(int ir); - double potQTerm1(int ir); - double potXiTerm1(int ir); - double potTanhxiTerm1(int ir); - double potTanhpTerm1(int ir); - double potTanhqTerm1(int ir); + double pot_gamma_term(int ir); + double pot_p_term_1(int ir); + double pot_q_term_1(int ir); + double pot_xi_term_1(int ir); + double pot_tanhxi_term_1(int ir); + double pot_tanhp_term_1(int ir); + double pot_tanhq_term_1(int ir); // Derived classes should ensure they can work with these signatures. // Note: ML_EXX originally passed tau_lda for some of these. // If tau_lda is needed, derived classes can override or we can add it to member variables. // For now, keeping signatures compatible with member access. - void potGammanlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rGammanlTerm); - void potXinlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rXinlTerm); - void potTanhxinlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxinlTerm); - void potTanhxi_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxi_nlTerm); - void potPPnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rPPnlTerm); - void potQQnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rQQnlTerm); - void potTanhpTanh_pnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanh_pnlTerm); - void potTanhqTanh_qnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanh_qnlTerm); - void potTanhpTanhp_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanhp_nlTerm); - void potTanhqTanhq_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanhq_nlTerm); + void pot_gammanl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rGammanlTerm); + + void pot_xi_nl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rXinlTerm); + + void pot_tanhxi_nl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhxinlTerm); + + void pot_tanhxi_nl_nl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhxi_nlTerm); + + void pot_p_pnl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rPPnlTerm); + + void pot_q_qnl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rQQnlTerm); + + void pot_tanhp_tanh_pnl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhpTanh_pnlTerm); + + void pot_tanhq_tanh_qnl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhqTanh_qnlTerm); + + void pot_tanhp_tanhp_nl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhpTanhp_nlTerm); + + void pot_tanhq_tanhq_nl_term( + const double * const *prho, + const std::vector &tau_lda, + const ModulePW::PW_Basis *pw_rho, + std::vector &rTanhqTanhq_nlTerm); protected: // --- Member Variables (Common) --- @@ -146,4 +220,4 @@ class ML_Base }; #endif // __MLALGO -#endif // ML_BASE_H +#endif // ML_BASE_H \ No newline at end of file diff --git a/source/source_pw/module_ofdft/ml_base_pot.cpp b/source/source_pw/module_ofdft/ml_base_pot.cpp index b13e5552a7f..20481541cfd 100644 --- a/source/source_pw/module_ofdft/ml_base_pot.cpp +++ b/source/source_pw/module_ofdft/ml_base_pot.cpp @@ -2,19 +2,19 @@ #ifdef __MLALGO -double ML_Base::potGammaTerm(int ir) +double ML_Base::pot_gamma_term(int ir) { return (this->ml_gamma) ? 1./3. * gamma[ir] * this->gradient_cpu_ptr[ir * this->ninput + this->descriptor2index["gamma"][0]] : 0.; } -double ML_Base::potPTerm1(int ir) +double ML_Base::pot_p_term_1(int ir) { return (this->ml_p) ? - 8./3. * p[ir] * this->gradient_cpu_ptr[ir * this->ninput + this->descriptor2index["p"][0]] : 0.; } -double ML_Base::potQTerm1(int ir) +double ML_Base::pot_q_term_1(int ir) { return (this->ml_q) ? - 5./3. * q[ir] * this->gradient_cpu_ptr[ir * this->ninput + this->descriptor2index["q"][0]] : 0.; } -double ML_Base::potXiTerm1(int ir) +double ML_Base::pot_xi_term_1(int ir) { double result = 0.; for (int ik = 0; ik < this->descriptor2kernel["xi"].size(); ++ik) @@ -25,7 +25,7 @@ double ML_Base::potXiTerm1(int ir) } return result; } -double ML_Base::potTanhxiTerm1(int ir) +double ML_Base::pot_tanhxi_term_1(int ir) { double result = 0.; for (int ik = 0; ik < this->descriptor2kernel["tanhxi"].size(); ++ik) @@ -37,19 +37,19 @@ double ML_Base::potTanhxiTerm1(int ir) } return result; } -double ML_Base::potTanhpTerm1(int ir) +double ML_Base::pot_tanhp_term_1(int ir) { return (this->ml_tanhp) ? - 8./3. * p[ir] * this->cal_tool->dtanh(this->tanhp[ir], this->chi_p) * this->gradient_cpu_ptr[ir * this->ninput + this->descriptor2index["tanhp"][0]] : 0.; } -double ML_Base::potTanhqTerm1(int ir) +double ML_Base::pot_tanhq_term_1(int ir) { return (this->ml_tanhq) ? - 5./3. * q[ir] * this->cal_tool->dtanh(this->tanhq[ir], this->chi_q) * this->gradient_cpu_ptr[ir * this->ninput + this->descriptor2index["tanhq"][0]] : 0.; } // Implementations of nl terms using energy_prefactor/exponent logic -void ML_Base::potGammanlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rGammanlTerm) +void ML_Base::pot_gammanl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rGammanlTerm) { double *dFdgammanl = new double[this->nx]; for (int ik = 0; ik < this->descriptor2kernel["gammanl"].size(); ++ik) @@ -69,7 +69,7 @@ void ML_Base::potGammanlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rXinlTerm) +void ML_Base::pot_xi_nl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rXinlTerm) { double *dFdxi = new double[this->nx]; for (int ik = 0; ik < this->descriptor2kernel["xi"].size(); ++ik) @@ -89,7 +89,7 @@ void ML_Base::potXinlTerm(const double * const *prho, const std::vector delete[] dFdxi; } -void ML_Base::potTanhxinlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxinlTerm) +void ML_Base::pot_tanhxi_nl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxinlTerm) { double *dFdtanhxi = new double[this->nx]; for (int ik = 0; ik < this->descriptor2kernel["tanhxi"].size(); ++ik) @@ -110,7 +110,7 @@ void ML_Base::potTanhxinlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxi_nlTerm) +void ML_Base::pot_tanhxi_nl_nl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhxi_nlTerm) { double *dFdtanhxi_nl = new double[this->nx]; double *dFdtanhxi_nl_nl = new double[this->nx]; @@ -143,7 +143,7 @@ void ML_Base::potTanhxi_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rPPnlTerm) +void ML_Base::pot_p_pnl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rPPnlTerm) { double *dFdpnl = new double[this->nx]; std::vector dFdpnl_tot(this->nx, 0.); @@ -201,7 +201,7 @@ void ML_Base::potPPnlTerm(const double * const *prho, const std::vector } -void ML_Base::potQQnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rQQnlTerm) +void ML_Base::pot_q_qnl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rQQnlTerm) { double *dFdqnl = new double[this->nx]; std::vector dFdqnl_tot(this->nx, 0.); @@ -249,7 +249,7 @@ void ML_Base::potQQnlTerm(const double * const *prho, const std::vector } -void ML_Base::potTanhpTanh_pnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanh_pnlTerm) +void ML_Base::pot_tanhp_tanh_pnl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanh_pnlTerm) { // Note we assume that tanhp_nl and tanh_pnl will NOT be used together. double *dFdpnl = new double[this->nx]; @@ -308,7 +308,7 @@ void ML_Base::potTanhpTanh_pnlTerm(const double * const *prho, const std::vector delete[] tempP; } -void ML_Base::potTanhqTanh_qnlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanh_qnlTerm) +void ML_Base::pot_tanhq_tanh_qnl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanh_qnlTerm) { // Note we assume that tanhq_nl and tanh_qnl will NOT be used together. double *dFdqnl = new double[this->nx]; @@ -359,7 +359,7 @@ void ML_Base::potTanhqTanh_qnlTerm(const double * const *prho, const std::vector } // Note we assume that tanhp_nl and tanh_pnl will NOT be used together. -void ML_Base::potTanhpTanhp_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanhp_nlTerm) +void ML_Base::pot_tanhp_tanhp_nl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhpTanhp_nlTerm) { double *dFdpnl = new double[this->nx]; std::vector dFdpnl_tot(this->nx, 0.); @@ -417,7 +417,7 @@ void ML_Base::potTanhpTanhp_nlTerm(const double * const *prho, const std::vector delete[] tempP; } -void ML_Base::potTanhqTanhq_nlTerm(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanhq_nlTerm) +void ML_Base::pot_tanhq_tanhq_nl_term(const double * const *prho, const std::vector &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector &rTanhqTanhq_nlTerm) { double *dFdqnl = new double[this->nx]; std::vector dFdqnl_tot(this->nx, 0.); From b89b832c0e05ab6caeaf55bf38bacf2fb90d8d75 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 10:57:00 +0800 Subject: [PATCH 09/11] fix a potential bug when the net.pt model cannot be found --- .../source_estate/module_pot/pot_ml_exx.cpp | 34 ++++++++++++------- source/source_pw/module_ofdft/kedf_ml.cpp | 16 +++++++-- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/source/source_estate/module_pot/pot_ml_exx.cpp b/source/source_estate/module_pot/pot_ml_exx.cpp index df4b905f618..41c63ab43f8 100644 --- a/source/source_estate/module_pot/pot_ml_exx.cpp +++ b/source/source_estate/module_pot/pot_ml_exx.cpp @@ -55,8 +55,16 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod int nnode = 100; int nlayer = 3; this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); - torch::load(this->nn, "net.pt", this->device_type); - ofs_running << "load net done" << std::endl; + try + { + torch::load(this->nn, "net.pt", this->device_type); + } + catch (const std::exception& e) + { + ModuleBase::WARNING_QUIT("ML_EXX::set_para", + "Failed to load neural network model from net.pt: " + std::string(e.what())); + } + ofs_running << "load net done (ML EXX neural network functional model loaded successfully)" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -128,9 +136,9 @@ void ML_EXX::ml_potential(const double * const * prho, const ModulePW::PW_Basis rho_data[ir] = std::abs(prho[0][ir]); } - this->updateInput(prho_mod, pw_rho); + this->update_input(prho_mod, pw_rho); - this->NN_forward(prho_mod, pw_rho, true); + this->nn_forward(prho_mod, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -168,9 +176,9 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba { if (PARAM.inp.of_kinetic == "ml") { - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, true); + this->nn_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -182,8 +190,8 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba this->get_potential_(prho, pw_rho, potential); - this->dumpTensor("enhancement.npy", enhancement); - this->dumpMatrix("potential.npy", potential); + this->dump_tensor("enhancement.npy", enhancement); + this->dump_matrix("potential.npy", potential); } } @@ -200,7 +208,7 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw bool fortran_order = false; std::vector temp_prho(this->nx); - this->loadVector("path_to_rho_file", temp_prho); + this->load_vector("path_to_rho_file", temp_prho); double ** prho = new double *[1]; prho[0] = new double[this->nx]; @@ -212,9 +220,9 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw } }; // ============================== - this->updateInput(prho, pw_rho); + this->update_input(prho, pw_rho); - this->NN_forward(prho, pw_rho, true); + this->nn_forward(prho, pw_rho, true); torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous(); this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr(); @@ -226,8 +234,8 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw this->get_potential_(prho, pw_rho, potential); - this->dumpTensor("enhancement-abacus.npy", enhancement); - this->dumpMatrix("potential-abacus.npy", potential); + this->dump_tensor("enhancement-abacus.npy", enhancement); + this->dump_matrix("potential-abacus.npy", potential); exit(0); } diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index 2689000f285..914adff55ba 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -42,6 +42,7 @@ void KEDF_ML::set_para( std::ostream& ofs_running ) { + ModuleBase::TITLE("KEDF_ML", "set_para"); torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble)); auto output = torch::get_default_dtype(); ofs_running << " Default type: " << output << std::endl; @@ -80,8 +81,16 @@ void KEDF_ML::set_para( int nnode = 100; int nlayer = 3; this->nn = std::make_shared(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running); - torch::load(this->nn, "net.pt", this->device_type); - ofs_running << " load net done (neural network loaded successfully)" << std::endl; + try + { + torch::load(this->nn, "net.pt", this->device_type); + } + catch (const std::exception& e) + { + ModuleBase::WARNING_QUIT("KEDF_ML::set_para", + "Failed to load neural network model from net.pt: " + std::string(e.what())); + } + ofs_running << " load net done (ML KEDF neural network model loaded successfully)" << std::endl; if (PARAM.inp.of_ml_feg != 0) { torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type); @@ -130,6 +139,7 @@ void KEDF_ML::set_para( */ double KEDF_ML::get_energy(const double * const * prho, ModulePW::PW_Basis *pw_rho) { + ModuleBase::TITLE("KEDF_ML", "get_energy"); this->update_input(prho, pw_rho); this->nn_forward(prho, pw_rho, false); @@ -197,6 +207,7 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r */ void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff) { + ModuleBase::TITLE("KEDF_ML", "generate_train_data"); // this->cal_tool->generateTrainData_WT(prho, wt, tf, pw_rho, veff); // Will be fixed in next pr if (PARAM.inp.of_kinetic == "ml") { @@ -227,6 +238,7 @@ void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis * */ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho) { + ModuleBase::TITLE("KEDF_ML", "local_test"); // for test ===================== std::vector cshape = {(long unsigned) this->nx}; bool fortran_order = false; From ca9b1d7d057e2742eebe29faf8dc0f53245fd456 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 11:34:57 +0800 Subject: [PATCH 10/11] update kedf and exx --- .../source_estate/module_pot/pot_ml_exx.cpp | 2 +- source/source_estate/module_pot/pot_ml_exx.h | 3 +- .../source_pw/module_ofdft/kedf_manager.cpp | 2 +- source/source_pw/module_ofdft/kedf_ml.cpp | 33 +++++++++++++++---- source/source_pw/module_ofdft/kedf_ml.h | 3 +- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/source/source_estate/module_pot/pot_ml_exx.cpp b/source/source_estate/module_pot/pot_ml_exx.cpp index 41c63ab43f8..d371ca8424f 100644 --- a/source/source_estate/module_pot/pot_ml_exx.cpp +++ b/source/source_estate/module_pot/pot_ml_exx.cpp @@ -172,7 +172,7 @@ void ML_EXX::ml_potential(const double * const * prho, const ModulePW::PW_Basis * @param pw_rho PW_Basis * @param veff effective potential */ -void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff) +void ML_EXX::gen_training_data(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff) { if (PARAM.inp.of_kinetic == "ml") { diff --git a/source/source_estate/module_pot/pot_ml_exx.h b/source/source_estate/module_pot/pot_ml_exx.h index 2dff9d922f2..b3e9d82b941 100644 --- a/source/source_estate/module_pot/pot_ml_exx.h +++ b/source/source_estate/module_pot/pot_ml_exx.h @@ -20,8 +20,7 @@ class ML_EXX : public ML_Base void ml_potential(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); - // output all parameters - void generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff); + void gen_training_data(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff); void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running); void init_data( diff --git a/source/source_pw/module_ofdft/kedf_manager.cpp b/source/source_pw/module_ofdft/kedf_manager.cpp index 6caedf433bb..337158703b2 100644 --- a/source/source_pw/module_ofdft/kedf_manager.cpp +++ b/source/source_pw/module_ofdft/kedf_manager.cpp @@ -499,6 +499,6 @@ void KEDF_Manager::generate_ml_target( ) { #ifdef __MLALGO - this->ml_->generateTrainData(prho, pw_rho, veff); + this->ml_->gen_training_data(prho, pw_rho, veff); #endif } diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index 914adff55ba..b280d61900a 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -110,7 +110,11 @@ void KEDF_ML::set_para( ofs_running << " feg_net_F = " << this->feg_net_F << " (Fermi energy guess factor)" << std::endl << std::endl; } - } + } + else + { + ofs_running << " ML KEDF not enabled (of_kinetic != \"ml\")" << std::endl; + } if (PARAM.inp.of_kinetic == "ml" || PARAM.inp.of_ml_gene_data == 1) { @@ -127,6 +131,10 @@ void KEDF_ML::set_para( kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho, ofs_running); } + else + { + ofs_running << " ML descriptor calculator not initialized (neither ml kinetic nor gene_data enabled)" << std::endl; + } } /** @@ -168,7 +176,7 @@ double KEDF_ML::get_energy(const double * const * prho, ModulePW::PW_Basis *pw_r void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential) { ModuleBase::TITLE("KEDF_ML", "ml_potential"); - ModuleBase::timer::start("KEDF_ML", "pauli_energy"); + ModuleBase::timer::start("KEDF_ML", "ml_potential"); this->update_input(prho, pw_rho); @@ -184,6 +192,8 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r this->get_potential_(prho, pw_rho, rpotential); + // Calculate Pauli energy (ml_energy) from enhancement factor + // E_pauli = c_TF * ∫ F(ρ) * ρ^(5/3) dr double energy = 0.; for (int ir = 0; ir < this->nx; ++ir) { @@ -193,7 +203,7 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r this->ml_energy = energy; Parallel_Reduce::reduce_all(this->ml_energy); - ModuleBase::timer::end("KEDF_ML", "pauli_energy"); + ModuleBase::timer::end("KEDF_ML", "ml_potential"); } /** @@ -205,9 +215,9 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r * @param pw_rho PW_Basis * @param veff effective potential */ -void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff) +void KEDF_ML::gen_training_data(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff) { - ModuleBase::TITLE("KEDF_ML", "generate_train_data"); + ModuleBase::TITLE("KEDF_ML", "gen_training_data"); // this->cal_tool->generateTrainData_WT(prho, wt, tf, pw_rho, veff); // Will be fixed in next pr if (PARAM.inp.of_kinetic == "ml") { @@ -228,6 +238,10 @@ void KEDF_ML::generateTrainData(const double * const *prho, ModulePW::PW_Basis * this->dump_tensor("enhancement.npy", enhancement); this->dump_matrix("potential.npy", potential); } + else + { + std::cout << " Warning: gen_training_data skipped (of_kinetic != \"ml\")" << std::endl; + } } /** @@ -250,8 +264,13 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho) for (int ir = 0; ir < this->nx; ++ir) prho[0][ir] = temp_prho[ir]; for (int ir = 0; ir < this->nx; ++ir) { - if (prho[0][ir] == 0.){ - std::cout << "WARNING: rho = 0" << std::endl; + if (prho[0][ir] == 0.) + { + std::cout << "WARNING: rho = 0 at grid point " << ir << std::endl; + } + else + { + // Normal case: non-zero density } }; // ============================== diff --git a/source/source_pw/module_ofdft/kedf_ml.h b/source/source_pw/module_ofdft/kedf_ml.h index 631bb413736..9ac90a2bf64 100644 --- a/source/source_pw/module_ofdft/kedf_ml.h +++ b/source/source_pw/module_ofdft/kedf_ml.h @@ -56,8 +56,7 @@ class KEDF_ML : public ML_Base void ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential); - // output all parameters - void generateTrainData(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff); + void gen_training_data(const double * const *prho, ModulePW::PW_Basis *pw_rho, const double *veff); void localTest(const double * const *prho, ModulePW::PW_Basis *pw_rho); From d63f55b2dc0529c22f093c6be9144dcc2fa2de97 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Thu, 4 Jun 2026 15:02:58 +0800 Subject: [PATCH 11/11] update --- source/source_pw/module_ofdft/kedf_ml.cpp | 3 ++- source/source_pw/module_ofdft/ml_base.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/source/source_pw/module_ofdft/kedf_ml.cpp b/source/source_pw/module_ofdft/kedf_ml.cpp index b280d61900a..4c439b8f30f 100644 --- a/source/source_pw/module_ofdft/kedf_ml.cpp +++ b/source/source_pw/module_ofdft/kedf_ml.cpp @@ -108,7 +108,8 @@ void KEDF_ML::set_para( this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr()[0]; } - ofs_running << " feg_net_F = " << this->feg_net_F << " (Fermi energy guess factor)" << std::endl << std::endl; + ofs_running << " feg_net_F = " << this->feg_net_F + << " (Pauli energy enhancement factor in free electron gas)" << std::endl << std::endl; } } else diff --git a/source/source_pw/module_ofdft/ml_base.cpp b/source/source_pw/module_ofdft/ml_base.cpp index 5eefec0be26..eb16ff5d940 100644 --- a/source/source_pw/module_ofdft/ml_base.cpp +++ b/source/source_pw/module_ofdft/ml_base.cpp @@ -27,6 +27,8 @@ void ML_Base::set_device(const std::string& device_inpt, std::ostream& ofs_runni } else { + std::cout << "--------------- Warning: GPU is unavailable ---------------" << std::endl; + ofs_running << "--------------- Warning: GPU is unavailable ---------------" << std::endl; ofs_running << "------------------- Running Neural Network on CPU -------------------" << std::endl; this->device_type = torch::kCPU;