首页 诗词 字典 板报 句子 名言 友答 励志 学校 网站地图
当前位置: 首页 > 教程频道 > 开发语言 > C++ >

看看这个矩阵模板该如何优化(没分了)

2012-03-16 
看看这个矩阵模板该怎么优化(没分了) enumARRAY_STORAGE_MODE{SET_BY_COL,Set_BY_ROW}templateclassTcl

看看这个矩阵模板该怎么优化(没分了)

enum   ARRAY_STORAGE_MODE
{
        SET_BY_COL,
        Set_BY_ROW
};

template   <class   T>
class   CMatrix
{
public:
CMatrix(int   val=0);
CMatrix(int   m,   int   n);
CMatrix(CMatrix <T> &   a);

~CMatrix(void);


bool   SetMatrix(T*   array,ARRAY_STORAGE_MODE   mode);

int   GetRows(){return   m;};

int   GetCols(){return   n;};
void   operator=(CMatrix <T> &   a);
T*&   operator[](int   i);
CMatrix <T>   operator+(CMatrix <T> &   a);
CMatrix <T>   operator*(double   k);
CMatrix <T>   operator*(CMatrix <T> &   a);
CMatrix <T>   CutMatrixRows(vector <int>   iArray);//切掉指定行
CMatrix <T>   CutMatrixBlock(int   startRow,int   endRow,int   startCol,int   endCol);//切块
CMatrix <T>   Transpose();
CMatrix <T>   Inverse();
int   m;           //行
int   n;           //列
T**   p;           //数据
};

template   <class   T>
CMatrix <T> ::CMatrix(int   val)
{
        m   =   0;   n     =   0;
        p   =   NULL;
}

template   <class   T>
CMatrix <T> ::CMatrix(int   m,   int   n)
{
int   i,j;
p   =   new   T*[m];        
for   (i   =   0;   i   <   m;   i++)
{
p[i]   =   new   T[n];
for   (j   =   0;   j   <   n;   j++)
{p[i][j]   =   0;}
}
this-> m   =   m;   this-> n   =   n;
}

template   <class   T>
CMatrix <T> ::~CMatrix(void)
{    
        int   i;
        if   (p)
        {
for   (i   =   0;   i   <   m;   i++)
                    {delete[]   p[i];}
delete[]   p;
        }
}

template   <class   T>
bool   CMatrix <T> ::SetMatrix(T*   array,ARRAY_STORAGE_MODE   mode)
{
        int   i,   j=0;
switch(mode)
{
case   SET_BY_COL   :
{        
for   (i   =   0;   i   <   n;   i++)
{
for   (j   =   0;   j   <   m;   j++)
{p[j][i]   =   array[i*m+j];   }
}
                                                      break;
}
        case   Set_BY_ROW   :
{        
for   (i   =   0;   i   <   m;   i++)        
{
for   (j   =   0;   j   <   n;   j++)


{
p[i][j]   =   array[i*n+j];
}
}

break;
}
default:
{

return   false;
}
}

return   true;
}

template   <class   T>
T*&   CMatrix <T> ::operator[](int   i)
{
        return   p[i];
}


template   <class   T>
CMatrix <T>   CMatrix <T> ::Transpose()
{
        int   i,   j;

CMatrix <T>   t(n,   m);
for   (i   =   0;   i   <   m;   i++)
{
for   (j   =   0;   j   <   n;   j++)
{
t[j][i]   =   p[i][j];
}
}

return   t;

}


template   <class   T>
void   CMatrix <T> ::operator=(CMatrix <T> &   a)
{
        int   i,   j;
        if   (p)
        {
for   (i   =   0;   i   <   m;   i++)
{
delete[]   p[i];
}
delete[]   p;
        }

        m   =   a.GetRows();  
n   =   a.GetCols();

        p   =   new   T*[m];

        for   (i   =   0;   i   <   m;   i++)
        {
p[i]   =   new   T[n];
for   (j   =   0;   j   <   n;   j++)
{
p[i][j]   =   a[i][j];
}
        }
}


template   <class   T>
CMatrix <T> ::CMatrix(CMatrix <T> &   a)
{
        int   i,   j;
        m   =   a.GetRows();  
n   =   a.GetCols();
        p   =   new   T*[m];
        for   (i   =   0;   i   <   m;   i++)
        {
p[i]   =   new   T[n];
for   (j   =   0;   j   <   n;   j++)                
{
p[i][j]   =   a[i][j];              

}
        }
}

template   <class   T>
CMatrix <T>   CMatrix <T> ::operator+(CMatrix <T> &   a)
{
int   i,   j;
CMatrix <T>   t(m,   n);
if   (m   ==   a.GetRows()   &&   n   ==   a.GetCols())
{
for   (i   =   0;   i   <   m;   i++)
{
for   (j   =   0;   j   <   n;   j++)
{
t[i][j]   =   p[i][j]   +   a[i][j];
}
}
}
return   t;
}

template   <class   T>
CMatrix <T>   CMatrix <T> ::operator*(CMatrix <T> &   a)
{
int   i,   j,   k;
CMatrix <T>   t(m,   a.GetCols());  
if   (n   ==   a.GetRows())  


{
for   (i   =   0;   i   <   m;   i++)
{
for   (j   =   0;   j   <   a.GetCols();   j   ++)
{
for   (k   =   0;   k   <   n;   k++)
{
t[i][j]   +=     (p[i][k])*(a[k][j])   ;
}
}
}
}
return   t;
}

template <class   T>
CMatrix <T>   CMatrix <T> ::CutMatrixRows(vector <int>   iArray)
{
if   (iArray.size()==0)
{
return   *this;
}
CMatrix <T>   t(m-iArray.size(),n);
int   i,j,k,l,q;

k=l=q=0;
for   (i=0;i <m;i++,l++)
{
if   (k==iArray.size())
{
break;
}

if   (i!=iArray[k])
{
for   (j=0;j <n;j++,q++)
{
t[l][q]   =   p[i][j];
}

}
else
{
k++;
}

}

return   t;
}

template <class   T>
CMatrix <T>   CMatrix <T> ::Inverse()//3阶求逆
{
CMatrix <T>   t(*this);

        int   is[3];
        int   js[3];

        float   fDet   =   1.0f;
        int   f   =   1;

        for   (int   k   =   0;   k   <   3;   k   ++)
        {
                //   &micro;&Uacute;&Ograve;&raquo;&sup2;&frac12;&pound;&not;&Egrave;&laquo;&Ntilde;&iexcl;&Ouml;÷&Ocirc;&ordf;                
float   fMax   =   0.0f;
                for   (int   i   =   k;   i   <   3;   i   ++)                
{
                        for   (int   j   =   k;   j   <   3;   j   ++)
                        {
                                const   float   f   =   fabs(t[i][j]);
                                if   (f   >   fMax)
                                {
                                        fMax         =   f;
                                        is[k]         =   i;
                                        js[k]         =   j;
                                }


                        }
                }
                if   (fMax   <   0.0001f)
                        return   t;
               
                if   (is[k]   !=   k)
                {
                        f   =   -f;
                        swap(t[k][0],   t[is[k]][0]);
                        swap(t[k][1],   t[is[k]][1]);
                        swap(t[k][2],   t[is[k]][2]);
                }
                if   (js[k]   !=   k)
                {
                        f   =   -f;
                        swap(t[0][k],   t[0][js[k]]);
                        swap(t[1][k],   t[1][js[k]]);
                        swap(t[2][k],   t[2][js[k]]);
                }

                //   &frac14;&AElig;&Euml;&atilde;&ETH;&ETH;&Aacute;&ETH;&Ouml;&micro;
                fDet   *=   t[k][k];

                //   &frac14;&AElig;&Euml;&atilde;&Auml;&aelig;&frac34;&Oslash;&Otilde;ó

                //   &micro;&Uacute;&para;&thorn;&sup2;&frac12;
                t[k][k]   =   1.0f   /   t[k][k];        
                //   &micro;&Uacute;&Egrave;&yacute;&sup2;&frac12;
                for   (int   j   =   0;   j   <   3;   j   ++)
                {
                        if   (j   !=   k)
                                t[k][j]   *=   t[k][k];
                }
                //   &micro;&Uacute;&Euml;&Auml;&sup2;&frac12;
                for   (i   =   0;   i   <   3;   i   ++)
                {
                        if   (i   !=   k)


                        {
                                for         (j   =   0;   j   <   3;   j   ++)
                                {
                                        if   (j   !=   k)
                                                t[i][j]   =   t[i][j]   -   t[i][k]   *   t[k][j];
                                }
                        }
                }
                //   &micro;&Uacute;&Icirc;&aring;&sup2;&frac12;
                for   (i   =   0;   i   <   3;   i++)
                {
                        if   (i!=   k)
                                t[i][k]   *=   -t[k][k];
                }
        }

        for         (k   =   2;   k   > =   0;   k--)
        {
                if   (js[k]   !=   k)
                {
                        swap(t[k][0],   t[js[k]][0]);
                        swap(t[k][1],   t[js[k]][1]);
                        swap(t[k][2],   t[js[k]][2]);
                }
                if   (is[k]   !=   k)
                {
                        swap(t[0][k],   t[0][is[k]]);
                        swap(t[1][k],   t[1][is[k]]);
                        swap(t[2][k],   t[2][is[k]]);
                }
        }
return   t;

}

template   <class   T>
CMatrix <T>   CMatrix <T> ::CutMatrixBlock(int   startRow,int   endRow,int   startCol,int   endCol)//°&acute;&Otilde;&Otilde;iArray&Ccedil;&ETH;&micro;&ocirc;&Ouml;&cedil;&para;¨&ETH;&ETH;
{
CMatrix <T>   t(endRow-startRow+1,endCol-startCol+1);



for   (int   i=0;i <endRow   -   startRow+1;i++)
{
for   (int   j=0;j <endCol   -   startCol+1;j++)
{
t[i][j]   =   p[startRow+i][startCol+j];
}

}

return   t;
}
上面的类是小弟在用vc修改matlab程序时写的矩阵类(参考了网上的),发现效率很差,大约和matlab相差5倍的时间,考虑有可能是该类的问题,问问大家该怎么优化?谢谢!

[解决办法]
http://community.csdn.net/Expert/topic/5265/5265271.xml?temp=.7587549
去找现成的来用,不要自己写。
[解决办法]
矩阵相乘是一个有优化余地的东西, 有Strassen快速矩阵乘法,你去搜搜看

热点排行