diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -44,6 +44,11 @@ void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx); +/// Collect a set of patterns to convert vector-related Affine ops to the Vector +/// dialect. +void populateAffineToVectorConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + /// Emit code that computes the lower bound of the given affine loop using /// standard arithmetic operations. Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder); diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -370,7 +370,44 @@ let hasFolder = 1; } -def AffineLoadOp : Affine_Op<"load", []> { +class AffineLoadOpBase traits = []> : + Affine_Op { + let arguments = (ins Arg:$memref, + Variadic:$indices); + + code extraClassDeclarationBase = [{ + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 0; } + + /// Get memref operand. + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } + MemRefType getMemRefType() { + return getMemRef().getType().cast(); + } + + /// Get affine map operands. + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + }]; +} + +def AffineLoadOp : AffineLoadOpBase<"load", []> { let summary = "affine load operation"; let description = [{ The "affine.load" op reads an element from a memref, where the index @@ -393,9 +430,6 @@ ``` }]; - let arguments = (ins Arg:$memref, - Variadic:$indices); let results = (outs AnyType:$result); let builders = [ @@ -410,35 +444,7 @@ "AffineMap map, ValueRange mapOperands"> ]; - let extraClassDeclaration = [{ - /// Returns the operand index of the memref. - unsigned getMemRefOperandIndex() { return 0; } - - /// Get memref operand. - Value getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } - MemRefType getMemRefType() { - return getMemRef().getType().cast(); - } - - /// Get affine map operands. - operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); } - - /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } - AffineMapAttr getAffineMapAttr() { - return getAttr(getMapAttrName()).cast(); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - - static StringRef getMapAttrName() { return "map"; } - }]; + let extraClassDeclaration = extraClassDeclarationBase; let hasCanonicalizer = 1; let hasFolder = 1; @@ -659,7 +665,45 @@ let hasFolder = 1; } -def AffineStoreOp : Affine_Op<"store", []> { +class AffineStoreOpBase traits = []> : + Affine_Op { + + code extraClassDeclarationBase = [{ + /// Get value to be stored by store operation. + Value getValueToStore() { return getOperand(0); } + + /// Returns the operand index of the memref. + unsigned getMemRefOperandIndex() { return 1; } + + /// Get memref operand. + Value getMemRef() { return getOperand(getMemRefOperandIndex()); } + void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } + + MemRefType getMemRefType() { + return getMemRef().getType().cast(); + } + + /// Get affine map operands. + operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } + AffineMapAttr getAffineMapAttr() { + return getAttr(getMapAttrName()).cast(); + } + + /// Returns the AffineMapAttr associated with 'memref'. + NamedAttribute getAffineMapAttrForMemRef(Value memref) { + assert(memref == getMemRef()); + return {Identifier::get(getMapAttrName(), getContext()), + getAffineMapAttr()}; + } + + static StringRef getMapAttrName() { return "map"; } + }]; +} + +def AffineStoreOp : AffineStoreOpBase<"store", []> { let summary = "affine store operation"; let description = [{ The "affine.store" op writes an element to a memref, where the index @@ -686,7 +730,6 @@ [MemWrite]>:$memref, Variadic:$indices); - let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " @@ -696,39 +739,7 @@ "ValueRange mapOperands"> ]; - let extraClassDeclaration = [{ - /// Get value to be stored by store operation. - Value getValueToStore() { return getOperand(0); } - - /// Returns the operand index of the memref. - unsigned getMemRefOperandIndex() { return 1; } - - /// Get memref operand. - Value getMemRef() { return getOperand(getMemRefOperandIndex()); } - void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); } - - MemRefType getMemRefType() { - return getMemRef().getType().cast(); - } - - /// Get affine map operands. - operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); } - - /// Returns the affine map used to index the memref for this operation. - AffineMap getAffineMap() { return getAffineMapAttr().getValue(); } - AffineMapAttr getAffineMapAttr() { - return getAttr(getMapAttrName()).cast(); - } - - /// Returns the AffineMapAttr associated with 'memref'. - NamedAttribute getAffineMapAttrForMemRef(Value memref) { - assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), - getAffineMapAttr()}; - } - - static StringRef getMapAttrName() { return "map"; } - }]; + let extraClassDeclaration = extraClassDeclarationBase; let hasCanonicalizer = 1; let hasFolder = 1; @@ -765,4 +776,107 @@ let verifier = ?; } +def AffineVectorLoadOp : AffineLoadOpBase<"vector_load", []> { + let summary = "affine vector load operation"; + let description = [{ + The "affine.vector_load" is the vector counterpart of + [affine.load](#affineload-operation). It reads a slice from a + [MemRef](../LangRef.md#memref-type), supplied as its first operand, + into a [vector](../LangRef.md#vector-type) of the same base elemental type. + The index for each memref dimension is an affine expression of loop induction + variables and symbols. These indices determine the start position of the read + within the memref. The shape of the return vector type determines the shape of + the slice read from the memref. This slice is contiguous along the respective + dimensions of the shape. Strided vector loads will be supported in the future. + An affine expression of loop IVs and symbols must be specified for each + dimension of the memref. The keyword 'symbol' can be used to indicate SSA + identifiers which are symbolic. + + Example 1: 8-wide f32 vector load. + + ```mlir + %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32> + ``` + + Example 2: 4-wide f32 vector load. Uses 'symbol' keyword for symbols '%n' and '%m'. + + ```mlir + %1 = affine.vector_load %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32> + ``` + + Example 3: 2-dim f32 vector load. + + ```mlir + %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> + ``` + + TODOs: + * Add support for strided vector loads. + * Consider adding a permutation map to permute the slice that is read from memory + (see [vector.transfer_read](../Vector/#vectortransfer_read-vectortransferreadop)). + }]; + + let results = (outs AnyVector:$result); + + let extraClassDeclaration = extraClassDeclarationBase # [{ + VectorType getVectorType() { + return result().getType().cast(); + } + }]; +} + +def AffineVectorStoreOp : AffineStoreOpBase<"vector_store", []> { + let summary = "affine vector store operation"; + let description = [{ + The "affine.vector_store" is the vector counterpart of + [affine.store](#affinestore-affinestoreop). It writes a + [vector](../LangRef.md#vector-type), supplied as its first operand, + into a slice within a [MemRef](../LangRef.md#memref-type) of the same base + elemental type, supplied as its second operand. + The index for each memref dimension is an affine expression of loop + induction variables and symbols. These indices determine the start position + of the write within the memref. The shape of th input vector determines the + shape of the slice written to the memref. This slice is contiguous along the + respective dimensions of the shape. Strided vector stores will be supported + in the future. + An affine expression of loop IVs and symbols must be specified for each + dimension of the memref. The keyword 'symbol' can be used to indicate SSA + identifiers which are symbolic. + + Example 1: 8-wide f32 vector store. + + ```mlir + affine.vector_store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<8xf32> + ``` + + Example 2: 4-wide f32 vector store. Uses 'symbol' keyword for symbols '%n' and '%m'. + + ```mlir + affine.vector_store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] : memref<100x100xf32>, vector<4xf32> + ``` + + Example 3: 2-dim f32 vector store. + + ```mlir + affine.vector_store %v0, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> + ``` + + TODOs: + * Add support for strided vector stores. + * Consider adding a permutation map to permute the slice that is written to memory + (see [vector.transfer_write](../Vector/#vectortransfer_write-vectortransferwriteop)). + }]; + + let arguments = (ins AnyVector:$value, + Arg:$memref, + Variadic:$indices); + + let extraClassDeclaration = extraClassDeclarationBase # [{ + VectorType getVectorType() { + return value().getType().cast(); + } + }]; +} + #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -994,6 +994,13 @@ ``` }]; + let builders = [ + // Builder that sets permutation map and padding to 'getMinorIdentityMap' + // and zero, respectively, by default. + OpBuilder<"OpBuilder &builder, OperationState &result, VectorType vector, " + "Value memref, ValueRange indices"> + ]; + let extraClassDeclaration = [{ MemRefType getMemRefType() { return memref().getType().cast(); @@ -1058,6 +1065,13 @@ ``` }]; + let builders = [ + // Builder that sets permutation map and padding to 'getMinorIdentityMap' + // by default. + OpBuilder<"OpBuilder &builder, OperationState &result, Value vector, " + "Value memref, ValueRange indices"> + ]; + let extraClassDeclaration = [{ VectorType getVectorType() { return vector().getType().cast(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -27,6 +28,7 @@ #include "mlir/Transforms/Passes.h" using namespace mlir; +using namespace mlir::vector; namespace { /// Visit affine expressions recursively and build the sequence of operations @@ -556,6 +558,51 @@ } }; +/// Apply the affine map from an 'affine.vector_load' operation to its operands, +/// and feed the results to a newly created 'vector.transfer_read' operation +/// (which replaces the original 'affine.vector_load'). +class AffineVectorLoadLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineVectorLoadOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affineVectorLoadOp'. + SmallVector indices(op.getMapOperands()); + auto resultOperands = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!resultOperands) + return failure(); + + // Build vector.transfer_read memref[expandedMap.results]. + rewriter.replaceOpWithNewOp( + op, op.getVectorType(), op.getMemRef(), *resultOperands); + return success(); + } +}; + +/// Apply the affine map from an 'affine.vector_store' operation to its +/// operands, and feed the results to a newly created 'vector.transfer_write' +/// operation (which replaces the original 'affine.vector_store'). +class AffineVectorStoreLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineVectorStoreOp op, + PatternRewriter &rewriter) const override { + // Expand affine map from 'affineVectorStoreOp'. + SmallVector indices(op.getMapOperands()); + auto maybeExpandedMap = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); + if (!maybeExpandedMap) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); + return success(); + } +}; + } // end namespace void mlir::populateAffineToStdConversionPatterns( @@ -576,13 +623,24 @@ // clang-format on } +void mlir::populateAffineToVectorConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + // clang-format off + patterns.insert< + AffineVectorLoadLowering, + AffineVectorStoreLowering>(ctx); + // clang-format on +} + namespace { class LowerAffinePass : public ConvertAffineToStandardBase { void runOnFunction() override { OwningRewritePatternList patterns; populateAffineToStdConversionPatterns(patterns, &getContext()); + populateAffineToVectorConversionPatterns(patterns, &getContext()); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1912,32 +1912,47 @@ p << " : " << op.getMemRefType(); } -LogicalResult verify(AffineLoadOp op) { - if (op.getType() != op.getMemRefType().getElementType()) - return op.emitOpError("result type must match element type of memref"); - - auto mapAttr = op.getAttrOfType(op.getMapAttrName()); +/// Verify common indexing invariants of affine.load, affine.store, +/// affine.vector_load and affine.vector_store. +static LogicalResult +verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, + Operation::operand_range mapOperands, + MemRefType memrefType, unsigned numIndexOperands) { if (mapAttr) { - AffineMap map = - op.getAttrOfType(op.getMapAttrName()).getValue(); - if (map.getNumResults() != op.getMemRefType().getRank()) - return op.emitOpError("affine.load affine map num results must equal" - " memref rank"); - if (map.getNumInputs() != op.getNumOperands() - 1) - return op.emitOpError("expects as many subscripts as affine map inputs"); + AffineMap map = mapAttr.getValue(); + if (map.getNumResults() != memrefType.getRank()) + return op->emitOpError("affine map num results must equal memref rank"); + if (map.getNumInputs() != numIndexOperands) + return op->emitOpError("expects as many subscripts as affine map inputs"); } else { - if (op.getMemRefType().getRank() != op.getNumOperands() - 1) - return op.emitOpError( + if (memrefType.getRank() != numIndexOperands) + return op->emitOpError( "expects the number of subscripts to be equal to memref rank"); } Region *scope = getAffineScope(op); - for (auto idx : op.getMapOperands()) { + for (auto idx : mapOperands) { if (!idx.getType().isIndex()) - return op.emitOpError("index to load must have 'index' type"); + return op->emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) - return op.emitOpError("index must be a dimension or symbol identifier"); + return op->emitOpError("index must be a dimension or symbol identifier"); } + + return success(); +} + +LogicalResult verify(AffineLoadOp op) { + auto memrefType = op.getMemRefType(); + if (op.getType() != memrefType.getElementType()) + return op.emitOpError("result type must match element type of memref"); + + if (failed(verifyMemoryOpIndexing( + op.getOperation(), + op.getAttrOfType(op.getMapAttrName()), + op.getMapOperands(), memrefType, + /*numIndexOperands=*/op.getNumOperands() - 1))) + return failure(); + return success(); } @@ -2014,31 +2029,18 @@ LogicalResult verify(AffineStoreOp op) { // First operand must have same type as memref element type. - if (op.getValueToStore().getType() != op.getMemRefType().getElementType()) + auto memrefType = op.getMemRefType(); + if (op.getValueToStore().getType() != memrefType.getElementType()) return op.emitOpError( "first operand must have same type memref element type"); - auto mapAttr = op.getAttrOfType(op.getMapAttrName()); - if (mapAttr) { - AffineMap map = mapAttr.getValue(); - if (map.getNumResults() != op.getMemRefType().getRank()) - return op.emitOpError("affine.store affine map num results must equal" - " memref rank"); - if (map.getNumInputs() != op.getNumOperands() - 2) - return op.emitOpError("expects as many subscripts as affine map inputs"); - } else { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) - return op.emitOpError( - "expects the number of subscripts to be equal to memref rank"); - } + if (failed(verifyMemoryOpIndexing( + op.getOperation(), + op.getAttrOfType(op.getMapAttrName()), + op.getMapOperands(), memrefType, + /*numIndexOperands=*/op.getNumOperands() - 2))) + return failure(); - Region *scope = getAffineScope(op); - for (auto idx : op.getMapOperands()) { - if (!idx.getType().isIndex()) - return op.emitOpError("index to store must have 'index' type"); - if (!isValidAffineIndexOperand(idx, scope)) - return op.emitOpError("index must be a dimension or symbol identifier"); - } return success(); } @@ -2494,6 +2496,125 @@ } //===----------------------------------------------------------------------===// +// AffineVectorLoadOp +//===----------------------------------------------------------------------===// + +ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + auto indexTy = builder.getIndexType(); + + MemRefType memrefType; + VectorType resultType; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser.parseOperand(memrefInfo) || + parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, + AffineVectorLoadOp::getMapAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(memrefType) || parser.parseComma() || + parser.parseType(resultType) || + parser.resolveOperand(memrefInfo, memrefType, result.operands) || + parser.resolveOperands(mapOperands, indexTy, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +void print(OpAsmPrinter &p, AffineVectorLoadOp op) { + p << "affine.vector_load " << op.getMemRef() << '['; + if (AffineMapAttr mapAttr = + op.getAttrOfType(op.getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + p << ']'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); + p << " : " << op.getMemRefType() << ", " << op.getType(); +} + +/// Verify common invariants of affine.vector_load and affine.vector_store. +static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, + VectorType vectorType) { + // Check that memref and vector element types match. + if (memrefType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + return success(); +} + +static LogicalResult verify(AffineVectorLoadOp op) { + MemRefType memrefType = op.getMemRefType(); + if (failed(verifyMemoryOpIndexing( + op.getOperation(), + op.getAttrOfType(op.getMapAttrName()), + op.getMapOperands(), memrefType, + /*numIndexOperands=*/op.getNumOperands() - 1))) + return failure(); + + if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, + op.getVectorType()))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// AffineVectorStoreOp +//===----------------------------------------------------------------------===// + +ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, + OperationState &result) { + auto indexTy = parser.getBuilder().getIndexType(); + + MemRefType memrefType; + VectorType resultType; + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser.parseOperand(storeValueInfo) || parser.parseComma() || + parser.parseOperand(memrefInfo) || + parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, + AffineVectorStoreOp::getMapAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(memrefType) || parser.parseComma() || + parser.parseType(resultType) || + parser.resolveOperand(storeValueInfo, resultType, result.operands) || + parser.resolveOperand(memrefInfo, memrefType, result.operands) || + parser.resolveOperands(mapOperands, indexTy, result.operands)); +} + +void print(OpAsmPrinter &p, AffineVectorStoreOp op) { + p << "affine.vector_store " << op.getValueToStore(); + p << ", " << op.getMemRef() << '['; + if (AffineMapAttr mapAttr = + op.getAttrOfType(op.getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + p << ']'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); + p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); +} + +static LogicalResult verify(AffineVectorStoreOp op) { + MemRefType memrefType = op.getMemRefType(); + if (failed(verifyMemoryOpIndexing( + op.getOperation(), + op.getAttrOfType(op.getMapAttrName()), + op.getMapOperands(), memrefType, + /*numIndexOperands=*/op.getNumOperands() - 2))) + return failure(); + + if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, + op.getVectorType()))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1284,6 +1284,21 @@ return success(); } +/// Builder that sets permutation map and padding to 'getMinorIdentityMap' and +/// zero, respectively, by default. +void TransferReadOp::build(OpBuilder &builder, OperationState &result, + VectorType vector, Value memref, + ValueRange indices) { + auto permMap = AffineMap::getMinorIdentityMap( + memref.getType().cast().getRank(), vector.getRank(), + builder.getContext()); + Type elemType = vector.cast().getElementType(); + Value padding = builder.create(result.location, elemType, + builder.getZeroAttr(elemType)); + + build(builder, result, vector, memref, indices, permMap, padding); +} + static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding() << " "; @@ -1361,6 +1376,17 @@ // TransferWriteOp //===----------------------------------------------------------------------===// +/// Builder that sets permutation map and padding to 'getMinorIdentityMap' by +/// default. +void TransferWriteOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value memref, ValueRange indices) { + auto vectorType = vector.getType().cast(); + auto permMap = AffineMap::getMinorIdentityMap( + memref.getType().cast().getRank(), vectorType.getRank(), + builder.getContext()); + build(builder, result, vector, memref, indices, permMap); +} + static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir @@ -0,0 +1,117 @@ +// RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s + +// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @affine_vector_load +func @affine_vector_load(%arg0 : index) { + %0 = alloc() : memref<100xf32> + affine.for %i0 = 0 to 16 { + %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32> + } +// CHECK: %[[buf:.*]] = alloc +// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index +// CHECK-NEXT: %[[c7:.*]] = constant 7 : index +// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index +// CHECK-NEXT: %[[pad:.*]] = constant 0.0 +// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32> + return +} + +// ----- + +// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @affine_vector_store +func @affine_vector_store(%arg0 : index) { + %0 = alloc() : memref<100xf32> + %1 = constant dense<11.0> : vector<4xf32> + affine.for %i0 = 0 to 16 { + affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32> +} +// CHECK: %[[buf:.*]] = alloc +// CHECK: %[[val:.*]] = constant dense +// CHECK: %[[c_1:.*]] = constant -1 : index +// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index +// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index +// CHECK-NEXT: %[[c7:.*]] = constant 7 : index +// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index +// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32> + return +} + +// ----- + +// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @affine_vector_load +func @affine_vector_load(%arg0 : index) { + %0 = alloc() : memref<100xf32> + affine.for %i0 = 0 to 16 { + %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32> + } +// CHECK: %[[buf:.*]] = alloc +// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index +// CHECK-NEXT: %[[c7:.*]] = constant 7 : index +// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index +// CHECK-NEXT: %[[pad:.*]] = constant 0.0 +// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100xf32>, vector<8xf32> + return +} + +// ----- + +// CHECK: #[[perm_map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @affine_vector_store +func @affine_vector_store(%arg0 : index) { + %0 = alloc() : memref<100xf32> + %1 = constant dense<11.0> : vector<4xf32> + affine.for %i0 = 0 to 16 { + affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32> +} +// CHECK: %[[buf:.*]] = alloc +// CHECK: %[[val:.*]] = constant dense +// CHECK: %[[c_1:.*]] = constant -1 : index +// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index +// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index +// CHECK-NEXT: %[[c7:.*]] = constant 7 : index +// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index +// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] {permutation_map = #[[perm_map]]} : vector<4xf32>, memref<100xf32> + return +} + +// ----- + +// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @vector_load_2d +func @vector_load_2d() { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 16 step 2{ + affine.for %i1 = 0 to 16 step 8 { + %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> +// CHECK: %[[buf:.*]] = alloc +// CHECK: scf.for %[[i0:.*]] = +// CHECK: scf.for %[[i1:.*]] = +// CHECK-NEXT: %[[pad:.*]] = constant 0.0 +// CHECK-NEXT: vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] {permutation_map = #[[perm_map]]} : memref<100x100xf32>, vector<2x8xf32> + } + } + return +} + +// ----- + +// CHECK: #[[perm_map:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @vector_store_2d +func @vector_store_2d() { + %0 = alloc() : memref<100x100xf32> + %1 = constant dense<11.0> : vector<2x8xf32> + affine.for %i0 = 0 to 16 step 2{ + affine.for %i1 = 0 to 16 step 8 { + affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> +// CHECK: %[[buf:.*]] = alloc +// CHECK: %[[val:.*]] = constant dense +// CHECK: scf.for %[[i0:.*]] = +// CHECK: scf.for %[[i1:.*]] = +// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[i0]], %[[i1]]] {permutation_map = #[[perm_map]]} : vector<2x8xf32>, memref<100x100xf32> + } + } + return +} + diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -262,3 +262,49 @@ } return } + +// ----- + +func @vector_load_invalid_vector_type() { + %0 = alloc() : memref<100xf32> + affine.for %i0 = 0 to 16 step 8 { + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + %1 = affine.vector_load %0[%i0] : memref<100xf32>, vector<8xf64> + } + return +} + +// ----- + +func @vector_store_invalid_vector_type() { + %0 = alloc() : memref<100xf32> + %1 = constant dense<7.0> : vector<8xf64> + affine.for %i0 = 0 to 16 step 8 { + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + affine.vector_store %1, %0[%i0] : memref<100xf32>, vector<8xf64> + } + return +} + +// ----- + +func @vector_load_vector_memref() { + %0 = alloc() : memref<100xvector<8xf32>> + affine.for %i0 = 0 to 4 { + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + %1 = affine.vector_load %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32> + } + return +} + +// ----- + +func @vector_store_vector_memref() { + %0 = alloc() : memref<100xvector<8xf32>> + %1 = constant dense<7.0> : vector<8xf32> + affine.for %i0 = 0 to 4 { + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + affine.vector_store %1, %0[%i0] : memref<100xvector<8xf32>>, vector<8xf32> + } + return +} diff --git a/mlir/test/Dialect/Affine/load-store.mlir b/mlir/test/Dialect/Affine/load-store.mlir --- a/mlir/test/Dialect/Affine/load-store.mlir +++ b/mlir/test/Dialect/Affine/load-store.mlir @@ -214,3 +214,65 @@ } return } + +// ----- + +// CHECK: [[MAP_ID:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> + +// Test with just loop IVs. +func @vector_load_vector_store_iv() { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 16 { + affine.for %i1 = 0 to 16 step 8 { + %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32> + affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<8xf32> +// CHECK: %[[buf:.*]] = alloc +// CHECK-NEXT: affine.for %[[i0:.*]] = 0 +// CHECK-NEXT: affine.for %[[i1:.*]] = 0 +// CHECK-NEXT: %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32> +// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<8xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0 + 3, d1 + 7)> + +// Test with loop IVs and constants. +func @vector_load_vector_store_iv_constant() { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 16 step 4 { + %1 = affine.vector_load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32> + affine.vector_store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>, vector<4xf32> +// CHECK: %[[buf:.*]] = alloc +// CHECK-NEXT: affine.for %[[i0:.*]] = 0 +// CHECK-NEXT: affine.for %[[i1:.*]] = 0 +// CHECK-NEXT: %[[val:.*]] = affine.vector_load %{{.*}}[%{{.*}} + 3, %{{.*}} + 7] : memref<100x100xf32>, vector<4xf32> +// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]] + 3, %[[i1]] + 7] : memref<100x100xf32>, vector<4xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> + +func @vector_load_vector_store_2d() { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 16 step 2{ + affine.for %i1 = 0 to 16 step 8 { + %1 = affine.vector_load %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> + affine.vector_store %1, %0[%i0, %i1] : memref<100x100xf32>, vector<2x8xf32> +// CHECK: %[[buf:.*]] = alloc +// CHECK-NEXT: affine.for %[[i0:.*]] = 0 +// CHECK-NEXT: affine.for %[[i1:.*]] = 0 +// CHECK-NEXT: %[[val:.*]] = affine.vector_load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32> +// CHECK-NEXT: affine.vector_store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32> + } + } + return +}