diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -27,7 +27,7 @@ namespace mlir { -class OpBuilder; +class RewriterBase; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -79,8 +79,8 @@ Red() : IteratorType(IteratorTypeT::reduction) {} }; - StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) - : builder(builder), ctx(op.getContext()), loc(op.getLoc()), + StructuredGenerator(RewriterBase &b, StructuredOpInterface op) + : b(b), ctx(op.getContext()), loc(op.getLoc()), iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), op(op) {} @@ -100,7 +100,7 @@ } protected: - OpBuilder &builder; + RewriterBase &b; MLIRContext *ctx; Location loc; SmallVector iterators; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -46,7 +46,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X) /// Try to vectorize `convOp` as a convolution. -static FailureOr vectorizeConvolution(OpBuilder &b, +static FailureOr vectorizeConvolution(RewriterBase &b, LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. @@ -163,7 +163,7 @@ /// Broadcast `value` to a vector of `shape` if possible. Return value /// otherwise. -static Value broadcastIfNeeded(OpBuilder &b, Value value, +static Value broadcastIfNeeded(RewriterBase &b, Value value, ArrayRef shape) { // If no shape to broadcast to, just return `value`. if (shape.empty()) @@ -180,7 +180,7 @@ /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction /// initial value. -static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, +static Operation *buildMultiDimReduce(RewriterBase &b, Operation *reduceOp, Value valueToReduce, Value acc, const SmallVector &reductionMask) { auto maybeKind = getCombinerOpKind(reduceOp); @@ -198,7 +198,7 @@ /// to all `0`; where `outputOperand` is an output operand of the LinalgOp /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &b, Value value, +static Value buildVectorWrite(RewriterBase &b, Value value, OpOperand *outputOperand) { Operation *write; Location loc = value.getLoc(); @@ -232,27 +232,27 @@ return Value(); } -// Custom vectorization precondition function type. This is intented to be used -// with CustomVectorizationHook. Returns success if the correpsonding custom -// hook can vectorize the op. +// Custom vectorization precondition function type. This is intented to be +// used with CustomVectorizationHook. Returns success if the correpsonding +// custom hook can vectorize the op. using CustomVectorizationPrecondition = std::function; // Custom vectorization function type. Produce a vector form of Operation* -// assuming all its vectorized operands are already in the BlockAndValueMapping. -// Return nullptr if the Operation cannot be vectorized. +// assuming all its vectorized operands are already in the +// BlockAndValueMapping. Return nullptr if the Operation cannot be vectorized. using CustomVectorizationHook = std::function; /// Helper function to vectorize the terminator of a `linalgOp`. New result /// vector values are appended to `newResults`. Return -/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it -/// should not try to map produced operations and instead return the results -/// using the `newResults` vector making them available to the +/// VectorizationStatus::NoReplace to signal the vectorization algorithm that +/// it should not try to map produced operations and instead return the +/// results using the `newResults` vector making them available to the /// vectorization algorithm for RAUW. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeLinalgYield(OpBuilder &b, Operation *op, +vectorizeLinalgYield(RewriterBase &b, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); @@ -274,7 +274,7 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op, +static VectorizationResult vectorizeLinalgIndex(RewriterBase &b, Operation *op, LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) @@ -288,8 +288,8 @@ auto constantOp = b.create(loc, b.getIndexVectorAttr(constantSeq)); // Return the one-dimensional index vector if it lives in the trailing - // dimension of the iteration space since the vectorization algorithm in this - // case can handle the broadcast. + // dimension of the iteration space since the vectorization algorithm in + // this case can handle the broadcast. if (indexOp.getDim() == targetShape.size() - 1) return VectorizationResult{VectorizationStatus::NewOp, constantOp}; // Otherwise permute the targetShape to move the index dimension last, @@ -334,7 +334,7 @@ /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. static VectorizationResult -vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, +vectorizeTensorExtract(RewriterBase &b, Operation *op, LinalgOp linalgOp, const BlockAndValueMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) @@ -369,10 +369,11 @@ return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } -/// Emit reduction operations if the shapes of the value to reduce is different -/// that the result shape. -static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, - Value reduceValue, Value initialValue, +/// Emit reduction operations if the shapes of the value to reduce is +/// different that the result shape. +static Operation *reduceIfNeeded(RewriterBase &b, LinalgOp linalgOp, + Operation *op, Value reduceValue, + Value initialValue, const BlockAndValueMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); @@ -387,8 +388,8 @@ return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask); } -/// Generic vectorization for a single operation `op`, given already vectorized -/// operands carried by `bvm`. Vectorization occurs as follows: +/// Generic vectorization for a single operation `op`, given already +/// vectorized operands carried by `bvm`. Vectorization occurs as follows: /// 1. Try to apply any of the `customVectorizationHooks` and return its /// result on success. /// 2. Clone any constant in the current scope without vectorization: each @@ -396,18 +397,18 @@ /// constant needs to be broadcast to. /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose /// of the `customVectorizationHooks` to cover such cases. -/// 4. Clone `op` in vector form to a vector of shape prescribed by the first -/// operand of maximal rank. Other operands have smaller rank and are +/// 4. Clone `op` in vector form to a vector of shape prescribed by the +/// first operand of maximal rank. Other operands have smaller rank and are /// broadcast accordingly. It is assumed this broadcast is always legal, /// otherwise, it means one of the `customVectorizationHooks` is incorrect. /// /// This function assumes all operands of `op` have been vectorized and are in -/// the `bvm` mapping. As a consequence, this function is meant to be called on -/// a topologically-sorted list of ops. -/// This function does not update `bvm` but returns a VectorizationStatus that -/// instructs the caller what `bvm` update needs to occur. +/// the `bvm` mapping. As a consequence, this function is meant to be called +/// on a topologically-sorted list of ops. This function does not update `bvm` +/// but returns a VectorizationStatus that instructs the caller what `bvm` +/// update needs to occur. static VectorizationResult -vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, +vectorizeOneOp(RewriterBase &b, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { LDBG("vectorize op " << *op); @@ -422,8 +423,8 @@ } } - // 2. Constant ops don't get vectorized but rather broadcasted at their users. - // Clone so that the constant is not confined to the linalgOp block . + // 2. Constant ops don't get vectorized but rather broadcasted at their + // users. Clone so that the constant is not confined to the linalgOp block . if (isa(op)) return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)}; @@ -468,7 +469,8 @@ ? bvm.lookup(v) : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape); }); - // c. for elementwise, the result is the vector with the firstMaxRankedShape + // c. for elementwise, the result is the vector with the + // firstMaxRankedShape auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { return firstMaxRankedShape.empty() ? t @@ -488,8 +490,8 @@ /// 1. Verify the `linalgOp` has one non-empty region. /// 2. Values defined above the region are mapped to themselves and will be /// broadcasted on a per-need basis by their consumers. -/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d -/// load). +/// 3. Each region argument is vectorized into a vector.transfer_read (or +/// 0-d load). /// TODO: Reuse opportunities for RAR dependencies. /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. /// 4b. Register CustomVectorizationHook for IndexOp to access the iteration @@ -503,15 +505,15 @@ /// broadcasting makes it trivial to detrmine where broadcast, transposes and /// reductions should occur, without any bookkeeping. The tradeoff is that, in /// the absence of good canonicalizations, the amount of work increases. -/// This is not deemed a problem as we expect canonicalizations and foldings to -/// aggressively clean up the useless work. +/// This is not deemed a problem as we expect canonicalizations and foldings +/// to aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(RewriterBase &b, LinalgOp linalgOp, SmallVectorImpl &newResults) { Block *block = linalgOp.getBlock(); - // 2. Values defined above the region can only be broadcast for now. Make them - // map to themselves. + // 2. Values defined above the region can only be broadcast for now. Make + // them map to themselves. BlockAndValueMapping bvm; SetVector valuesSet; mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); @@ -656,9 +658,9 @@ } if (isElementwise(op)) return success(); - // TODO: isaConvolutionOpInterface that can also infer from generic features. - // But we will still need stride/dilation attributes that will be annoying to - // reverse-engineer... + // TODO: isaConvolutionOpInterface that can also infer from generic + // features. But we will still need stride/dilation attributes that will be + // annoying to reverse-engineer... if (isa(op.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when @@ -760,14 +762,14 @@ /// Given an ArrayRef of OpFoldResults, return a vector of Values. /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are /// not supported. -static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, +static SmallVector ofrToIndexValues(RewriterBase &b, Location loc, ArrayRef ofrs) { SmallVector result; for (auto o : ofrs) { if (auto val = o.template dyn_cast()) { result.push_back(val); } else { - result.push_back(builder.create( + result.push_back(b.create( loc, getIntFromAttr(o.template get()))); } } @@ -1422,9 +1424,9 @@ /// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1DGenerator : public StructuredGenerator { - Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + Conv1DGenerator(RewriterBase &b, LinalgOp linalgOp, int strideW, int dilationW) - : StructuredGenerator(builder, linalgOp), + : StructuredGenerator(b, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) @@ -1525,7 +1527,7 @@ } vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = b.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1540,18 +1542,18 @@ auto resType = VectorType::get(resShape, resEltType); // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, // 0]. - Value lhs = builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + Value lhs = b.create(loc, lhsType, lhsShaped, + ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. - Value rhs = builder.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + Value rhs = b.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. - Value res = builder.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + Value res = b.create(loc, resType, resShaped, + ValueRange{zero, zero, zero}); - // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: - // {n,w,f}. To reuse the base pattern vectorization case, we do pre - // transpose on input, weight, and output. + // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, + // output: {n,w,f}. To reuse the base pattern vectorization case, we do + // pre transpose on input, weight, and output. switch (conv1DOpOrder) { case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. @@ -1560,13 +1562,13 @@ // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array permLhs = {0, 2, 1}; - lhs = builder.create(loc, lhs, permLhs); + lhs = b.create(loc, lhs, permLhs); // fcw -> wcf static constexpr std::array permRhs = {2, 1, 0}; - rhs = builder.create(loc, rhs, permRhs); + rhs = b.create(loc, rhs, permRhs); // nfw -> nwf static constexpr std::array permRes = {0, 2, 1}; - res = builder.create(loc, res, permRes); + res = b.create(loc, res, permRes); break; } } @@ -1579,7 +1581,7 @@ // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(b.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1588,12 +1590,12 @@ } // Extract rhs slice of size {c, f} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(b.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(b.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, @@ -1608,14 +1610,14 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = conv1dSliceAsContraction( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + b, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = b.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1634,26 +1636,26 @@ case Conv1DOpOrder::Ncw: { // nwf -> nfw static constexpr std::array perm = {0, 2, 1}; - res = builder.create(loc, res, perm); + res = b.create(loc, res, perm); break; } } // Write back res slice of size {n, w, f} @ [0, 0, 0]. - return builder + return b .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, + Value conv1dSliceAsContraction(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); - return builder.create( + return b.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef{par, par, par, red}); @@ -1679,7 +1681,7 @@ bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; - Value zero = builder.create(loc, 0); + Value zero = b.create(loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -1701,14 +1703,14 @@ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. - Value lhs = builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + Value lhs = b.create(loc, lhsType, lhsShaped, + ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = builder.create(loc, rhsType, rhsShaped, - ValueRange{zero, zero}); + Value rhs = b.create(loc, rhsType, rhsShaped, + ValueRange{zero, zero}); // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = builder.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + Value res = b.create(loc, resType, resShaped, + ValueRange{zero, zero, zero}); //===------------------------------------------------------------------===// // Begin vector-only rewrite part @@ -1719,7 +1721,7 @@ // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(builder.create( + lhsVals.push_back(b.create( loc, lhs, /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1728,12 +1730,12 @@ } // Extract rhs slice of size {c} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(builder.create( + rhsVals.push_back(b.create( loc, rhs, /*offsets=*/ArrayRef{kw})); } // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(builder.create( + resVals.push_back(b.create( loc, res, /*offsets=*/ArrayRef{0, w, 0}, /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, @@ -1748,7 +1750,7 @@ for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals[w] = depthwiseConv1dSliceAsMulAcc( - builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); + b, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } @@ -1761,7 +1763,7 @@ // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = builder.create( + res = b.create( loc, resVals[w], res, /*offsets=*/ArrayRef{0, w, 0}, /*strides=*/ArrayRef{1, 1, 1}); @@ -1771,14 +1773,14 @@ //===------------------------------------------------------------------===// // Write back res slice of size {n, w, c} @ [0, 0, 0]. - return builder + return b .create(loc, res, resShaped, ValueRange{zero, zero, zero}) .getOperation(); } // Take a value of element type T and widen to the destination type. - Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + Value promote(RewriterBase &b, Location loc, Value val, Type ty) { if (val.getType() == ty) return val; @@ -1787,16 +1789,16 @@ const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return b.create(loc, ty, val); if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) - return builder.create(loc, ty, val); + return b.create(loc, ty, val); return nullptr; } /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc - Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, + Value depthwiseConv1dSliceAsMulAcc(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { auto rhsTy = rhs.getType().cast(); auto resTy = res.getType().cast(); @@ -1804,7 +1806,7 @@ // TODO(suderman): Change this to use a vector.ima intrinsic. lhs = promote(b, loc, lhs, resTy); - rhs = builder.create( + rhs = b.create( loc, resTy.clone(rhsTy.getElementType()), rhs); rhs = promote(b, loc, rhs, resTy); @@ -1876,11 +1878,12 @@ /// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { +static FailureOr vectorizeConvolution(RewriterBase &b, + LinalgOp op) { // The ConvolutionOpInterface gives us guarantees of existence for - // strides/dilations. However, we do not need to rely on those, we can simply - // use them if present, otherwise use the default and let the generic conv. - // matcher in the ConvGenerator succeed or fail. + // strides/dilations. However, we do not need to rely on those, we can + // simply use them if present, otherwise use the default and let the generic + // conv. matcher in the ConvGenerator succeed or fail. auto strides = op->getAttrOfType("strides"); auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1523,15 +1523,14 @@ /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) - : StructuredGenerator( - builder, op), + UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) + : StructuredGenerator(b, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), res(op.getAcc()), lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; - return builder.create(loc, v, perm); + return b.create(loc, v, perm); } Value promote(Value v, Type dstElementType) { @@ -1545,20 +1544,20 @@ if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); if (dstElementType.isa()) - return builder.create(loc, promotedType, v); - return builder.create(loc, promotedType, v); + return b.create(loc, promotedType, v); + return b.create(loc, promotedType, v); } Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { - Value a = builder.create(loc, lhs, k); - Value b = builder.create(loc, rhs, k); - a = promote(a, resElementType); - b = promote(b, resElementType); - res = builder.create(loc, res.getType(), a, b, - res, kind); + Value extractA = b.create(loc, lhs, k); + Value extractB = b.create(loc, rhs, k); + extractA = promote(extractA, resElementType); + extractB = promote(extractB, resElementType); + res = b.create(loc, res.getType(), extractA, + extractB, res, kind); } return res; } @@ -1569,7 +1568,7 @@ return failure(); // Set up the parallel/reduction structure in the right form. AffineExpr m, n, k; - bindDims(builder.getContext(), m, n, k); + bindDims(b.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1605,7 +1604,7 @@ if (!iters({Par(), Red()})) return failure(); AffineExpr m, k; - bindDims(builder.getContext(), m, k); + bindDims(b.getContext(), m, k); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) @@ -1629,7 +1628,7 @@ if (!iters({Red(), Par()})) return failure(); AffineExpr k, m; - bindDims(builder.getContext(), k, m); + bindDims(b.getContext(), k, m); // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}}))