diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -43,8 +43,9 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) -static FailureOr -vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp); +/// Try to vectorize `convOp` as a convolution. +static FailureOr vectorizeConvolution(OpBuilder &b, + LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. @@ -636,13 +637,12 @@ SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic // features. Will require stride/dilation attributes inference. - if (auto convOp = dyn_cast(linalgOp.getOperation())) { - LDBG("Vectorize as a conv: " << linalgOp); - FailureOr convOr = vectorizeConvolution(rewriter, convOp); - if (failed(convOr)) - return failure(); + FailureOr convOr = vectorizeConvolution(rewriter, linalgOp); + if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); } else { + if (failed(vectorizeLinalgOpPrecondition(linalgOp))) + return failure(); LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) return failure(); @@ -1640,40 +1640,39 @@ }; } // namespace -/// Helper function to vectorize a `linalgOp` with convolution semantics. +/// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr -vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) { - // TODO: these are legitimately part of ConvolutionOpInterface. - auto strides = convOp->getAttrOfType("strides"); - auto dilations = convOp->getAttrOfType("dilations"); +static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { + // The ConvolutionOpInterface gives us guarantees of existence for + // strides/dilations. However, we do not need to rely on those, we can simply + // use them if present, otherwise use the default and let the generic conv. + // matcher in the ConvGenerator succeed or fail. + auto strides = op->getAttrOfType("strides"); + auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; - LinalgOp linalgOp = cast(convOp.getOperation()); - Conv1DNwcGenerator e(b, linalgOp, stride, dilation); + Conv1DNwcGenerator e(b, op, stride, dilation); auto res = e.generateConv(); if (succeeded(res)) return res; return e.generateDilatedConv(); } -struct VectorizeConvolution - : public OpInterfaceRewritePattern { +struct VectorizeConvolution : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(ConvolutionOpInterface convOp, + LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - FailureOr resultOrFail = - vectorizeConvolution(rewriter, convOp); + FailureOr resultOrFail = vectorizeConvolution(rewriter, op); if (failed(resultOrFail)) return failure(); Operation *newOp = *resultOrFail; if (newOp->getNumResults() == 0) { - rewriter.eraseOp(convOp.getOperation()); + rewriter.eraseOp(op.getOperation()); return success(); } assert(newOp->getNumResults() == 1 && "expected single result"); - rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0)); + rewriter.replaceOp(op.getOperation(), newOp->getResult(0)); return success(); } };