diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -30,6 +30,7 @@ class OpBuilder; class TypeRange; class ValueRange; +class RewriterBase; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -81,8 +82,8 @@ Red() : IteratorType(IteratorTypeT::reduction) {} }; - StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) - : builder(builder), ctx(op.getContext()), loc(op.getLoc()), + StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op) + : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()), iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), op(op) {} @@ -102,7 +103,7 @@ } protected: - OpBuilder &builder; + RewriterBase &rewriter; MLIRContext *ctx; Location loc; SmallVector iterators; @@ -112,10 +113,12 @@ // Clone the current operation with the operands. This is used to abstract away // the optional underlying region creation. +// Note: this is a true builder that notifies the OpBuilder listener. Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands); // Clone the current operation with the operands but leave the regions empty. +// Note: this is a true builder that notifies the OpBuilder listener. Operation *cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -46,7 +46,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X) /// Try to vectorize `convOp` as a convolution. -static FailureOr vectorizeConvolution(OpBuilder &b, +static FailureOr vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. @@ -174,14 +174,18 @@ vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, value); + return b.createOrFold(loc, targetVectorType, + value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction -/// initial value. -static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, - Value valueToReduce, Value acc, +/// initial value.buildMultiDimReduce +// Note: this is a true builder that notifies the OpBuilder listener. +// TODO: Consider moving as a static helper on the ReduceOp. +static Operation *buildMultiDimReduce(OpBuilder &b, + Operation *reduceOp, Value valueToReduce, + Value acc, const SmallVector &reductionMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); @@ -198,6 +202,8 @@ /// to all `0`; where `outputOperand` is an output operand of the LinalgOp /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. +// Note: this is a true builder that notifies the OpBuilder listener. +// TODO: Consider moving as a static helper on the ReduceOp. static Value buildVectorWrite(OpBuilder &b, Value value, OpOperand *outputOperand) { Operation *write; @@ -217,14 +223,14 @@ SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); - write = b.create(loc, value, outputOperand->get(), - indices, map); + write = b.create( + loc, value, outputOperand->get(), indices, map); } else { if (!value.getType().isa()) value = b.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); - write = b.create(loc, value, outputOperand->get(), - ValueRange{}); + write = b.create( + loc, value, outputOperand->get(), ValueRange{}); } LDBG("vectorized op: " << *write); if (!write->getResults().empty()) @@ -233,7 +239,7 @@ } // Custom vectorization precondition function type. This is intented to be used -// with CustomVectorizationHook. Returns success if the correpsonding custom +// with CustomVectorizationHook. Returns success if the corresponding custom // hook can vectorize the op. using CustomVectorizationPrecondition = std::function; @@ -248,11 +254,11 @@ /// vector values are appended to `newResults`. Return /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it /// should not try to map produced operations and instead return the results -/// using the `newResults` vector making them available to the -/// vectorization algorithm for RAUW. This function is meant to be used as a +/// using the `newResults` vector making them available to the vectorization +/// algorithm for RAUW. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeLinalgYield(OpBuilder &b, Operation *op, +vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); @@ -263,7 +269,7 @@ // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); + rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); if (newResult) newResults.push_back(newResult); } @@ -274,8 +280,8 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op, - LinalgOp linalgOp) { +static VectorizationResult +vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; @@ -285,8 +291,8 @@ // Compute a one-dimensional index vector for the index op dimension. SmallVector constantSeq = llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); - auto constantOp = - b.create(loc, b.getIndexVectorAttr(constantSeq)); + auto constantOp = rewriter.create( + loc, rewriter.getIndexVectorAttr(constantSeq)); // Return the one-dimensional index vector if it lives in the trailing // dimension of the iteration space since the vectorization algorithm in this // case can handle the broadcast. @@ -296,13 +302,13 @@ // broadcast the one-dimensional index vector to the permuted shape, and // finally transpose the broadcasted index vector to undo the permutation. std::swap(targetShape[indexOp.getDim()], targetShape.back()); - auto broadCastOp = b.create( - loc, VectorType::get(targetShape, b.getIndexType()), constantOp); + auto broadCastOp = rewriter.create( + loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp); SmallVector transposition = llvm::to_vector<16>(llvm::seq(0, linalgOp.getNumLoops())); std::swap(transposition.back(), transposition[indexOp.getDim()]); auto transposeOp = - b.create(loc, broadCastOp, transposition); + rewriter.create(loc, broadCastOp, transposition); return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; } @@ -334,7 +340,7 @@ /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, +vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp, const BlockAndValueMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) @@ -350,19 +356,19 @@ auto targetShape = linalgOp.computeStaticLoopSizes(); SmallVector gatherIndices; - gatherIndices.push_back(b.create(loc, 0)); + gatherIndices.push_back(rewriter.create(loc, 0)); - auto maskConstantOp = b.create( - loc, - DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()), - /*value=*/true)); + auto maskConstantOp = rewriter.create( + loc, DenseIntElementsAttr::get( + VectorType::get(targetShape, rewriter.getI1Type()), + /*value=*/true)); auto resultType = VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = - b.create(loc, b.getZeroAttr(resultType)); + rewriter.create(loc, rewriter.getZeroAttr(resultType)); - auto gatherOp = b.create( + auto gatherOp = rewriter.create( loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, maskConstantOp, passThruConstantOp); @@ -371,8 +377,11 @@ /// Emit reduction operations if the shapes of the value to reduce is different /// that the result shape. -static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, - Value reduceValue, Value initialValue, +// Note: this is a true builder that notifies the OpBuilder listener. +// TODO: Consider moving as a static helper on the ReduceOp. +static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, + Operation *op, Value reduceValue, + Value initialValue, const BlockAndValueMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); @@ -402,12 +411,12 @@ /// otherwise, it means one of the `customVectorizationHooks` is incorrect. /// /// This function assumes all operands of `op` have been vectorized and are in -/// the `bvm` mapping. As a consequence, this function is meant to be called on +/// the `bvm` mapping. As a consequence, this function is meant to be called on /// a topologically-sorted list of ops. /// This function does not update `bvm` but returns a VectorizationStatus that /// instructs the caller what `bvm` update needs to occur. static VectorizationResult -vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, +vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { LDBG("vectorize op " << *op); @@ -425,7 +434,7 @@ // 2. Constant ops don't get vectorized but rather broadcasted at their users. // Clone so that the constant is not confined to the linalgOp block . if (isa(op)) - return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)}; + return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)}; // 3. Only ElementwiseMappable are allowed in the generic vectorization. if (!OpTrait::hasElementwiseMappableTraits(op)) @@ -448,7 +457,7 @@ if (!reductionOperands.empty()) { assert(reductionOperands.size() == 1); Operation *reduceOp = - reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first, + reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first, reductionOperands[0].second, bvm); if (reduceOp) return VectorizationResult{VectorizationStatus::NewOp, reduceOp}; @@ -462,11 +471,12 @@ if (vt && firstMaxRankedShape.size() < vt.getShape().size()) firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); } - // b. broadcast each op if needed. + // rewriter. broadcast each op if needed. auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { return firstMaxRankedShape.empty() ? bvm.lookup(v) - : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape); + : broadcastIfNeeded(rewriter, bvm.lookup(v), + firstMaxRankedShape); }); // c. for elementwise, the result is the vector with the firstMaxRankedShape auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { @@ -478,9 +488,9 @@ // Build and return the new op. return VectorizationResult{ VectorizationStatus::NewOp, - b.create(op->getLoc(), op->getName().getIdentifier(), - llvm::to_vector<4>(vectorizedOperands), - llvm::to_vector<4>(returnTypes), op->getAttrs())}; + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + llvm::to_vector<4>(vectorizedOperands), + llvm::to_vector<4>(returnTypes), op->getAttrs())}; } /// Generic vectorization function that rewrites the body of a `linalgOp` into @@ -492,8 +502,8 @@ /// load). /// TODO: Reuse opportunities for RAR dependencies. /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. -/// 4b. Register CustomVectorizationHook for IndexOp to access the iteration -/// indices. +/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the +/// iteration indices. /// 5. Iteratively call vectorizeOneOp on the region operations. /// /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is @@ -506,7 +516,7 @@ /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp, SmallVectorImpl &newResults) { Block *block = linalgOp.getBlock(); @@ -527,7 +537,7 @@ // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); - Value zero = b.create(loc, 0); + Value zero = rewriter.create(loc, 0); for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); if (linalgOp.isScalar(opOperand)) { @@ -555,12 +565,12 @@ auto shape = linalgOp.getShape(opOperand); SmallVector indices(shape.size(), zero); - Value readValue = b.create( + Value readValue = rewriter.create( loc, readType, opOperand->get(), indices, map); // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) - readValue = b.create(loc, readValue); + readValue = rewriter.create(loc, readValue); LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); @@ -572,15 +582,15 @@ CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults); + return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults); }; hooks.push_back(vectorizeYield); - // 4b. Register CustomVectorizationHook for indexOp. + // 4rewriter. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgIndex(b, op, linalgOp); + return vectorizeLinalgIndex(rewriter, op, linalgOp); }; hooks.push_back(vectorizeIndex); @@ -588,13 +598,14 @@ CustomVectorizationHook vectorizeExtract = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeTensorExtract(b, op, linalgOp, bvm); + return vectorizeTensorExtract(rewriter, op, linalgOp, bvm); }; hooks.push_back(vectorizeExtract); // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { - VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); + VectorizationResult result = + vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LDBG("failed to vectorize: " << op); return failure(); @@ -760,14 +771,14 @@ /// Given an ArrayRef of OpFoldResults, return a vector of Values. /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are /// not supported. -static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, +static SmallVector ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef ofrs) { SmallVector result; for (auto o : ofrs) { if (auto val = o.template dyn_cast()) { result.push_back(val); } else { - result.push_back(builder.create( + result.push_back(rewriter.create( loc, getIntFromAttr(o.template get()))); } } @@ -1415,9 +1426,9 @@ /// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1DGenerator : public StructuredGenerator { - Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW, int dilationW) - : StructuredGenerator(builder, linalgOp), + : StructuredGenerator(rewriter, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) @@ -1481,8 +1492,7 @@ /// > 1. FailureOr conv(Conv1DOpOrder conv1DOpOrder) { if (!valid) - return IRRewriter(builder).notifyMatchFailure(op, - "unvectorizable 1-D conv"); + return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv"); int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; @@ -1519,7 +1529,7 @@ } vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = rewriter.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1534,13 +1544,13 @@ auto resType = VectorType::get(resShape, resEltType); // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, // 0]. - Value lhs = builder.create( + Value lhs = rewriter.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - Value rhs = builder.create( + Value rhs = rewriter.create( loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = builder.create( + Value res = rewriter.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: @@ -1554,13 +1564,13 @@ // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array permLhs = {0, 2, 1}; - lhs = builder.create(loc, lhs, permLhs); + lhs = rewriter.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; - rhs = builder.create(loc, rhs, permRhs); + rhs = rewriter.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; - res = builder.create(loc, res, permRes); + res = rewriter.create(loc, res, permRes); break; } } @@ -1573,7 +1583,7 @@ // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(rewriter.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1582,12 +1592,12 @@ } // Extract rhs slice of size {c, f} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(rewriter.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(rewriter.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, @@ -1602,14 +1612,14 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = conv1dSliceAsContraction( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = rewriter.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1628,26 +1638,26 @@ case Conv1DOpOrder::Ncw: { // nwf -> nfw static constexpr std::array perm = {0, 2, 1}; - res = builder.create(loc, res, perm); + res = rewriter.create(loc, res, perm); break; } } // Write back res slice of size {n, w, f} @ [0, 0, 0]. - return builder + return rewriter .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, - Value rhs, Value res) { + Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); - return builder.create( + return rewriter.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); @@ -1664,8 +1674,7 @@ /// > 1. FailureOr depthwiseConv() { if (!valid) - return IRRewriter(builder).notifyMatchFailure( - op, "unvectorizable depthwise conv"); + return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv"); int64_t nSize, wSize, cSize, kwSize; // kernel{kw, c} @@ -1674,7 +1683,7 @@ bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = rewriter.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1696,13 +1705,13 @@ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. - Value lhs = builder.create( + Value lhs = rewriter.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = builder.create(loc, rhsType, rhsShaped, - ValueRange{zero, zero}); + Value rhs = rewriter.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero}); // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = builder.create( + Value res = rewriter.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); //===------------------------------------------------------------------===// @@ -1714,7 +1723,7 @@ // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(rewriter.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1723,12 +1732,12 @@ } // Extract rhs slice of size {c} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(rewriter.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(rewriter.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1743,18 +1752,23 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = depthwiseConv1dSliceAsMulAcc( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } - // Its possible we failed to create the Fma - if (!llvm::all_of(resVals, [](Value v) { return v; })) - return IRRewriter(builder).notifyMatchFailure(op, "failed to create FMA"); + // Its possible we failed to create the Fma. + if (!llvm::all_of(resVals, [](Value v) { return v; })) { + // Manually revert (in reverse order) to avoid leaving a bad IR state. + for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) + for (Value v : collection) + rewriter.eraseOp(v.getDefiningOp()); + return rewriter.notifyMatchFailure(op, "failed to create FMA"); + } // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = rewriter.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1764,14 +1778,14 @@ //===------------------------------------------------------------------===// // Write back res slice of size {n, w, c} @ [0, 0, 0]. - return builder + return rewriter .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Take a value of element type T and widen to the destination type. - Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { if (val.getType() == ty) return val; @@ -1780,35 +1794,35 @@ const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return rewriter.create(loc, ty, val); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return rewriter.create(loc, ty, val); return nullptr; } /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc - Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, - Value rhs, Value res) { + Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, + Value lhs, Value rhs, Value res) { auto rhsTy = rhs.getType().cast(); auto resTy = res.getType().cast(); // TODO(suderman): Change this to use a vector.ima intrinsic. - lhs = promote(b, loc, lhs, resTy); + lhs = promote(rewriter, loc, lhs, resTy); - rhs = builder.create( + rhs = rewriter.create( loc, resTy.clone(rhsTy.getElementType()), rhs); - rhs = promote(b, loc, rhs, resTy); + rhs = promote(rewriter, loc, rhs, resTy); if (!lhs || !rhs) return nullptr; if (resTy.getElementType().isa()) - return b.create(loc, lhs, rhs, res); + return rewriter.create(loc, lhs, rhs, res); - auto mul = b.create(loc, lhs, rhs); - return b.create(loc, mul, res); + auto mul = rewriter.create(loc, lhs, rhs); + return rewriter.create(loc, mul, res); } /// Entry point that transposes into the common form: @@ -1817,7 +1831,7 @@ AffineExpr n, w, f, kw, c; bindDims(ctx, n, w, f, kw, c); if (!iters({Par(), Par(), Par(), Red(), Red()})) - return IRRewriter(builder).notifyMatchFailure( + return rewriter.notifyMatchFailure( op, "failed to match conv::Nwc 3-par 2-red"); // No transposition needed. @@ -1825,7 +1839,7 @@ /*rhsIndex*/ {kw, c, f}, /*resIndex*/ {n, w, f}})) return conv(Conv1DOpOrder::Nwc); - return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Nwc layout"); + return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout"); } /// Entry point that transposes into the common form: @@ -1834,7 +1848,7 @@ AffineExpr n, w, f, kw, c; bindDims(ctx, n, f, w, c, kw); if (!iters({Par(), Par(), Par(), Red(), Red()})) - return IRRewriter(builder).notifyMatchFailure( + return rewriter.notifyMatchFailure( op, "failed to match conv::Ncw 3-par 2-red"); if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, @@ -1842,7 +1856,7 @@ /*resIndex*/ {n, f, w}})) return conv(Conv1DOpOrder::Ncw); - return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Ncw layout"); + return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout"); } /// Entry point that transposes into the common form: @@ -1851,7 +1865,7 @@ AffineExpr n, w, c, kw; bindDims(ctx, n, w, c, kw); if (!iters({Par(), Par(), Par(), Red()})) - return IRRewriter(builder).notifyMatchFailure( + return rewriter.notifyMatchFailure( op, "failed to match depthwise::Nwc conv 3-par 1-red"); // No transposition needed. @@ -1860,8 +1874,7 @@ /*resIndex*/ {n, w, c}})) return depthwiseConv(); - return IRRewriter(builder).notifyMatchFailure( - op, "not a depthwise::Nwc layout"); + return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout"); } private: @@ -1874,7 +1887,8 @@ /// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { +static FailureOr vectorizeConvolution(RewriterBase &rewriter, + LinalgOp op) { // The ConvolutionOpInterface gives us guarantees of existence for // strides/dilations. However, we do not need to rely on those, we can simply // use them if present, otherwise use the default and let the generic conv. @@ -1883,7 +1897,7 @@ auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; auto dilation = dilations ? *dilations.getValues().begin() : 1; - Conv1DGenerator e(b, op, stride, dilation); + Conv1DGenerator e(rewriter, op, stride, dilation); auto res = e.generateNwcConv(); if (succeeded(res)) return res; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1522,15 +1522,14 @@ /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) - : StructuredGenerator( - builder, op), + UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) + : StructuredGenerator(b, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), res(op.getAcc()), lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; - return builder.create(loc, v, perm); + return rewriter.create(loc, v, perm); } Value promote(Value v, Type dstElementType) { @@ -1544,20 +1543,20 @@ if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); if (dstElementType.isa()) - return builder.create(loc, promotedType, v); - return builder.create(loc, promotedType, v); + return rewriter.create(loc, promotedType, v); + return rewriter.create(loc, promotedType, v); } Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { - Value a = builder.create(loc, lhs, k); - Value b = builder.create(loc, rhs, k); - a = promote(a, resElementType); - b = promote(b, resElementType); - res = builder.create(loc, res.getType(), a, b, - res, kind); + Value extractA = rewriter.create(loc, lhs, k); + Value extractB = rewriter.create(loc, rhs, k); + extractA = promote(extractA, resElementType); + extractB = promote(extractB, resElementType); + res = rewriter.create(loc, res.getType(), extractA, + extractB, res, kind); } return res; } @@ -1568,7 +1567,7 @@ return failure(); // Set up the parallel/reduction structure in the right form. AffineExpr m, n, k; - bindDims(builder.getContext(), m, n, k); + bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1604,7 +1603,7 @@ if (!iters({Par(), Red()})) return failure(); AffineExpr m, k; - bindDims(builder.getContext(), m, k); + bindDims(rewriter.getContext(), m, k); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) @@ -1628,7 +1627,7 @@ if (!iters({Red(), Par()})) return failure(); AffineExpr k, m; - bindDims(builder.getContext(), k, m); + bindDims(rewriter.getContext(), k, m); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}}))