컴퓨터/PS

[PS] 백준 11049 행렬 곱셈 순서/최대한 자세히 설명해보기 - 자바 풀이

도도새 도 2023. 12. 13. 14:25

행렬 곱셈 순서

 

행렬 곱셈 순서에 대한 문제이다. 행렬은 곱 연산시 순서에 상관없이 같은 결과값을 보장하나, 곱셈 순서에 따라 곱셈 연산 횟수가 달라지게 된다. 해당 연산을 최소화 했을 때 몇 번인지 구하는 문제이다.

행렬의 곱셈 순서를 정하는 것은 대표적인 DP알고리즘 사용 사례라고 한다. 이 문제를 혼자 풀기에 실패하였기에 풀이를 보고 이해한 바를 정리한다. 접근 과정을 최대한 자세히 기술하였다.


문제

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.

 

입력

첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

 

출력

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.

 

예제 입력 1 

3
5 3
3 2
2 6

 

예제 출력 1

90

 

풀이

 

Metric 배열

우선 기본적인 발상은 아래와 같다. 매트릭스 A B C가 주어질 때,

metric[inx][0]은 inx 매트릭스의 행 수이다.

metric[inx][1]은 inx 매트릭스의 열 수이다.

 

이를테면, 첫 번째인 5*3 행렬의 경우 아래와 같이 나타낼 수 있다. (인덱스는 0부터 시작하도록 한다.)

metric[0][0] = 5;

metric[0][1] = 3;

 

DP 배열

다음으로 dp 배열을 정의한다. dp 배열은 특정 범위의 최소값을 뜻한다. 예를 들어

dp[0][2]의 경우 0번 매트릭스부터 2번 매트릭스까지의 곱셈 연산의 최소값이다.

 

dp[i][j] // i번째 행렬에서 j번째 행렬까지의 곱 연산시 최소 횟수 

 

예제 입력을 보면 각 매트릭스는 아래와 같다.

metric = [{5, 3}, {3, 2}, {2, 6}};

각 매트릭스는 A B C라는 별칭을 부여하자.

 

접근

 

A B C의 경우(3개의 매트릭스)

접근 발상은 범위 별로 계산을 해서 적당히 끼워 맞추는 것이다. 이를테면 ABC의 경우를 보자. 이는 두 가지 케이스로 이루어진다.

Case 1: (A*B)*C //A와 B를 먼저 곱하고 C를 곱하는 것이다.

Case 2: A*(B*C) //B와 C를 먼저 곱하고 C를 곱하는 것이다.

이 결과 dp[0][2]은 Case1과 Case2중 최소값을 선택한다.

 

A B C D의 경우(4개의 매트릭스)

만약 매트릭스가 A B C D로 주어진다면?

Case 1 :(A * B * C) * D //A와 B와 C를 곱한 최소 값에 D를 곱하는 것이다.

Sub Case 1 : (A * B) * C // (A * B * C)의 최소를 찾기 위한 서브 케이스

Sub Case 2: A * (B * C) // (A * B * C)의 최소를 찾기 위하 서브 케이스

즉, Case 1은 min(Sub Case1, Sub Case2)에 그 결과 배열과 D를 곱한 연산의 수가 된다.

 

Case 2 :(A * B) * (C * D) 

Sub Case 1 :A * B

Sub Case 2: C * D

즉, Case 2은 Sub Case1, Sub Case2를 곱한 연산의 수가 된다.

 

Case3 : A*(B * C * D)

Sub Case 1 : (B * C) * D //(B * C * D)의 최소를 찾기 위한 서브 케이스

Sub Case 1 : B * (C * D) //(B * C * D)의 최소를 찾기 위한 서브 케이스

즉, Case 3은 A매트릭스에 min(Sub Case1, Sub Case2)를 곱한 연산의 수가 된다.

 

이 결과 dp[0][3]은 Case1과 Case2, Case3중 최소값을 선택한다.

 

이전에 했던 연산을 이후의 연산에서 지속적으로 사용하게 된다. 바텀 업 방식으로 구현 할 수 있겠다. 또한 큰 경우의 수는 n-1개만큼 나오는 것을 확인 할 수 있다.

 

그렇다면 어떻게 케이스를 나눌까? 그것은 범위를 기준으로 나누는 것이다.


매트릭스 A B C

metric = [{5, 3}, {3, 2}, {2, 6}}; 를 기준으로 보자.

 

범위 0인 경우

앞서 dp[i][j]는 i부터 j까지의 행렬 곱 연산의 횟수의 최소값이라고 정의하였다.

dp의 범위(ex: dp[0][1]은 범위가 1이다) 0이라면 값은 0이다. 자기 자신뿐이기 때문이다. 이를테면 dp[1][1]은 매트릭스 1에서 1번까지의 곱셈의 수를 의미하며 이는 0이다.

 

범위 1인 경우

그렇담 범위가 1이라면? 행렬 A B C에서는 두 가지 케이스가 존재하게 된다.

dp[0][1] // A * B

이는 매트릭스 A와 매트릭스 B의 곱의 형태이다.

dp[0][1] = {5, 3} * {3, 2} = 5 * 3 * 2 = 30

 

dp[1][2] // B * C

이는 매트릭스 B와 매트릭스 C의 곱의 형태이다.

dp[1][2] = {3, 2} * {2, 6} = 3 * 2 * 6 = 36

 

범위 2인 경우

범위가 2라면 여기서는 전범위를 나타내게 된다. 즉 정답을 나타낸다.

dp[0][2] : 0에서 2번까지 곱셈 연산의 최소값

이는 아래 두 형태 중 최소값을 나타낸다.

(A * B) * C// 케이스 1번

A * (B * C) // 케이스 2번

 

케이스 1번

A * B의 결과 배열은 {5, 3} * {3, 2} = {5, 2}의 형태이다.// 매트릭스 결과는 앞 매트릭스의 행과 뒷 매트릭스의 열의 형태

 

그렇담 곱셈 연산 횟수는

(A * B의 연산수) + ((A * B 결과 행렬) * C행렬의 연산 수)이다. 즉,

30+ (5* 2 * 6)= 90이 된다.

 

케이스 2번

B * C의 결과 배열은 {3, 2} * {2, 6} = {3, 6}의 형태이다.

그렇담 곱셈 연산 횟수는

(A행렬 * (B * C 결과 행렬)의 연산 수) * (B * C의 연산 수) 이다. 즉,

(5* 3 * 6) + 36 = 126이 된다.

 

dp[0][2]의 값은 위 케이스 1과 2 중 최소값이 90이 되며, 이는 테스트 케이스의 정답이 된다.

 

표로 나타내기

이를 표로 나타내면 아래와 같다.

아래 표는 dp배열을 나타낸다. 예를 들어 0행 1열은 dp[0][1]을 나타낸다. 즉 0번부터 1번까지의 최소 곱연산 횟수이다.

DP표
index 0 1 2
0 0 30 90
1   0 36
2     0

위 표를 보면 알다시피 계산은 왼쪽 위부터 오른쪽 아래로 대각선 형태로 진행된다. 다음 범위로 나아가기 위해서는 이전에 범위의 모든 연산(대각 범위)가 모두 필요한 것이다.

 

즉, 위 표에서 대각선은 같은 범위를 나타내게 된다. 즉, 각 범위의 최소값을 채워나가며 범위를 점점 넓혀나가는 식으로 문제를 풀게 된다.

 

이를 일반화 하면 아래 식을 얻을 수 있다.(이 일반화해서 식을 찾는 것을 내가 해결하지 못했었다.)

for 1 ≤ i ≤ j ≤ n,

 if(i = j): dp[i][j] = 0

 if(i < j): dp[i][j] = Minumum(dp[i][k] + dp[k+1][j] + (metric[i][0] * metric[k][1] * metric[j][1])(단, i ≤ k ≤ j-1)

*dp[i][k] + dp[k+1][j]는 피벗 k를 기준으로 앞 행렬과 뒤 행렬의 경우의 수를 더한 것이다.

*metric[i][0] * metric[k][1] * metric[j][1]는 피벗 k를 기준으로 앞 행렬과 뒷 행렬를 곱하는 경우의 수이다. 즉 두 행렬을 이을 때 나오는 곱의 횟수이다.

 

이 식에서 i는 시작, j는 종료 위치, k는 나누는 위치를 나타낸다.

(A * B) * C에서 i는 0, j는 2, k는 1이다.

 


 

식을 사용한 접근

이 식을 이용해 범위 1부터 범위 2까지 접근해보자.

metric = [{5, 3}, {3, 2}, {2, 6}};

 

범위 1 :

i = 0, j = 1, k = 0

dp[0][1] = dp[0][0] + dp[1][1] + (metric[0][0] * metric[0][1] * metric[1][1])

= 0 + 0 + (5 * 3 * 2) = 30

 

i = 1, j = 2, k = 1

dp[1][2] = dp[1][1] + dp[2][2] + (metric[1][0] * metric[1][1] * metric[2][1])

= 0 + 0 + (3 * 2 * 6) = 36

 

범위 2:

i = 0, j = 2, k = 0

dp[0][2] = dp[0][0] + dp[1][2] + (metric[0][0] * metric[0][1] * metric[2][1])

= 0 + 36 + (5 * 3 * 6) = 126

 

i = 0 , j =2, k = 1

dp[0][2] = dp[0][1] + dp[2][2] + (metric[0][0] * metric[1][1] * metric[2][1])

= 30 + 0 + (5 * 2 * 6) = 90

dp[0][2]는 이 중 최소값인 90이 된다.

 

정답 코드

더보기
import java.util.*;
import java.math.*;

public class Main {
	static int INF = Integer.MAX_VALUE;
	public static void main(String[] args) {
		Scanner sc = new Scanner(System.in);
		int metric[][] = new int[501][2];
		int n = sc.nextInt();
		
		int dp[][] = new int[501][501];
		
		for(int i = 0; i < n; i++) {
			int r = sc.nextInt();
			int c = sc.nextInt();
			metric[i][0] = r;
			metric[i][1] = c;
		}
		
		for(int gap = 1; gap < n; gap++) { //범위(0은 모두 0으로 초기화하기에 1부터 시작)
			for(int i = 0; i < n-gap; i++) {//metric number
				dp[i][i+gap] = INF;
				for(int pivot = i; pivot < i+gap; pivot++) {//(시작범위 ~ 끝범위까지)
					int val = dp[i][pivot] + dp[pivot+1][i+gap] + (metric[i][0] * metric[pivot][1] * metric[i+gap][1]);
					dp[i][i+gap] = Math.min(dp[i][i+gap], val);
					
				}
			}
		}
		System.out.println(dp[0][n-1]);
	}

}

 


내가 이해하는 데 약간 시간이 걸린 문제기에 풀이를 최대한 자세히 작성해 보았다. 틀린 설명을 발견할 시 댓글로 정정해 주기를 부탁드립니다.

* 아래 강의가 아주 큰 도움이 되었습니다. 이해가 안 된다면 꼭 한 번 보시길 추천합니다.

ref:

https://www.youtube.com/watch?v=5MXOUix_Ud4&t=1659s