strassen's matrix multiplication algorithm in java






 



















import java.util.ArrayList;


public class work {

public static void main(String[] args) {

System.out.println("note that the order of both matrix have to be same \n and in the power of (2X2) order \n");

int a[][] = {

{5, 7, 9, 10},

{2, 3, 3, 8},

{8, 10, 2, 3},

{3, 3, 4, 8}

};

int b[][] = {

{3, 10, 12, 18},

{12, 1, 4, 9},

{9, 10, 12, 2},

{3, 12, 4, 10}

};

int [][]aa = {

{1, 2},

{3, 4}

};

int [][]ba = {

{2, 0},

{1, 2}

};

matrix m1 = new matrix(a);

matrix m2 = new matrix(b);

System.out.println("first matrix is here ");


display screen = new display(m1, a.length);

System.out.println("second matrix is here ");

screen = new display(m2, b.length);

multipication obj = new multipication(m1, m2);

matrix toprint = obj.getR();

System.out.println("result matrix is here ");


screen = new display(toprint,  a.length);


}

}

class matrix {


matrix marr[][];

int values[][];

matrix(matrix obj) {

if (obj.values == null) {

if (obj.marr == null) {

return;

}

marr = new matrix[2][2];

marr[0][0] = obj.marr[0][0];

marr[0][1] = obj.marr[0][1];

marr[1][0] = obj.marr[1][0];

marr[1][1] = obj.marr[1][1];

return;

}

values = obj.values;

return;

}

matrix(matrix m1, matrix m2, matrix m3, matrix m4) {

marr = new matrix[2][2];

marr[0][0] = m1;

marr[0][1] = m2;

marr[1][0] = m3;

marr[1][1] = m4;

}

matrix(int v1, int v2, int v3, int v4) {

values = new int[2][2];

values[0][0] = v1;

values[0][1] = v2;

values[1][0] = v3;

values[1][1] = v4;

}

matrix(int [][]arr) {

if (arr.length == 2) {

values = new int[2][2];

values[0][0] = arr[0][0];

values[0][1] = arr[0][1];

values[1][0] = arr[1][0];

values[1][1] = arr[1][1];

return;

}

marr = new matrix[2][2];

int length = arr.length;

if (length / 4 == 1) {

int temp[][] = new int[length / 2][length / 2];

for (int i = 0; i < length / 2; i++) {

for (int j = 0; j < length / 2; j++) {


temp[i % 2][j % 2] = arr[i][j];


}

}


marr[0][0] = new matrix(temp[0][0], temp[0][1], temp[1][0], temp[1][1]);


for (int i = 0; i < length / 2; i++) {

for (int j = length / 2; j < length; j++) {


temp[i % 2][j % 2] = arr[i][j];

}

}

marr[0][1] = new matrix(temp[0][0], temp[0][1], temp[1][0], temp[1][1]);


for (int i = length / 2; i < length; i++) {

for (int j = 0; j < length / 2; j++) {


temp[i % 2][j % 2] = arr[i][j];

}

}


marr[1][0] = new matrix(temp[0][0], temp[0][1], temp[1][0], temp[1][1]);




for (int i = length / 2; i < length; i++) {

for (int j = length / 2; j < length; j++) {


temp[i % 2][j % 2] = arr[i][j];

}

}

marr[1][1] = new matrix(temp[0][0], temp[0][1], temp[1][0], temp[1][1]);


return;



} else {

int n = length / 2;


int temp[][] = new int[n][n];


for (int i = 0; i < n; i++) {

for (int j = 0; j < n; j++) {


temp[i % n][j % n] = arr[i][j];


}

}


marr[0][0] = new matrix(temp);


for (int i = 0; i < n; i++) {

for (int j = n; j < length; j++) {


temp[i % n][j % n] = arr[i][j];

}

}

marr[0][1] = new matrix(temp);


for (int i = n; i < length; i++) {

for (int j = 0; j < n; j++) {


temp[i % n][j % n] = arr[i][j];

}

}


marr[1][0] = new matrix(temp);




for (int i = n; i < length; i++) {

for (int j = n; j < length; j++) {


temp[i % n][j % n] = arr[i][j];

}

}

marr[1][1] = new matrix(temp);


return;


}


}

}

class display {


ArrayList<String> lines = new ArrayList<>();

ArrayList<String> ulist = new ArrayList<>();

ArrayList<String> llist = new ArrayList<>();

int length = 0;

private boolean initialize = false;

private int size = 0;

display(matrix obj, int s) {

if (! initialize) {

this.size = s;

initialize = true;

}

print(obj);

show();

System.out.println();

}


void print(matrix obj) {



if (obj.values != null) {


String s1 = " " + obj.values[0][0] + " " + obj.values[0][1];


String s2 = " " + obj.values[1][0] + " " + obj.values[1][1];


length ++;

ulist.add(s1);

llist.add(s2);


if (length == size / 2) {

String line = "";

for (String temp : ulist) {

line = line + temp;

}

lines.add(line);

ulist.clear();

line = "";

for (String temp : llist) {

line = line + temp;

}

lines.add(line);

llist.clear();

length = 0;


}

return;

} else {

if (obj.marr == null) {

return;

}


print(obj.marr[0][0]);


print(obj.marr[0][1]);


print(obj.marr[1][0]);


print(obj.marr[1][1]);


}


}

void show() {

for (String line : lines) {

System.out.println(line);

}


}

}

class multipication {

matrix newResult;

multipication(matrix x, matrix y) {

if (x.values != null || x.marr != null) {

newResult = mult(x, y);


}


}

matrix getR() {

return newResult;

}


matrix mult(matrix x, matrix y) {


if (x.values != null && y.values != null) {

int m1, m2, m3, m4, m5, m6, m7, i, j, k, l;

int  a = x.values[0][0];

int  b = x.values[0][1];

int  c = x.values[1][0];

int  d = x.values[1][1];

int  e = y.values[0][0];

int  f = y.values[0][1];

int  g = y.values[1][0];

int  h = y.values[1][1];


m1 = (a + c) * (e + f);

m2 = (b + d) * (g + h);

m3 = (a - d) * (e + h);

m4 = (a) * (f - h);

m5 = (c + d) * (e);

m6 = (a + b) * (h);

m7 = (d) * (g - e);




i = m2 + m3 - m6 - m7;

j = m4 + m6;

k = m5 + m7;

l = m1 - m3 - m4 - m5;

matrix result = new matrix(new int[][] {{i, j}, { k, l}});


return result;



}


matrix m1 = null, m2 = null, m3 = null, m4 = null, m5 = null, m6 = null, m7 = null, i = null, j = null, k = null, l = null;

matrix a = x.marr[0][0];

matrix  b = x.marr[0][1];

matrix c = x.marr[1][0];

matrix d = x.marr[1][1];

matrix e = y.marr[0][0];

matrix f = y.marr[0][1];

matrix g = y.marr[1][0];

matrix  h = y.marr[1][1];



m1 = mult(add(a, c), add(e, f));


m2 = mult(add(b, d), add(g, h));


m3 = mult(sub(a, d), add(e, h));


m4 = mult(a, sub(f, h));


m5 = mult(add(c, d), e);


m6 = mult(add(a, b), h);


m7 = mult(d, sub(g, e));


i = add(sub(m2, m6), sub(m3, m7));

j = add(m4, m6);

k = add(m5, m7);

l = sub(sub(sub(m1, m3), m4), m5);


matrix result = new matrix(i, j, k, l);


return result;



}



matrix add(matrix x, matrix y) {


if (x.values != null && y.values != null) {



int a = x.values[0][0] + y.values[0][0];



int b = x.values[0][1] + y.values[0][1];


int c = x.values[1][0] + y.values[1][0];



int d = x.values[1][1] + y.values[1][1];



matrix result = new matrix(new int[][] {{a, b}, {c, d}});


return result;

}


matrix a = add(x.marr[0][0], y.marr[0][0]);

matrix b = add(x.marr[0][1], y.marr[0][1]);

matrix c = add(x.marr[1][0], y.marr[1][0]);

matrix d = add(x.marr[1][1], y.marr[1][1]);


matrix result = new matrix(a, b, c, d);


return result;



}


matrix sub(matrix x, matrix y) {

if (x.values != null && y.values != null) {

int a = x.values[0][0] - y.values[0][0];

int b = x.values[0][1] - y.values[0][1];

int c = x.values[1][0] - y.values[1][0];

int d = x.values[1][1] - y.values[1][1];


matrix result = new matrix(new int[][] {{a, b}, {c, d}});


return result;

}


matrix a = sub(x.marr[0][0], y.marr[0][0]);

matrix b = sub(x.marr[0][1], y.marr[0][1]);

matrix c = sub(x.marr[1][0], y.marr[1][0]);

matrix d = sub(x.marr[1][1], y.marr[1][1]);


matrix result = new matrix(a, b, c, d);

return result;




}

}


Comments

  1. Purushottam singram14 October 2023 at 10:11

    i cant believe that i wrote this : very sophisticated , very elegant 🫡

    ReplyDelete

Post a Comment