diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1390,16 +1390,25 @@ // Convolution vectorization patterns //===----------------------------------------------------------------------===// namespace { -/// Generate a vector implementation for: +/// Generate a vector implementation for either: /// ``` /// Op def: ( n, w, c, kw, f ) /// Iters: ({Par(), Par(), Par(), Red(), Red()}) /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} /// ``` /// kw is unrolled, w is unrolled iff dilationW > 1. -struct Conv1D_NWC_WCF_Generator : public StructuredGenerator { - Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, - int dilationW) +/// +/// or +/// +/// ``` +/// Op def: ( n, w, c, kw ) +/// Iters: ({Par(), Par(), Par(), Red()}) +/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} +/// ``` +/// kw is unrolled, w is unrolled iff dilationW > 1. +struct Conv1D_NWC_Generator : public StructuredGenerator { + Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + int dilationW) : StructuredGenerator(builder, linalgOp), valid(false), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator @@ -1413,7 +1422,8 @@ resShapedType = resShaped.getType().dyn_cast(); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; - if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 || + if (lhsShapedType.getRank() != 3 || + (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) || resShapedType.getRank() != 3) return; @@ -1553,12 +1563,130 @@ /*iteratorTypes=*/ArrayRef{par, par, par, red}); } + /// Generate a vector implementation for: + /// ``` + /// Op def: ( n, w, c, kw) + /// Iters: ({Par(), Par(), Par(), Red()}) + /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} + /// ``` + /// kw is always unrolled. + /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. + FailureOr dilated_conv() { + if (!valid) + return failure(); + + int nSize = lhsShapedType.getShape()[0]; + int wSize = resShapedType.getShape()[1]; + int cSize = lhsShapedType.getShape()[2]; + int kwSize = rhsShapedType.getShape()[0]; + + vector::TransferWriteOp write; + Value zero = builder.create(loc, 0); + + // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. + // When strideW == 1, we can batch the contiguous loads and avoid unrolling + int64_t wSizeStep = strideW == 1 ? wSize : 1; + + Type lhsEltType = lhsShapedType.getElementType(); + Type rhsEltType = rhsShapedType.getElementType(); + Type resEltType = resShapedType.getElementType(); + VectorType lhsType = VectorType::get( + {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1, + cSize}, + lhsEltType); + VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType); + VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType); + + // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0]. + Value lhs = builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + // Read rhs slice of size {kw, c} @ [0, 0]. + Value rhs = builder.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero}); + // Read res slice of size {n, w, c} @ [0, 0, 0]. + Value res = builder.create( + loc, resType, resShaped, ValueRange{zero, zero, zero}); + + //===------------------------------------------------------------------===// + // Begin vector-only rewrite part + //===------------------------------------------------------------------===// + // Unroll along kw and read slices of lhs and rhs. + SmallVector lhsVals, rhsVals, resVals; + for (int64_t kw = 0; kw < kwSize; ++kw) { + // Extract rhs slice of size {c} @ [kw]. + rhsVals.push_back(builder.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); + + for (int64_t w = 0; w < wSize; w += wSizeStep) { + // Extract lhs slice of size {n, wSizeStep, c} + // @ [0, sw * w + dw * kw, 0]. + lhsVals.push_back(builder.create( + loc, lhs, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + + // This does not depend on kw. + if (kw == 0) { + // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. + resVals.push_back(builder.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + } + } + } + + auto linearIndex = [&](int64_t kw, int64_t w) { + return kw * (wSize / wSizeStep) + w; + }; + + // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} + for (int64_t kw = 0; kw < kwSize; ++kw) { + for (int64_t w = 0; w < wSize; w += wSizeStep) { + resVals[w] = dilatedConv1dSliceAsContraction( + builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + } + } + + // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. + // This does not depend on kw. + for (int64_t w = 0; w < wSize; w += wSizeStep) { + res = builder.create( + loc, resVals[w], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); + } + //===------------------------------------------------------------------===// + // End vector-only rewrite part + //===------------------------------------------------------------------===// + + // Write back res slice of size {n, w, c} @ [0, 0, 0]. + return builder + .create(loc, res, resShaped, + ValueRange{zero, zero, zero}) + .getOperation(); + } + + // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c} + vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b, + Location loc, Value lhs, + Value rhs, Value res) { + StringRef par = Par().strRef, red = Red().strRef; + AffineExpr n, w, c; + bindDims(ctx, n, w, c); + return builder.create( + loc, lhs, rhs, res, + /*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}}, + /*iteratorTypes=*/ArrayRef{par, par, red}); + } + /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} FailureOr generateConv() { AffineExpr n, w, f, kw, c; bindDims(ctx, n, w, f, kw, c); - if (!iters({Par(), Par(), Par(), Red(), Red()})) return failure(); @@ -1570,6 +1698,22 @@ return failure(); } + /// Entry point that transposes into the common form: + /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} + FailureOr generateDilatedConv() { + AffineExpr n, w, c, kw; + bindDims(ctx, n, w, c, kw); + if (!iters({Par(), Par(), Par(), Red()})) + return failure(); + + // No transposition needed. + if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, + /*rhsIndex*/ {kw, c}, + /*resIndex*/ {n, w, c}})) + return dilated_conv(); + return failure(); + } + private: bool valid; int strideW, dilationW; @@ -1588,8 +1732,11 @@ auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; LinalgOp linalgOp = cast(convOp.getOperation()); - Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation); - return e.generateConv(); + Conv1D_NWC_Generator e(b, linalgOp, stride, dilation); + auto res = e.generateConv(); + if (succeeded(res)) + return res; + return e.generateDilatedConv(); } struct VectorizeConvolution diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -180,7 +180,7 @@ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> -/// w == 1, kw == 1 +/// w == 0, kw == 1 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]} @@ -189,3 +189,52 @@ // Write the result back in one shot. // CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// ----- + +func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { + linalg.depthwise_conv1D_nw + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) + outs(%output : memref<3x2x4xf32>) + return +} + +// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> + +// CHECK: func @depthwise_conv1d_nwc_3x5x4_memref +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] +// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +/// w == 0, kw == 0 +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32> +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32> +/// w == 0, kw == 1 +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32> + +/// w == 0, kw == 0 +// CHECK: %[[CONTRACT_0:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] +// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32> +/// w == 0, kw == 1 +// CHECK: %[[CONTRACT_1:.+]] = vector.contract { +// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] +// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32> + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]