diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -284,6 +284,54 @@ newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); } +namespace lowering_n_d { + +/// Helper data structure for data and mask buffers. +struct BufferAllocs { + Value dataBuffer; + Value maskBuffer; +}; + +/// Allocate temporary buffers for data (vector) and mask (if present). +/// TODO: Parallelism and threadlocal considerations. +template +static BufferAllocs allocBuffers(OpTy xferOp) { + auto &b = ScopedContext::getBuilderRef(); + OpBuilder::InsertionGuard guard(b); + Operation *scope = + xferOp->template getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + b.setInsertionPointToStart(&scope->getRegion(0).front()); + + BufferAllocs result; + auto bufferType = MemRefType::get({}, xferOp.getVectorType()); + result.dataBuffer = memref_alloca(bufferType).value; + + if (xferOp.mask()) { + auto maskType = MemRefType::get({}, xferOp.mask().getType()); + auto maskBuffer = memref_alloca(maskType).value; + memref_store(xferOp.mask(), maskBuffer); + result.maskBuffer = memref_load(maskBuffer); + } + + return result; +} + +/// Given a MemRefType with VectorType element type, unpack one dimension from +/// the VectorType into the MemRefType. +/// +/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> +static MemRefType unpackOneDim(MemRefType type) { + auto vectorType = type.getElementType().dyn_cast(); + auto memrefShape = type.getShape(); + SmallVector newMemrefShape; + newMemrefShape.append(memrefShape.begin(), memrefShape.end()); + newMemrefShape.push_back(vectorType.getDimSize(0)); + return MemRefType::get(newMemrefShape, + VectorType::get(vectorType.getShape().drop_front(), + vectorType.getElementType())); +} + /// Given a transfer op, find the memref from which the mask is loaded. This /// is similar to Strategy::getBuffer. template @@ -688,6 +736,10 @@ } }; +} // namespace lowering_n_d + +namespace lowering_n_d_unrolled { + /// If the original transfer op has a mask, compute the mask of the new transfer /// op (for the current iteration `i`) and assign it. template @@ -959,6 +1011,10 @@ } }; +} // namespace lowering_n_d_unrolled + +namespace lowering_1_d { + /// Compute the indices into the memref for the LoadOp/StoreOp generated as /// part of TransferOp1dConversion. Return the memref dimension on which /// the transfer is operating. A return value of None indicates a broadcast. @@ -1119,6 +1175,7 @@ } }; +} // namespace lowering_1_d } // namespace namespace mlir { @@ -1126,19 +1183,21 @@ void populateVectorToSCFConversionPatterns( RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { if (options.unroll) { - patterns.add( + patterns.add( patterns.getContext(), options); } else { - patterns.add, - TransferOpConversion>(patterns.getContext(), - options); + patterns.add, + lowering_n_d::TransferOpConversion>( + patterns.getContext(), options); } if (options.targetRank == 1) { - patterns.add, - TransferOp1dConversion>(patterns.getContext(), - options); + patterns.add, + lowering_1_d::TransferOp1dConversion>( + patterns.getContext(), options); } }