diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -53,7 +53,8 @@ /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. -template static OpType getSingleOpOfType(Block &block) { +template +static OpType getSingleOpOfType(Block &block) { OpType res; block.walk([&](OpType op) { if (res) { @@ -1504,7 +1505,7 @@ /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} /// ``` - /// w and kw are unrolled. + /// kw is always unrolled. /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. FailureOr conv() { if (!valid) @@ -1519,47 +1520,50 @@ vector::TransferWriteOp write; Value zero = builder.create(loc, 0); + int64_t wSizeStep = strideW == 1 ? wSize : 1; + // Unroll along kw and read slices of lhs and rhs. // Alternatively we could preload both 3-d slices and extract smaller slices // iteratively without touching memory. But this will quickly spill. for (int64_t kw = 0; kw < kwSize; ++kw) { - // Read rhs slice of size {1, c, f} @ [kw, 0, 0]. + // Read rhs slice of size {c, f} @ [kw, 0, 0]. Value kwVal = builder.create(loc, kw); VectorType rhsType = - VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType()); + VectorType::get({cSize, fSize}, rhsShapedType.getElementType()); Value rhs = builder.create( loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}); - for (int64_t w = 0; w < wSize; ++w) { - // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0]. + for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { + // Read lhs slice of size {n, wSizeStep, c} + // @ [0, sw * w + dw * kw, 0]. Value lhsStridedIdx = builder.create( - loc, strideW * w + dilationW * kw); - VectorType lhsType = - VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType()); + loc, strideW * w_iv + dilationW * kw); + VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize}, + lhsShapedType.getElementType()); Value lhs = builder.create( loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}); - // Read res slice: {n, 1, f} @ [0, w, 0]. - Value wVal = builder.create(loc, w); - VectorType resType = - VectorType::get({nSize, 1, fSize}, resShapedType.getElementType()); + // Read res slice: {n, wSizeStep, f} @ [0, w, 0]. + Value wVal = builder.create(loc, w_iv); + VectorType resType = VectorType::get({nSize, wSizeStep, fSize}, + resShapedType.getElementType()); // When operating on tensors, reading from the updated value is required // for vector.transfer_read/write hoisting to function as expected. Value res = builder.create( loc, resType, resShaped, ValueRange{zero, wVal, zero}); - // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f} + // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f} StringRef par = Par().strRef, red = Red().strRef; - AffineExpr n, one, f, c; - bindDims(ctx, n, one, f, c); + AffineExpr n, w, f, c; + bindDims(ctx, n, w, f, c); // clang-format off res = builder.create( loc, lhs, rhs, res, - /*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}}, + /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); // clang-format on - // Write back res slice: {n, 1, f} @ [0, w, 0]. + // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. write = builder.create( loc, res, resShaped, ValueRange{zero, wVal, zero}); if (write.getNumResults() == 1) diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -9,7 +9,7 @@ } // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: func @conv1d_nwc_4x2x8_memref @@ -28,7 +28,7 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] /// w == 1, kw == 0 @@ -38,7 +38,7 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] // ----- @@ -52,7 +52,7 @@ } // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: func @conv1d_nwc_4x2x8_memref @@ -73,7 +73,7 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] /// w == 0, kw == 1 @@ -83,7 +83,7 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] /// w == 1, kw == 0 @@ -94,7 +94,7 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] /// w == 1, kw == 1 @@ -104,5 +104,49 @@ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]] -// CHECK-SAME: : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32> +// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> // CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] + +// ----- + + + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK: func @conv1d_nwc_4x2x8_memref +// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>) +func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) { +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> +// CHECK: %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> +// CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> +// CHECK: %[[CONTRACT0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]] +// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 0, kw == 1 +// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> +// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> +// CHECK: %[[CONTRACT1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]] +// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> +// CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + linalg.conv_1d_nwc_wcf + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>) + outs(%output : memref<4x2x8xf32>) + return +}