diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1320,6 +1320,156 @@ let hasFolder = 1; } +def Vector_LoadOp : Vector_Op<"load"> { + let summary = "reads an n-D slice of memory into an n-D vector"; + let description = [{ + The 'vector.load' operation reads an n-D slice of memory into an n-D + vector. It takes a 'base' memref, an index for each memref dimension and a + result vector type as arguments. It returns a value of the result vector + type. The 'base' memref and indices determine the start memory address from + which to read. Each index provides an offset for each memref dimension + based on the element type of the memref. The shape of the result vector + type determines the shape of the slice read from the start memory address. + The elements along each dimension of the slice are strided by the memref + strides. Only memref with default strides are allowed. These constraints + guarantee that elements read along the first dimension of the slice are + contiguous in memory. + + The memref element type can be a scalar or a vector type. If the memref + element type is a scalar, it should match the element type of the result + vector. If the memref element type is vector, it should match the result + vector type. + + Example 1: 1-D vector load on a scalar memref. + ```mlir + %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32> + ``` + + Example 2: 1-D vector load on a vector memref. + ```mlir + %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + ``` + + Example 3: 2-D vector load on a scalar memref. + ```mlir + %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32> + ``` + + Example 4: 2-D vector load on a vector memref. + ```mlir + %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + ``` + + Representation-wise, the 'vector.load' operation permits out-of-bounds + reads. Support and implementation of out-of-bounds vector loads is + target-specific. No assumptions should be made on the value of elements + loaded out of bounds. Not all targets may support out-of-bounds vector + loads. + + Example 5: Potential out-of-bound vector load. + ```mlir + %result = vector.load %memref[%index] : memref, vector<8xf32> + ``` + + Example 6: Explicit out-of-bound vector load. + ```mlir + %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32> + ``` + }]; + + let arguments = (ins Arg:$base, + Variadic:$indices); + let results = (outs AnyVector:$result); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + + VectorType getVectorType() { + return result().getType().cast(); + } + }]; + + let assemblyFormat = + "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; +} + +def Vector_StoreOp : Vector_Op<"store"> { + let summary = "writes an n-D vector to an n-D slice of memory"; + let description = [{ + The 'vector.store' operation writes an n-D vector to an n-D slice of memory. + It takes the vector value to be stored, a 'base' memref and an index for + each memref dimension. The 'base' memref and indices determine the start + memory address from which to write. Each index provides an offset for each + memref dimension based on the element type of the memref. The shape of the + vector value to store determines the shape of the slice written from the + start memory address. The elements along each dimension of the slice are + strided by the memref strides. Only memref with default strides are allowed. + These constraints guarantee that elements written along the first dimension + of the slice are contiguous in memory. + + The memref element type can be a scalar or a vector type. If the memref + element type is a scalar, it should match the element type of the value + to store. If the memref element type is vector, it should match the type + of the value to store. + + Example 1: 1-D vector store on a scalar memref. + ```mlir + vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> + ``` + + Example 2: 1-D vector store on a vector memref. + ```mlir + vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + ``` + + Example 3: 2-D vector store on a scalar memref. + ```mlir + vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32> + ``` + + Example 4: 2-D vector store on a vector memref. + ```mlir + vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + ``` + + Representation-wise, the 'vector.store' operation permits out-of-bounds + writes. Support and implementation of out-of-bounds vector stores are + target-specific. No assumptions should be made on the memory written out of + bounds. Not all targets may support out-of-bounds vector stores. + + Example 5: Potential out-of-bounds vector store. + ```mlir + vector.store %valueToStore, %memref[%index] : memref, vector<8xf32> + ``` + + Example 6: Explicit out-of-bounds vector store. + ```mlir + vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32> + ``` + }]; + + let arguments = (ins AnyVector:$valueToStore, + Arg:$base, + Variadic:$indices); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + + VectorType getVectorType() { + return valueToStore().getType().cast(); + } + }]; + + let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " + "`:` type($base) `,` type($valueToStore)"; +} + def Vector_MaskedLoadOp : Vector_Op<"maskedload">, Arguments<(ins Arg:$base, @@ -1363,7 +1513,7 @@ VectorType getPassThruVectorType() { return pass_thru().getType().cast(); } - VectorType getResultVectorType() { + VectorType getVectorType() { return result().getType().cast(); } }]; @@ -1377,7 +1527,7 @@ Arguments<(ins Arg:$base, Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$value)> { + VectorOfRank<[1]>:$valueToStore)> { let summary = "stores elements from a vector into memory as defined by a mask vector"; @@ -1411,12 +1561,13 @@ VectorType getMaskVectorType() { return mask().getType().cast(); } - VectorType getValueVectorType() { - return value().getType().cast(); + VectorType getVectorType() { + return valueToStore().getType().cast(); } }]; - let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " - "type($base) `,` type($mask) `,` type($value)"; + let assemblyFormat = + "$base `[` $indices `]` `,` $mask `,` $valueToStore " + "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -578,8 +578,9 @@ if (!resultOperands) return failure(); - // Build std.load memref[expandedMap.results]. - rewriter.replaceOpWithNewOp(op, op.getMemRef(), *resultOperands); + // Build vector.load memref[expandedMap.results]. + rewriter.replaceOpWithNewOp(op, op.getMemRef(), + *resultOperands); return success(); } }; @@ -625,8 +626,8 @@ return failure(); // Build std.store valueToStore, memref[expandedMap.results]. - rewriter.replaceOpWithNewOp(op, op.getValueToStore(), - op.getMemRef(), *maybeExpandedMap); + rewriter.replaceOpWithNewOp( + op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); return success(); } }; @@ -695,8 +696,8 @@ }; /// Apply the affine map from an 'affine.vector_load' operation to its operands, -/// and feed the results to a newly created 'vector.transfer_read' operation -/// (which replaces the original 'affine.vector_load'). +/// and feed the results to a newly created 'vector.load' operation (which +/// replaces the original 'affine.vector_load'). class AffineVectorLoadLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -710,16 +711,16 @@ if (!resultOperands) return failure(); - // Build vector.transfer_read memref[expandedMap.results]. - rewriter.replaceOpWithNewOp( + // Build vector.load memref[expandedMap.results]. + rewriter.replaceOpWithNewOp( op, op.getVectorType(), op.getMemRef(), *resultOperands); return success(); } }; /// Apply the affine map from an 'affine.vector_store' operation to its -/// operands, and feed the results to a newly created 'vector.transfer_write' -/// operation (which replaces the original 'affine.vector_store'). +/// operands, and feed the results to a newly created 'vector.store' operation +/// (which replaces the original 'affine.vector_store'). class AffineVectorStoreLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -733,7 +734,7 @@ if (!maybeExpandedMap) return failure(); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -357,64 +357,72 @@ } }; -/// Conversion pattern for a vector.maskedload. -class VectorMaskedLoadOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::MaskedLoadOp load, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = load->getLoc(); - auto adaptor = vector::MaskedLoadOpAdaptor(operands); - MemRefType memRefType = load.getMemRefType(); +/// Overloaded utility that replaces a vector.load, vector.store, +/// vector.maskedload and vector.maskedstore with their respective LLVM +/// couterparts. +static void replaceLoadOrStoreOp(vector::LoadOp loadOp, + vector::LoadOpAdaptor adaptor, + VectorType vectorTy, Value ptr, unsigned align, + ConversionPatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(loadOp, ptr, align); +} - // Resolve alignment. - unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) - return failure(); +static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, + vector::MaskedLoadOpAdaptor adaptor, + VectorType vectorTy, Value ptr, unsigned align, + ConversionPatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); +} - // Resolve address. - auto vtype = typeConverter->convertType(load.getResultVectorType()); - Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); - Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); +static void replaceLoadOrStoreOp(vector::StoreOp storeOp, + vector::StoreOpAdaptor adaptor, + VectorType vectorTy, Value ptr, unsigned align, + ConversionPatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(storeOp, adaptor.valueToStore(), + ptr, align); +} - rewriter.replaceOpWithNewOp( - load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), - rewriter.getI32IntegerAttr(align)); - return success(); - } -}; +static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, + vector::MaskedStoreOpAdaptor adaptor, + VectorType vectorTy, Value ptr, unsigned align, + ConversionPatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); +} -/// Conversion pattern for a vector.maskedstore. -class VectorMaskedStoreOpConversion - : public ConvertOpToLLVMPattern { +/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and +/// vector.maskedstore. +template +class VectorLoadStoreConversion : public ConvertOpToLLVMPattern { public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::MaskedStoreOp store, ArrayRef operands, + matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = store->getLoc(); - auto adaptor = vector::MaskedStoreOpAdaptor(operands); - MemRefType memRefType = store.getMemRefType(); + // Only 1-D vectors can be lowered to LLVM. + VectorType vectorTy = loadOrStoreOp.getVectorType(); + if (vectorTy.getRank() > 1) + return failure(); + + auto loc = loadOrStoreOp->getLoc(); + auto adaptor = LoadOrStoreOpAdaptor(operands); + MemRefType memRefTy = loadOrStoreOp.getMemRefType(); // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) + if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) return failure(); // Resolve address. - auto vtype = typeConverter->convertType(store.getValueVectorType()); - Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), + auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) + .template cast(); + Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), adaptor.indices(), rewriter); - Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); + Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); - rewriter.replaceOpWithNewOp( - store, adaptor.value(), ptr, adaptor.mask(), - rewriter.getI32IntegerAttr(align)); + replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); return success(); } }; @@ -1511,8 +1519,14 @@ VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, - VectorMaskedLoadOpConversion, - VectorMaskedStoreOpConversion, + VectorLoadStoreConversion, + VectorLoadStoreConversion, + VectorLoadStoreConversion, + VectorLoadStoreConversion, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2373,6 +2373,67 @@ SideEffects::DefaultResource::get()); } +//===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, + MemRefType memRefTy) { + auto affineMaps = memRefTy.getAffineMaps(); + if (!affineMaps.empty()) + return op->emitOpError("base memref should have a default identity layout"); + return success(); +} + +static LogicalResult verify(vector::LoadOp op) { + VectorType resVecTy = op.getVectorType(); + MemRefType memRefTy = op.getMemRefType(); + + if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + return failure(); + + // Checks for vector memrefs. + Type memElemTy = memRefTy.getElementType(); + if (auto memVecTy = memElemTy.dyn_cast()) { + if (memVecTy != resVecTy) + return op.emitOpError("base memref and result vector types should match"); + memElemTy = memVecTy.getElementType(); + } + + if (resVecTy.getElementType() != memElemTy) + return op.emitOpError("base and result element types should match"); + if (llvm::size(op.indices()) != memRefTy.getRank()) + return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return success(); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(vector::StoreOp op) { + VectorType valueVecTy = op.getVectorType(); + MemRefType memRefTy = op.getMemRefType(); + + if (failed(verifyLoadStoreMemRefLayout(op, memRefTy))) + return failure(); + + // Checks for vector memrefs. + Type memElemTy = memRefTy.getElementType(); + if (auto memVecTy = memElemTy.dyn_cast()) { + if (memVecTy != valueVecTy) + return op.emitOpError( + "base memref and valueToStore vector types should match"); + memElemTy = memVecTy.getElementType(); + } + + if (valueVecTy.getElementType() != memElemTy) + return op.emitOpError("base and valueToStore element type should match"); + if (llvm::size(op.indices()) != memRefTy.getRank()) + return op.emitOpError("requires ") << memRefTy.getRank() << " indices"; + return success(); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// @@ -2380,7 +2441,7 @@ static LogicalResult verify(MaskedLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); - VectorType resVType = op.getResultVectorType(); + VectorType resVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (resVType.getElementType() != memType.getElementType()) @@ -2427,15 +2488,15 @@ static LogicalResult verify(MaskedStoreOp op) { VectorType maskVType = op.getMaskVectorType(); - VectorType valueVType = op.getValueVectorType(); + VectorType valueVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) - return op.emitOpError("base and value element type should match"); + return op.emitOpError("base and valueToStore element type should match"); if (llvm::size(op.indices()) != memType.getRank()) return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return op.emitOpError("expected value dim to match mask dim"); + return op.emitOpError("expected valueToStore dim to match mask dim"); return success(); } @@ -2448,7 +2509,7 @@ switch (get1DMaskFormat(store.mask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( - store, store.value(), store.base(), store.indices(), false); + store, store.valueToStore(), store.base(), store.indices(), false); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(store); diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine-to-vector.mlir @@ -1,41 +1,5 @@ // RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s -// CHECK-LABEL: func @affine_vector_load -func @affine_vector_load(%arg0 : index) { - %0 = alloc() : memref<100xf32> - affine.for %i0 = 0 to 16 { - %1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32> - } -// CHECK: %[[buf:.*]] = alloc -// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index -// CHECK-NEXT: %[[c7:.*]] = constant 7 : index -// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index -// CHECK-NEXT: %[[pad:.*]] = constant 0.0 -// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32> - return -} - -// ----- - -// CHECK-LABEL: func @affine_vector_store -func @affine_vector_store(%arg0 : index) { - %0 = alloc() : memref<100xf32> - %1 = constant dense<11.0> : vector<4xf32> - affine.for %i0 = 0 to 16 { - affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32> -} -// CHECK: %[[buf:.*]] = alloc -// CHECK: %[[val:.*]] = constant dense -// CHECK: %[[c_1:.*]] = constant -1 : index -// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index -// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index -// CHECK-NEXT: %[[c7:.*]] = constant 7 : index -// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index -// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32> - return -} - -// ----- // CHECK-LABEL: func @affine_vector_load func @affine_vector_load(%arg0 : index) { @@ -47,8 +11,7 @@ // CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index // CHECK-NEXT: %[[c7:.*]] = constant 7 : index // CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index -// CHECK-NEXT: %[[pad:.*]] = constant 0.0 -// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32> +// CHECK-NEXT: vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32> return } @@ -68,7 +31,7 @@ // CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index // CHECK-NEXT: %[[c7:.*]] = constant 7 : index // CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index -// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32> +// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32> return } @@ -83,8 +46,7 @@ // CHECK: %[[buf:.*]] = alloc // CHECK: scf.for %[[i0:.*]] = // CHECK: scf.for %[[i1:.*]] = -// CHECK-NEXT: %[[pad:.*]] = constant 0.0 -// CHECK-NEXT: vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] : memref<100x100xf32>, vector<2x8xf32> +// CHECK-NEXT: vector.load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32> } } return @@ -103,9 +65,8 @@ // CHECK: %[[val:.*]] = constant dense // CHECK: scf.for %[[i0:.*]] = // CHECK: scf.for %[[i1:.*]] = -// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[i0]], %[[i1]]] : vector<2x8xf32>, memref<100x100xf32> +// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32> } } return } - diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -23,6 +23,7 @@ // ----- + func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2xf32> return %0 : vector<2xf32> @@ -1242,6 +1243,33 @@ // ----- +func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @vector_load_op +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr> + +func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) { + %val = constant dense<11.0> : vector<4xf32> + vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func @vector_store_op +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr> + func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1198,6 +1198,38 @@ // ----- +func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, + %i : index, %j : index, %value : vector<8xf32>) { + // expected-error@+1 {{'vector.store' op base memref should have a default identity layout}} + vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, + vector<8xf32> +} + +// ----- + +func @vector_memref_mismatch(%memref : memref<200x100xvector<4xf32>>, %i : index, + %j : index, %value : vector<8xf32>) { + // expected-error@+1 {{'vector.store' op base memref and valueToStore vector types should match}} + vector.store %value, %memref[%i, %j] : memref<200x100xvector<4xf32>>, vector<8xf32> +} + +// ----- + +func @store_base_type_mismatch(%base : memref, %value : vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.store' op base and valueToStore element type should match}} + vector.store %value, %base[%c0] : memref, vector<16xf32> +} + +// ----- + +func @store_memref_index_mismatch(%base : memref, %value : vector<16xf32>) { + // expected-error@+1 {{'vector.store' op requires 1 indices}} + vector.store %value, %base[] : memref, vector<16xf32> +} + +// ----- + func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedload' op base and result element type should match}} @@ -1231,7 +1263,7 @@ func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index - // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}} + // expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}} vector.maskedstore %base[%c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> } @@ -1239,7 +1271,7 @@ func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index - // expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}} + // expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}} vector.maskedstore %base[%c0], %mask, %value : memref, vector<15xi1>, vector<16xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -450,6 +450,56 @@ return %0 : vector<16xi32> } +// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref +func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> + return +} + +// CHECK-LABEL: @vector_load_and_store_1d_vector_memref +func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf32>>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + return +} + +// CHECK-LABEL: @vector_load_and_store_out_of_bounds +func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) { + %c0 = constant 0 : index + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32> + %0 = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32> + vector.store %0, %memref[%c0] : memref<7xf32>, vector<8xf32> + return +} + +// CHECK-LABEL: @vector_load_and_store_2d_scalar_memref +func @vector_load_and_store_2d_scalar_memref(%memref : memref<200x100xf32>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32> + return +} + +// CHECK-LABEL: @vector_load_and_store_2d_vector_memref +func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvector<4x8xf32>>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32> + return +} + // CHECK-LABEL: @masked_load_and_store func @masked_load_and_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { %c0 = constant 0 : index