diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -754,20 +754,19 @@ /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D /// convolution ops. +template struct DownscaleSizeOneWindowed2DConvolution final - : public OpRewritePattern { + : public OpRewritePattern { DownscaleSizeOneWindowed2DConvolution( MLIRContext *context, LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - filter(std::move(f)) {} + : OpRewritePattern(context, benefit), filter(std::move(f)) {} - FailureOr - returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const; + FailureOr returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const; - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + LogicalResult matchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -76,10 +76,18 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { - FailureOr windowed = - tryApply(target); - if (succeeded(windowed)) { - results.push_back(*windowed); + FailureOr windowedNhwc = + tryApply>(target); + if (succeeded(windowedNhwc)) { + results.push_back(*windowedNhwc); + return DiagnosedSilenceableFailure(success()); + } + FailureOr windowedNchw = + tryApply>(target); + if (succeeded(windowedNchw)) { + results.push_back(*windowedNchw); return DiagnosedSilenceableFailure(success()); } FailureOr depthwise = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -828,9 +828,9 @@ // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. -FailureOr -DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( - linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { +template +FailureOr DownscaleSizeOneWindowed2DConvolution:: + returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, convOp))) return failure(); if (convOp.hasBufferSemantics()) @@ -847,10 +847,30 @@ auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); + // Get domain indices based on conv2D layout. + int khIndex, kwIndex, ohIndex, owIndex; + + TypeSwitch(convOp) + .Case([&](linalg::Conv2DNhwcHwcfOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::Conv2DNchwFchwOp op) { + khIndex = 2; + kwIndex = 3; + ohIndex = 2; + owIndex = 3; + }) + .Default([&](Operation *op) { + llvm_unreachable("unexpected conv2d operation."); + }); + // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; - int64_t ohSize = outputShape[1], owSize = outputShape[2]; + int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; + int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) @@ -860,11 +880,11 @@ // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); // Rank-reduce operands. Location loc = convOp.getLoc(); @@ -877,16 +897,17 @@ // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. - auto strides = llvm::to_vector<4>(convOp.getStrides().getValues()); + auto strides = + llvm::to_vector<4>(convOp.getStrides().template getValues()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = - llvm::to_vector<4>(convOp.getDilations().getValues()); + llvm::to_vector<4>(convOp.getDilations().template getValues()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create( + auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); @@ -973,7 +994,10 @@ void linalg::populateDecomposeConvolutionPatterns( RewritePatternSet &patterns, const LinalgTransformationFilter &filter, PatternBenefit benefit) { - patterns.add, + DownscaleSizeOneWindowed2DConvolution, DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, benefit); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -18,6 +18,24 @@ return %0 : tensor } +// CHECK-LABEL: @conv_2d_nchw_fchw +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor) +func.func @conv_2d_nchw_fchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32> // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>