diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -582,20 +582,13 @@ /// equal than their counterpart interation space sizes, if static. /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. -LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, +LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, bool vectorizeNDExtract = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); -/// Vectorize a `padOp` with (1) static result type, (2) constant padding value -/// and (3) all-zero lowPad to -/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. -FailureOr -maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp, - ArrayRef inputVectorSizes); - /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. FailureOr linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2927,26 +2927,15 @@ // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { - if (auto padOp = dyn_cast(target)) { - FailureOr maybeWriteOp = - maskedVectorize(rewriter, padOp, vectorSizes); - if (failed(maybeWriteOp)) { - return mlir::emitSilenceableFailure(target->getLoc()) - << "failed to vectorize padOp"; - } - continue; - } - - auto linalgOp = dyn_cast(target); - if (!linalgOp) { + if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) - << "cannot vectorize non-Linalg op"; + << "cannot vectorize op "; } - if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes, + if (failed(linalg::vectorize(rewriter, target, vectorSizes, getVectorizeNdExtract()))) { return mlir::emitSilenceableFailure(target->getLoc()) - << "failed to vectorize linalg op"; + << "failed to vectorize linalg/pad op"; } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1280,6 +1280,49 @@ return success(); } +/// Vectorize a `padOp` with (1) static result type, (2) constant padding value +/// and (3) all-zero lowPad to +/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. +static LogicalResult +vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + auto padValue = padOp.getConstantPaddingValue(); + Location loc = padOp.getLoc(); + int64_t rank = inputVectorSizes.size(); + auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); + auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); + + // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(padOp); + auto zero = rewriter.create(loc, 0); + auto emptyOp = + rewriter.create(loc, padOp.getResultType(), + /*dynamicSizes=*/ValueRange{}); + SmallVector mixedSourceDims = + getMixedDimensions(rewriter, loc, padOp.getSource()); + Value mask = + rewriter.create(loc, maskType, mixedSourceDims); + auto transferReadOp = rewriter.create( + loc, + /*vectorType=*/vectorType, + /*source=*/padOp.getSource(), + /*indices=*/SmallVector(rank, zero), + /*padding=*/padValue, + /*inBounds=*/SmallVector(rank, true)); + auto maskedOp = cast( + mlir::vector::maskOperation(rewriter, transferReadOp, mask)); + auto transferWriteOp = rewriter.create( + loc, + /*vector=*/maskedOp->getResult(0), + /*source=*/emptyOp, + /*indices=*/SmallVector(rank, zero), + /*inBounds=*/SmallVector(rank, true)); + newResults.push_back(transferWriteOp.getResult()); + return success(); +} + // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { @@ -1392,78 +1435,46 @@ return success(); } -/// Converts affine.apply Ops to arithmetic operations. -static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { - OpBuilder::InsertionGuard g(rewriter); - auto toReplace = linalgOp.getBlock()->getOps(); - - for (auto op : make_early_inc_range(toReplace)) { - rewriter.setInsertionPoint(op); - auto expanded = affine::expandAffineExpr( - rewriter, op->getLoc(), op.getAffineMap().getResult(0), - op.getOperands().take_front(op.getAffineMap().getNumDims()), - op.getOperands().take_back(op.getAffineMap().getNumSymbols())); - rewriter.replaceOp(op, expanded); - } -} - -FailureOr -mlir::linalg::maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp, - ArrayRef inputVectorSizes) { +static LogicalResult +vectorizePadOpPrecondition(tensor::PadOp padOp, + ArrayRef inputVectorSizes) { auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { LDBG("pad value is not constant: " << padOp << "\n"); - return rewriter.notifyMatchFailure(padOp, "pad value is not constant"); + return failure(); } ArrayRef resultTensorShape = padOp.getResultType().getShape(); if (!(resultTensorShape == inputVectorSizes)) { LDBG("result tensor shape must match input vector sizes: " << padOp << "\n"); - return rewriter.notifyMatchFailure( - padOp, "result tensor shape must match input vector sizes"); + // return failure(); } + if (llvm::any_of(padOp.getLow(), [](Value v) { std::optional res = getConstantIntValue(v); return !res.has_value() || res.value() != 0; })) { LDBG("low pad must all be zero: " << padOp << "\n"); - return rewriter.notifyMatchFailure(padOp, "low pad must all be zero"); + return failure(); } - Location loc = padOp.getLoc(); - int64_t rank = inputVectorSizes.size(); - auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); - auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); + return success(); +} - // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) +/// Converts affine.apply Ops to arithmetic operations. +static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(padOp); - auto zero = rewriter.create(loc, 0); - auto emptyOp = - rewriter.create(loc, padOp.getResultType(), - /*dynamicSizes=*/ValueRange{}); - SmallVector mixedSourceDims = - getMixedDimensions(rewriter, loc, padOp.getSource()); - Value mask = - rewriter.create(loc, maskType, mixedSourceDims); - auto transferReadOp = rewriter.create( - loc, - /*vectorType=*/vectorType, - /*source=*/padOp.getSource(), - /*indices=*/SmallVector(rank, zero), - /*padding=*/padValue, - /*inBounds=*/SmallVector(rank, true)); - auto maskedOp = cast( - mlir::vector::maskOperation(rewriter, transferReadOp, mask)); - auto transferWriteOp = rewriter.create( - loc, - /*vector=*/maskedOp->getResult(0), - /*source=*/emptyOp, - /*indices=*/SmallVector(rank, zero), - /*inBounds=*/SmallVector(rank, true)); - rewriter.replaceOp(padOp, transferWriteOp->getResults()); - return transferWriteOp; + auto toReplace = linalgOp.getBlock()->getOps(); + + for (auto op : make_early_inc_range(toReplace)) { + rewriter.setInsertionPoint(op); + auto expanded = affine::expandAffineExpr( + rewriter, op->getLoc(), op.getAffineMap().getResult(0), + op.getOperands().take_front(op.getAffineMap().getNumDims()), + op.getOperands().take_back(op.getAffineMap().getNumSymbols())); + rewriter.replaceOp(op, expanded); + } } /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` @@ -1472,55 +1483,79 @@ /// greater than or equal to their counterpart iteration space sizes, if static. /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes, bool vectorizeNDExtract) { - LDBG("Attempting to vectorize:\n" << linalgOp << "\n"); + LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); - if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, - vectorizeNDExtract))) { + auto precondResult = + TypeSwitch(op) + .Case([&](auto linalgOp) { + return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, + vectorizeNDExtract); + }) + .Case([&](auto padOp) { + return vectorizePadOpPrecondition(padOp, inputVectorSizes); + }) + .Default([](auto) { return failure(); }); + if (failed(precondResult)) { LDBG("Vectorization pre-conditions failed\n"); return failure(); } // Initialize vectorization state. VectorizationState state(rewriter); - if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) { - LDBG("Vectorization state couldn't be initialized\n"); - return failure(); + if (auto linalgOp = dyn_cast(op)) { + if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) { + LDBG("Vectorization state couldn't be initialized\n"); + return failure(); + } } SmallVector results; - // TODO: isaConvolutionOpInterface that can also infer from generic - // features. Will require stride/dilation attributes inference. - FailureOr convOr = vectorizeConvolution(rewriter, linalgOp); - if (succeeded(convOr)) { - llvm::append_range(results, (*convOr)->getResults()); - } else { - if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, - vectorizeNDExtract))) - return failure(); - LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); - - // Pre-process before proceeding. - convertAffineApply(rewriter, linalgOp); - - // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to - // 'OpBuilder' when it is passed over to some methods like - // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op - // within these methods, the actual rewriter won't be notified and we will - // end up with read-after-free issues! - if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) - return failure(); + auto vectorizeResult = + TypeSwitch(op) + .Case([&](auto linalgOp) { + // TODO: isaConvolutionOpInterface that can also infer from generic + // features. Will require stride/dilation attributes inference. + FailureOr convOr = + vectorizeConvolution(rewriter, linalgOp); + if (succeeded(convOr)) { + llvm::append_range(results, (*convOr)->getResults()); + return success(); + } + + LDBG("Vectorize generic by broadcasting to the canonical vector " + "shape\n"); + + // Pre-process before proceeding. + convertAffineApply(rewriter, linalgOp); + + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted + // to 'OpBuilder' when it is passed over to some methods like + // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we + // erase an op within these methods, the actual rewriter won't be + // notified and we will end up with read-after-free issues! + return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); + }) + .Case([&](auto padOp) { + return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, + results); + }) + .Default([](auto) { return failure(); }); + + if (failed(vectorizeResult)) { + LDBG("Vectorization failed\n"); + return failure(); } if (!results.empty()) - rewriter.replaceOp(linalgOp, results); + rewriter.replaceOp(op, results); else - rewriter.eraseOp(linalgOp); + rewriter.eraseOp(op); return success(); }