diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1495,6 +1495,98 @@ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); +template +FailureOr DownscaleSizeOneWindowed2DConvolution:: + returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, convOp))) + return failure(); + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.getInputs().front(); + Value kernel = convOp.getInputs().back(); + Value output = convOp.getOutputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Get domain indices based on conv2D layout. + int khIndex, kwIndex, ohIndex, owIndex; + + TypeSwitch(convOp) + .Case([&](linalg::Conv2DNhwcHwcfOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::Conv2DNchwFchwOp op) { + khIndex = 2; + kwIndex = 3; + ohIndex = 2; + owIndex = 3; + }) + .Default([&](Operation *op) { + llvm_unreachable("unexpected conv2d operation."); + }); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; + int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); + + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = + llvm::to_vector<4>(convOp.getStrides().template getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); + + auto dilations = + llvm::to_vector<4>(convOp.getDilations().template getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); + + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); + return conv1DOp; +} + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -820,106 +820,14 @@ return success(); } -// The following are patterns for downscaling convolution ops with size-1 -// window dimensions. +// The following, along with DownscaleSizeOneWindowed2DConvolution, are +// patterns for downscaling convolution ops with size-1 window dimensions. // // Note that we'd eventually want to write such transformations in a generic // way, e.g., converting to linalg.generic, removing the size-1 dimensions, // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. -template -FailureOr DownscaleSizeOneWindowed2DConvolution:: - returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); - if (convOp.hasBufferSemantics()) - return failure(); // To be implemented. - - Value input = convOp.getInputs().front(); - Value kernel = convOp.getInputs().back(); - Value output = convOp.getOutputs().front(); - - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); - - auto kernelShape = kernelType.getShape(); - auto outputShape = outputType.getShape(); - - // Get domain indices based on conv2D layout. - int khIndex, kwIndex, ohIndex, owIndex; - - TypeSwitch(convOp) - .Case([&](linalg::Conv2DNhwcHwcfOp op) { - khIndex = 0; - kwIndex = 1; - ohIndex = 1; - owIndex = 2; - }) - .Case([&](linalg::Conv2DNchwFchwOp op) { - khIndex = 2; - kwIndex = 3; - ohIndex = 2; - owIndex = 3; - }) - .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d operation."); - }); - - // Only handle the case where at least one of the window dimensions is - // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; - int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; - bool removeH = (khSize == 1 && ohSize == 1); - bool removeW = (kwSize == 1 && owSize == 1); - if (!removeH && !removeW) - return failure(); - - // Get new shapes and types for all operands by removing the size-1 - // dimension. - using RTTBuilder = RankedTensorType::Builder; - RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); - RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); - RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); - - // Rank-reduce operands. - Location loc = convOp.getLoc(); - Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, input, newInputType); - Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, kernel, newKernelType); - Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, output, newOutputType); - - // Rank-reduce strides and dilations too. - // TODO: dropDim 1-liner helper. - auto strides = - llvm::to_vector<4>(convOp.getStrides().template getValues()); - strides.erase(strides.begin() + (removeH ? 0 : 1)); - auto stridesAttr = rewriter.getI64VectorAttr(strides); - - auto dilations = - llvm::to_vector<4>(convOp.getDilations().template getValues()); - dilations.erase(dilations.begin() + (removeH ? 0 : 1)); - auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - - auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newKernel}, - ValueRange{newOutput}, stridesAttr, dilationsAttr); - - // Insert back. - Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, conv1DOp.getResult(0), output); - rewriter.replaceOp(convOp, inserted); - - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); - return conv1DOp; -} - template struct linalg::DownscaleSizeOneWindowed2DConvolution; template struct linalg::DownscaleSizeOneWindowed2DConvolution