diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -27,7 +27,7 @@ namespace mlir { -class OpBuilder; +class RewriterBase; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -79,8 +79,8 @@ Red() : IteratorType(IteratorTypeT::reduction) {} }; - StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) - : builder(builder), ctx(op.getContext()), loc(op.getLoc()), + StructuredGenerator(RewriterBase &b, StructuredOpInterface op) + : b(b), ctx(op.getContext()), loc(op.getLoc()), iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), op(op) {} @@ -100,7 +100,7 @@ } protected: - OpBuilder &builder; + RewriterBase &b; MLIRContext *ctx; Location loc; SmallVector iterators; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -46,7 +46,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X) /// Try to vectorize `convOp` as a convolution. -static FailureOr vectorizeConvolution(OpBuilder &b, +static FailureOr vectorizeConvolution(RewriterBase &b, LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. @@ -163,7 +163,7 @@ /// Broadcast `value` to a vector of `shape` if possible. Return value /// otherwise. -static Value broadcastIfNeeded(OpBuilder &b, Value value, +static Value broadcastIfNeeded(RewriterBase &b, Value value, ArrayRef shape) { // If no shape to broadcast to, just return `value`. if (shape.empty()) @@ -180,7 +180,7 @@ /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction /// initial value. -static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, +static Operation *buildMultiDimReduce(RewriterBase &b, Operation *reduceOp, Value valueToReduce, Value acc, const SmallVector &reductionMask) { auto maybeKind = getCombinerOpKind(reduceOp); @@ -198,7 +198,7 @@ /// to all `0`; where `outputOperand` is an output operand of the LinalgOp /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &b, Value value, +static Value buildVectorWrite(RewriterBase &b, Value value, OpOperand *outputOperand) { Operation *write; Location loc = value.getLoc(); @@ -233,7 +233,7 @@ } // Custom vectorization precondition function type. This is intented to be used -// with CustomVectorizationHook. Returns success if the correpsonding custom +// with CustomVectorizationHook. Returns success if the corresponding custom // hook can vectorize the op. using CustomVectorizationPrecondition = std::function; @@ -248,11 +248,11 @@ /// vector values are appended to `newResults`. Return /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it /// should not try to map produced operations and instead return the results -/// using the `newResults` vector making them available to the -/// vectorization algorithm for RAUW. This function is meant to be used as a +/// using the `newResults` vector making them available to the vectorization +/// algorithm for RAUW. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeLinalgYield(OpBuilder &b, Operation *op, +vectorizeLinalgYield(RewriterBase &b, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); @@ -274,7 +274,7 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op, +static VectorizationResult vectorizeLinalgIndex(RewriterBase &b, Operation *op, LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) @@ -334,7 +334,7 @@ /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, +vectorizeTensorExtract(RewriterBase &b, Operation *op, LinalgOp linalgOp, const BlockAndValueMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) @@ -371,8 +371,9 @@ /// Emit reduction operations if the shapes of the value to reduce is different /// that the result shape. -static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, - Value reduceValue, Value initialValue, +static Operation *reduceIfNeeded(RewriterBase &b, LinalgOp linalgOp, + Operation *op, Value reduceValue, + Value initialValue, const BlockAndValueMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); @@ -402,12 +403,12 @@ /// otherwise, it means one of the `customVectorizationHooks` is incorrect. /// /// This function assumes all operands of `op` have been vectorized and are in -/// the `bvm` mapping. As a consequence, this function is meant to be called on +/// the `bvm` mapping. As a consequence, this function is meant to be called on /// a topologically-sorted list of ops. /// This function does not update `bvm` but returns a VectorizationStatus that /// instructs the caller what `bvm` update needs to occur. static VectorizationResult -vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, +vectorizeOneOp(RewriterBase &b, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { LDBG("vectorize op " << *op); @@ -506,7 +507,7 @@ /// This is not deemed a problem as we expect canonicalizations and foldings to /// aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(RewriterBase &b, LinalgOp linalgOp, SmallVectorImpl &newResults) { Block *block = linalgOp.getBlock(); @@ -760,14 +761,14 @@ /// Given an ArrayRef of OpFoldResults, return a vector of Values. /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are /// not supported. -static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, +static SmallVector ofrToIndexValues(RewriterBase &b, Location loc, ArrayRef ofrs) { SmallVector result; for (auto o : ofrs) { if (auto val = o.template dyn_cast()) { result.push_back(val); } else { - result.push_back(builder.create( + result.push_back(b.create( loc, getIntFromAttr(o.template get()))); } } @@ -1422,9 +1423,9 @@ /// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1DGenerator : public StructuredGenerator { - Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + Conv1DGenerator(RewriterBase &b, LinalgOp linalgOp, int strideW, int dilationW) - : StructuredGenerator(builder, linalgOp), + : StructuredGenerator(b, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) @@ -1525,7 +1526,7 @@ } vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = b.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1540,14 +1541,14 @@ auto resType = VectorType::get(resShape, resEltType); // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, // 0]. - Value lhs = builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + Value lhs = b.create(loc, lhsType, lhsShaped, + ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - Value rhs = builder.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + Value rhs = b.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = builder.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + Value res = b.create(loc, resType, resShaped, + ValueRange{zero, zero, zero}); // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: // {n,w,f}. To reuse the base pattern vectorization case, we do pre @@ -1560,13 +1561,13 @@ // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array permLhs = {0, 2, 1}; - lhs = builder.create(loc, lhs, permLhs); + lhs = b.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; - rhs = builder.create(loc, rhs, permRhs); + rhs = b.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; - res = builder.create(loc, res, permRes); + res = b.create(loc, res, permRes); break; } } @@ -1579,7 +1580,7 @@ // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(b.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1588,12 +1589,12 @@ } // Extract rhs slice of size {c, f} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(b.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(b.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, @@ -1608,14 +1609,14 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = conv1dSliceAsContraction( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + b, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = b.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1634,26 +1635,26 @@ case Conv1DOpOrder::Ncw: { // nwf -> nfw static constexpr std::array perm = {0, 2, 1}; - res = builder.create(loc, res, perm); + res = b.create(loc, res, perm); break; } } // Write back res slice of size {n, w, f} @ [0, 0, 0]. - return builder + return b .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, + Value conv1dSliceAsContraction(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); - return builder.create( + return b.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); @@ -1679,7 +1680,7 @@ bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = b.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1701,14 +1702,14 @@ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. - Value lhs = builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + Value lhs = b.create(loc, lhsType, lhsShaped, + ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = builder.create(loc, rhsType, rhsShaped, - ValueRange{zero, zero}); + Value rhs = b.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero}); // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = builder.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + Value res = b.create(loc, resType, resShaped, + ValueRange{zero, zero, zero}); //===------------------------------------------------------------------===// // Begin vector-only rewrite part @@ -1719,7 +1720,7 @@ // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(b.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1728,12 +1729,12 @@ } // Extract rhs slice of size {c} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(b.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(b.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1748,7 +1749,7 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = depthwiseConv1dSliceAsMulAcc( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + b, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } @@ -1761,7 +1762,7 @@ // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = b.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1771,14 +1772,14 @@ //===------------------------------------------------------------------===// // Write back res slice of size {n, w, c} @ [0, 0, 0]. - return builder + return b .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Take a value of element type T and widen to the destination type. - Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + Value promote(RewriterBase &b, Location loc, Value val, Type ty) { if (val.getType() == ty) return val; @@ -1787,16 +1788,16 @@ const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return b.create(loc, ty, val); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return b.create(loc, ty, val); return nullptr; } /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc - Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, + Value depthwiseConv1dSliceAsMulAcc(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { auto rhsTy = rhs.getType().cast(); auto resTy = res.getType().cast(); @@ -1804,7 +1805,7 @@ // TODO(suderman): Change this to use a vector.ima intrinsic. lhs = promote(b, loc, lhs, resTy); - rhs = builder.create( + rhs = b.create( loc, resTy.clone(rhsTy.getElementType()), rhs); rhs = promote(b, loc, rhs, resTy); @@ -1876,7 +1877,8 @@ /// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { +static FailureOr vectorizeConvolution(RewriterBase &b, + LinalgOp op) { // The ConvolutionOpInterface gives us guarantees of existence for // strides/dilations. However, we do not need to rely on those, we can simply // use them if present, otherwise use the default and let the generic conv. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1523,15 +1523,14 @@ /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) - : StructuredGenerator( - builder, op), + UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) + : StructuredGenerator(b, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), res(op.getAcc()), lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; - return builder.create(loc, v, perm); + return b.create(loc, v, perm); } Value promote(Value v, Type dstElementType) { @@ -1545,20 +1544,20 @@ if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); if (dstElementType.isa()) - return builder.create(loc, promotedType, v); - return builder.create(loc, promotedType, v); + return b.create(loc, promotedType, v); + return b.create(loc, promotedType, v); } Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { - Value a = builder.create(loc, lhs, k); - Value b = builder.create(loc, rhs, k); - a = promote(a, resElementType); - b = promote(b, resElementType); - res = builder.create(loc, res.getType(), a, b, - res, kind); + Value extractA = b.create(loc, lhs, k); + Value extractB = b.create(loc, rhs, k); + extractA = promote(extractA, resElementType); + extractB = promote(extractB, resElementType); + res = b.create(loc, res.getType(), extractA, + extractB, res, kind); } return res; } @@ -1569,7 +1568,7 @@ return failure(); // Set up the parallel/reduction structure in the right form. AffineExpr m, n, k; - bindDims(builder.getContext(), m, n, k); + bindDims(b.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1605,7 +1604,7 @@ if (!iters({Par(), Red()})) return failure(); AffineExpr m, k; - bindDims(builder.getContext(), m, k); + bindDims(b.getContext(), m, k); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) @@ -1629,7 +1628,7 @@ if (!iters({Red(), Par()})) return failure(); AffineExpr k, m; - bindDims(builder.getContext(), k, m); + bindDims(b.getContext(), k, m); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}}))