diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -49,52 +49,6 @@ VectorTransferToSCFOptions options; }; -/// Given a MemRefType with VectorType element type, unpack one dimension from -/// the VectorType into the MemRefType. -/// -/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> -static MemRefType unpackOneDim(MemRefType type) { - auto vectorType = type.getElementType().dyn_cast(); - auto memrefShape = type.getShape(); - SmallVector newMemrefShape; - newMemrefShape.append(memrefShape.begin(), memrefShape.end()); - newMemrefShape.push_back(vectorType.getDimSize(0)); - return MemRefType::get(newMemrefShape, - VectorType::get(vectorType.getShape().drop_front(), - vectorType.getElementType())); -} - -/// Helper data structure for data and mask buffers. -struct BufferAllocs { - Value dataBuffer; - Value maskBuffer; -}; - -/// Allocate temporary buffers for data (vector) and mask (if present). -/// TODO: Parallelism and threadlocal considerations. -template -static BufferAllocs allocBuffers(OpTy xferOp) { - auto &b = ScopedContext::getBuilderRef(); - OpBuilder::InsertionGuard guard(b); - Operation *scope = - xferOp->template getParentWithTrait(); - assert(scope && "Expected op to be inside automatic allocation scope"); - b.setInsertionPointToStart(&scope->getRegion(0).front()); - - BufferAllocs result; - auto bufferType = MemRefType::get({}, xferOp.getVectorType()); - result.dataBuffer = memref_alloca(bufferType).value; - - if (xferOp.mask()) { - auto maskType = MemRefType::get({}, xferOp.mask().getType()); - Value maskBuffer = memref_alloca(maskType); - memref_store(xferOp.mask(), maskBuffer); - result.maskBuffer = memref_load(maskBuffer); - } - - return result; -} - /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. /// A return value of None indicates a broadcast. @@ -284,6 +238,54 @@ newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); } +namespace lowering_n_d { + +/// Helper data structure for data and mask buffers. +struct BufferAllocs { + Value dataBuffer; + Value maskBuffer; +}; + +/// Allocate temporary buffers for data (vector) and mask (if present). +/// TODO: Parallelism and threadlocal considerations. +template +static BufferAllocs allocBuffers(OpTy xferOp) { + auto &b = ScopedContext::getBuilderRef(); + OpBuilder::InsertionGuard guard(b); + Operation *scope = + xferOp->template getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + b.setInsertionPointToStart(&scope->getRegion(0).front()); + + BufferAllocs result; + auto bufferType = MemRefType::get({}, xferOp.getVectorType()); + result.dataBuffer = memref_alloca(bufferType).value; + + if (xferOp.mask()) { + auto maskType = MemRefType::get({}, xferOp.mask().getType()); + auto maskBuffer = memref_alloca(maskType).value; + memref_store(xferOp.mask(), maskBuffer); + result.maskBuffer = memref_load(maskBuffer); + } + + return result; +} + +/// Given a MemRefType with VectorType element type, unpack one dimension from +/// the VectorType into the MemRefType. +/// +/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> +static MemRefType unpackOneDim(MemRefType type) { + auto vectorType = type.getElementType().dyn_cast(); + auto memrefShape = type.getShape(); + SmallVector newMemrefShape; + newMemrefShape.append(memrefShape.begin(), memrefShape.end()); + newMemrefShape.push_back(vectorType.getDimSize(0)); + return MemRefType::get(newMemrefShape, + VectorType::get(vectorType.getShape().drop_front(), + vectorType.getElementType())); +} + /// Given a transfer op, find the memref from which the mask is loaded. This /// is similar to Strategy::getBuffer. template @@ -688,6 +690,10 @@ } }; +} // namespace lowering_n_d + +namespace lowering_n_d_unrolled { + /// If the original transfer op has a mask, compute the mask of the new transfer /// op (for the current iteration `i`) and assign it. template @@ -954,6 +960,10 @@ } }; +} // namespace lowering_n_d_unrolled + +namespace lowering_1_d { + /// Compute the indices into the memref for the LoadOp/StoreOp generated as /// part of TransferOp1dConversion. Return the memref dimension on which /// the transfer is operating. A return value of None indicates a broadcast. @@ -1114,6 +1124,7 @@ } }; +} // namespace lowering_1_d } // namespace namespace mlir { @@ -1121,19 +1132,21 @@ void populateVectorToSCFConversionPatterns( RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { if (options.unroll) { - patterns.add( + patterns.add( patterns.getContext(), options); } else { - patterns.add, - TransferOpConversion>(patterns.getContext(), - options); + patterns.add, + lowering_n_d::TransferOpConversion>( + patterns.getContext(), options); } if (options.targetRank == 1) { - patterns.add, - TransferOp1dConversion>(patterns.getContext(), - options); + patterns.add, + lowering_1_d::TransferOp1dConversion>( + patterns.getContext(), options); } }