diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -20,52 +20,50 @@ ods_def: def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { - O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw))); + O(w) = std_addf(std_mulf(I(w + kw), K(kw))); } ods_def: def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { - O(n, w, f) = std_addf(O(n, w, f), - std_mulf(I(n, w + kw, c), K(f, kw, c))); + O(n, w, f) = std_addf(std_mulf(I(n, w + kw, c), K(f, kw, c))); } ods_def: def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { - O(n, f, w) = std_addf(O(n, f, w), - std_mulf(I(n, c, w + kw), K(f, c, kw))); + O(n, f, w) = std_addf(std_mulf(I(n, c, w + kw), K(f, c, kw))); } ods_def: def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) { - O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw))); + O(h, w) = std_addf(std_mulf(I(h + kh, w + kw), K(kh, kw))); } ods_def: def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { - O(n, h, w, f) = std_addf(O(n, h, w, f), - std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); + O(n, h, w, f) = std_addf(std_mulf( + I(n, h + kh, w + kw, c), K(f, kh, kw, c))); } ods_def: def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { - O(n, f, h, w) = std_addf(O(n, f, h, w), - std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); + O(n, f, h, w) = std_addf(std_mulf( + I(n, c, h + kh, w + kw), K(f, c, kh, kw))); } ods_def: def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { - O(d, h, w) = std_addf(O(d, h, w), - std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); + O(d, h, w) = std_addf(std_mulf( + I(d + kd, h + kh, w + kw), K(kd, kh, kw))); } ods_def: def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { - O(n, d, h, w, f) = std_addf(O(n, d, h, w, f), - std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); + O(n, d, h, w, f) = std_addf(std_mulf( + I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); } ods_def: def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { - O(n, f, d, h, w) = std_addf(O(n, f, d, h, w), - std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); + O(n, f, d, h, w) = std_addf(std_mulf( + I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); } \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1318,14 +1318,15 @@ // CHECKPARALLEL: %[[c1:.*]] = constant 1 : index // CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref -// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) { +// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref func @conv2d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { @@ -1367,15 +1368,17 @@ // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref // CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref -// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) -// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) { +// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref func @conv3d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { @@ -1427,13 +1430,16 @@ // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref // CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref // CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref -// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) { -// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) -// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) -// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) -// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref -// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref -// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref -// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 -// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 -// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) { +// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] { +// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] { +// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) +// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) +// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) +// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref +// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref +// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref +// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 +// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 +// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref