diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1396,8 +1396,7 @@ /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} /// ``` -/// w and kw are unrolled. -/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1. +/// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1D_NWC_WCF_Generator : public StructuredGenerator { Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, int dilationW) @@ -1455,52 +1454,58 @@ vector::TransferWriteOp write; Value zero = builder.create(loc, 0); + // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. + // When strideW == 1, we can batch the contiguous loads and avoid unrolling int64_t wSizeStep = strideW == 1 ? wSize : 1; + VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize}, + lhsShapedType.getElementType()); + VectorType rhsType = + VectorType::get({cSize, fSize}, rhsShapedType.getElementType()); + VectorType resType = VectorType::get({nSize, wSizeStep, fSize}, + resShapedType.getElementType()); + + SmallVector lhsVals, rhsVals, resVals; // Unroll along kw and read slices of lhs and rhs. // Alternatively we could preload both 3-d slices and extract smaller slices // iteratively without touching memory. But this will quickly spill. for (int64_t kw = 0; kw < kwSize; ++kw) { // Read rhs slice of size {c, f} @ [kw, 0, 0]. Value kwVal = builder.create(loc, kw); - VectorType rhsType = - VectorType::get({cSize, fSize}, rhsShapedType.getElementType()); - Value rhs = builder.create( - loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero}); + rhsVals.push_back(builder.create( + loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero})); for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { // Read lhs slice of size {n, wSizeStep, c} // @ [0, sw * w + dw * kw, 0]. Value lhsStridedIdx = builder.create( loc, strideW * w_iv + dilationW * kw); - VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize}, - lhsShapedType.getElementType()); - Value lhs = builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero}); + lhsVals.push_back(builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero})); // Read res slice: {n, wSizeStep, f} @ [0, w, 0]. Value wVal = builder.create(loc, w_iv); - VectorType resType = VectorType::get({nSize, wSizeStep, fSize}, - resShapedType.getElementType()); // When operating on tensors, reading from the updated value is required // for vector.transfer_read/write hoisting to function as expected. - Value res = builder.create( - loc, resType, resShaped, ValueRange{zero, wVal, zero}); - + resVals.push_back(builder.create( + loc, resType, resShaped, ValueRange{zero, wVal, zero})); + } + } + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { // Compute contraction: I{n, w, c} * F{c, f} -> O{n, w, f} - StringRef par = Par().strRef, red = Red().strRef; - AffineExpr n, w, f, c; - bindDims(ctx, n, w, f, c); - // clang-format off - res = builder.create( - loc, lhs, rhs, res, - /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, - /*iteratorTypes=*/ArrayRef{par, par, par, red}); - // clang-format on - + resVals[kw * (wSize / wSizeStep) + w_iv] = conv1dSliceAsContraction( + builder, loc, lhsVals[kw * (wSize / wSizeStep) + w_iv], rhsVals[kw], + resVals[kw * (wSize / wSizeStep) + w_iv]); + } + } + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { + Value wVal = builder.create(loc, w_iv); // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. write = builder.create( - loc, res, resShaped, ValueRange{zero, wVal, zero}); + loc, resVals[kw * (wSize / wSizeStep) + w_iv], resShaped, + ValueRange{zero, wVal, zero}); if (write.getNumResults() == 1) resShaped = write->getResult(0); } @@ -1509,6 +1514,19 @@ return write.getOperation(); } + // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} + vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc, + Value lhs, Value rhs, + Value res) { + StringRef par = Par().strRef, red = Red().strRef; + AffineExpr n, w, f, c; + bindDims(ctx, n, w, f, c); + return builder.create( + loc, lhs, rhs, res, + /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, + /*iteratorTypes=*/ArrayRef{par, par, par, red}); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateConv() { diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -24,21 +24,26 @@ // CHECK: %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK: %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +/// w == 1, kw == 0 +// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] // CHECK: %[[CONTRACT0:.+]] = vector.contract { + +/// w == 0, kw == 0 // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] - /// w == 1, kw == 0 -// CHECK: %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] // CHECK: %[[CONTRACT1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +/// w == 1, kw == 0 // CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] // ----- @@ -69,48 +74,53 @@ // CHECK: %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK: %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK: %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +/// w == 0, kw == 1 +// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] +/// w == 1, kw == 0 +// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]] +// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] +/// w == 1, kw == 1 +// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]] +// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] + +/// w == 0, kw == 0 // CHECK: %[[CONTRACT0_A:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] - /// w == 0, kw == 1 -// CHECK: %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] // CHECK: %[[CONTRACT1_A:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] - /// w == 1, kw == 0 -// CHECK: %[[V_FILTER_B:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]] -// CHECK: %[[V_INPUT0_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]] // CHECK: %[[CONTRACT0_B:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> -// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] - /// w == 1, kw == 1 -// CHECK: %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]] -// CHECK: %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]] // CHECK: %[[CONTRACT1_B:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> + +/// w == 0, kw == 0 +// CHECK: vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +/// w == 0, kw == 1 +// CHECK: vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] +/// w == 1, kw == 0 +// CHECK: vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +/// w == 1, kw == 1 // CHECK: vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]] // ----- - - // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> @@ -127,22 +137,27 @@ // CHECK: %[[V_FILTER_000:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> // CHECK: %[[V_INPUT_000:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> // CHECK: %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> +/// w == 0, kw == 1 +// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> +// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> +// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> + +/// w == 0, kw == 0 // CHECK: %[[CONTRACT0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT_000]], %[[V_FILTER_000]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> -// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] - /// w == 0, kw == 1 -// CHECK: %[[V_FILTER_100:.+]] = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<3x8xf32> -// CHECK: %[[V_INPUT_020:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x3xf32> -// CHECK: %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]{{.*}} vector<4x2x8xf32> // CHECK: %[[CONTRACT1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT_020]], %[[V_FILTER_100]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> + +/// w == 0, kw == 0 +// CHECK: vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +/// w == 0, kw == 1 // CHECK: vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] linalg.conv_1d_nwc_wcf {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}