diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1317,6 +1317,7 @@ def Vector_MaskedLoadOp : Vector_Op<"maskedload">, Arguments<(ins AnyMemRef:$base, + Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { @@ -1325,12 +1326,12 @@ let description = [{ The masked load reads elements from memory into a 1-D vector as defined - by a base and a 1-D mask vector. When the mask is set, the element is read - from memory. Otherwise, the corresponding element is taken from a 1-D - pass-through vector. Informally the semantics are: + by a base with indices and a 1-D mask vector. When the mask is set, the + element is read from memory. Otherwise, the corresponding element is taken + from a 1-D pass-through vector. Informally the semantics are: ``` - result[0] := mask[0] ? MEM[base+0] : pass_thru[0] - result[1] := mask[1] ? MEM[base+1] : pass_thru[1] + result[0] := mask[0] ? base[i+0] : pass_thru[0] + result[1] := mask[1] ? base[i+1] : pass_thru[1] etc. ``` The masked load can be used directly where applicable, or can be used @@ -1342,7 +1343,7 @@ Example: ```mlir - %0 = vector.maskedload %base, %mask, %pass_thru + %0 = vector.maskedload %base[%i], %mask, %pass_thru : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> ``` }]; @@ -1360,7 +1361,7 @@ return result().getType().cast(); } }]; - let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; } @@ -1368,6 +1369,7 @@ def Vector_MaskedStoreOp : Vector_Op<"maskedstore">, Arguments<(ins AnyMemRef:$base, + Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$value)> { @@ -1375,12 +1377,12 @@ let description = [{ The masked store operation writes elements from a 1-D vector into memory - as defined by a base and a 1-D mask vector. When the mask is set, the - corresponding element from the vector is written to memory. Otherwise, + as defined by a base with indices and a 1-D mask vector. When the mask is + set, the corresponding element from the vector is written to memory. Otherwise, no action is taken for the element. Informally the semantics are: ``` - if (mask[0]) MEM[base+0] = value[0] - if (mask[1]) MEM[base+1] = value[1] + if (mask[0]) base[i+0] = value[0] + if (mask[1]) base[i+1] = value[1] etc. ``` The masked store can be used directly where applicable, or can be used @@ -1392,7 +1394,7 @@ Example: ```mlir - vector.maskedstore %base, %mask, %value + vector.maskedstore %base[%i], %mask, %value : memref, vector<8xi1>, vector<8xf32> ``` }]; @@ -1407,8 +1409,8 @@ return value().getType().cast(); } }]; - let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " - "type($mask) `,` type($value) `into` type($base)"; + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " + "type($base) `,` type($mask) `,` type($value)"; let hasCanonicalizer = 1; } @@ -1430,8 +1432,8 @@ semantics are: ``` if (!defined(pass_thru)) pass_thru = [undef, .., undef] - result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0] - result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1] + result[0] := mask[0] ? base[index[0]] : pass_thru[0] + result[1] := mask[1] ? base[index[1]] : pass_thru[1] etc. ``` The vector dialect leaves out-of-bounds behavior undefined. @@ -1487,8 +1489,8 @@ bit in a 1-D mask vector is set. Otherwise, no action is taken for that element. Informally the semantics are: ``` - if (mask[0]) MEM[base + index[0]] = value[0] - if (mask[1]) MEM[base + index[1]] = value[1] + if (mask[0]) base[index[0]] = value[0] + if (mask[1]) base[index[1]] = value[1] etc. ``` The vector dialect leaves out-of-bounds and repeated index behavior @@ -1531,6 +1533,7 @@ def Vector_ExpandLoadOp : Vector_Op<"expandload">, Arguments<(ins AnyMemRef:$base, + Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { @@ -1539,13 +1542,13 @@ let description = [{ The expand load reads elements from memory into a 1-D vector as defined - by a base and a 1-D mask vector. When the mask is set, the next element - is read from memory. Otherwise, the corresponding element is taken from - a 1-D pass-through vector. Informally the semantics are: + by a base with indices and a 1-D mask vector. When the mask is set, the + next element is read from memory. Otherwise, the corresponding element + is taken from a 1-D pass-through vector. Informally the semantics are: ``` - index = base - result[0] := mask[0] ? MEM[index++] : pass_thru[0] - result[1] := mask[1] ? MEM[index++] : pass_thru[1] + index = i + result[0] := mask[0] ? base[index++] : pass_thru[0] + result[1] := mask[1] ? base[index++] : pass_thru[1] etc. ``` Note that the index increment is done conditionally. @@ -1559,7 +1562,7 @@ Example: ```mlir - %0 = vector.expandload %base, %mask, %pass_thru + %0 = vector.expandload %base[%i], %mask, %pass_thru : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> ``` }]; @@ -1577,7 +1580,7 @@ return result().getType().cast(); } }]; - let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; } @@ -1585,6 +1588,7 @@ def Vector_CompressStoreOp : Vector_Op<"compressstore">, Arguments<(ins AnyMemRef:$base, + Variadic:$indices, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$value)> { @@ -1592,13 +1596,13 @@ let description = [{ The compress store operation writes elements from a 1-D vector into memory - as defined by a base and a 1-D mask vector. When the mask is set, the - corresponding element from the vector is written next to memory. Otherwise, - no action is taken for the element. Informally the semantics are: + as defined by a base with indices and a 1-D mask vector. When the mask is + set, the corresponding element from the vector is written next to memory. + Otherwise, no action is taken for the element. Informally the semantics are: ``` - index = base - if (mask[0]) MEM[index++] = value[0] - if (mask[1]) MEM[index++] = value[1] + index = i + if (mask[0]) base[index++] = value[0] + if (mask[1]) base[index++] = value[1] etc. ``` Note that the index increment is done conditionally. @@ -1612,7 +1616,7 @@ Example: ```mlir - vector.compressstore %base, %mask, %value + vector.compressstore %base[%i], %mask, %value : memref, vector<8xi1>, vector<8xf32> ``` }]; @@ -1627,7 +1631,7 @@ return value().getType().cast(); } }]; - let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " "type($base) `,` type($mask) `,` type($value)"; let hasCanonicalizer = 1; } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir @@ -5,7 +5,16 @@ func @compress16(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { - vector.compressstore %base, %mask, %value + %c0 = constant 0: index + vector.compressstore %base[%c0], %mask, %value + : memref, vector<16xi1>, vector<16xf32> + return +} + +func @compress16_at8(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + %c8 = constant 8: index + vector.compressstore %base[%c8], %mask, %value : memref, vector<16xi1>, vector<16xf32> return } @@ -86,5 +95,10 @@ call @printmem16(%A) : (memref) -> () // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + call @compress16_at8(%A, %some1, %value) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 0, 1, 2, 3, 12, 13, 14, 15 ) + return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir @@ -5,8 +5,18 @@ func @expand16(%base: memref, %mask: vector<16xi1>, - %pass_thru: vector<16xf32>) -> vector<16xf32> { - %e = vector.expandload %base, %mask, %pass_thru + %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0: index + %e = vector.expandload %base[%c0], %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %e : vector<16xf32> +} + +func @expand16_at8(%base: memref, + %mask: vector<16xi1>, + %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c8 = constant 8: index + %e = vector.expandload %base[%c8], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return %e : vector<16xf32> } @@ -78,5 +88,10 @@ vector.print %e6 : vector<16xf32> // CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 ) + %e7 = call @expand16_at8(%A, %some1, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %e7 : vector<16xf32> + // CHECK-NEXT: ( 8, 9, 10, 11, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 ) + return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir @@ -5,7 +5,16 @@ func @maskedload16(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> { - %ld = vector.maskedload %base, %mask, %pass_thru + %c0 = constant 0: index + %ld = vector.maskedload %base[%c0], %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +func @maskedload16_at8(%base: memref, %mask: vector<16xi1>, + %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c8 = constant 8: index + %ld = vector.maskedload %base[%c8], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } @@ -61,6 +70,11 @@ vector.print %l4 : vector<16xf32> // CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 ) + %l5 = call @maskedload16_at8(%A, %some, %pass) + : (memref, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>) + vector.print %l5 : vector<16xf32> + // CHECK: ( 8, 9, 10, 11, 12, 13, 14, 15, -7, -7, -7, -7, -7, -7, -7, -7 ) + return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir @@ -5,8 +5,17 @@ func @maskedstore16(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { - vector.maskedstore %base, %mask, %value - : vector<16xi1>, vector<16xf32> into memref + %c0 = constant 0: index + vector.maskedstore %base[%c0], %mask, %value + : memref, vector<16xi1>, vector<16xf32> + return +} + +func @maskedstore16_at8(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + %c8 = constant 8: index + vector.maskedstore %base[%c8], %mask, %value + : memref, vector<16xi1>, vector<16xf32> return } @@ -85,5 +94,10 @@ call @printmem16(%A) : (memref) -> () // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 ) + call @maskedstore16_at8(%A, %some, %val) + : (memref, vector<16xi1>, vector<16xf32>) -> () + call @printmem16(%A) : (memref) -> () + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 ) + return } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -173,33 +173,7 @@ return success(); } -// Helper that returns a pointer given a memref base. -static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, - Location loc, Value memref, - MemRefType memRefType, Value &ptr) { - Value base; - if (failed(getBase(rewriter, loc, memref, memRefType, base))) - return failure(); - auto pType = MemRefDescriptor(memref).getElementPtrType(); - ptr = rewriter.create(loc, pType, base); - return success(); -} - -// Helper that returns a bit-casted pointer given a memref base. -static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, - Location loc, Value memref, - MemRefType memRefType, Type type, Value &ptr) { - Value base; - if (failed(getBase(rewriter, loc, memref, memRefType, base))) - return failure(); - auto pType = LLVM::LLVMPointerType::get(type); - base = rewriter.create(loc, pType, base); - ptr = rewriter.create(loc, pType, base); - return success(); -} - -// Helper that returns vector of pointers given a memref base and an index -// vector. +// Helper that returns vector of pointers given a memref base with index vector. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, Value memref, Value indices, MemRefType memRefType, VectorType vType, @@ -213,6 +187,18 @@ return success(); } +// Casts a strided element pointer to a vector pointer. The vector pointer +// would always be on address space 0, therefore addrspacecast shall be +// used when source/dst memrefs are not on address space 0. +static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, MemRefType memRefType, Type vt) { + auto pType = + LLVM::LLVMPointerType::get(vt.template cast()); + if (memRefType.getMemorySpace() == 0) + return rewriter.create(loc, pType, ptr); + return rewriter.create(loc, pType, ptr); +} + static LogicalResult replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, @@ -343,18 +329,18 @@ ConversionPatternRewriter &rewriter) const override { auto loc = load->getLoc(); auto adaptor = vector::MaskedLoadOpAdaptor(operands); + MemRefType memRefType = load.getMemRefType(); // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(), - align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); + // Resolve address. auto vtype = typeConverter->convertType(load.getResultVectorType()); - Value ptr; - if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), - vtype, ptr))) - return failure(); + Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); + Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); rewriter.replaceOpWithNewOp( load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), @@ -374,18 +360,18 @@ ConversionPatternRewriter &rewriter) const override { auto loc = store->getLoc(); auto adaptor = vector::MaskedStoreOpAdaptor(operands); + MemRefType memRefType = store.getMemRefType(); // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(), - align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); + // Resolve address. auto vtype = typeConverter->convertType(store.getValueVectorType()); - Value ptr; - if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), - vtype, ptr))) - return failure(); + Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); + Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype); rewriter.replaceOpWithNewOp( store, adaptor.value(), ptr, adaptor.mask(), @@ -473,16 +459,15 @@ ConversionPatternRewriter &rewriter) const override { auto loc = expand->getLoc(); auto adaptor = vector::ExpandLoadOpAdaptor(operands); + MemRefType memRefType = expand.getMemRefType(); - Value ptr; - if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), - ptr))) - return failure(); + // Resolve address. + auto vtype = typeConverter->convertType(expand.getResultVectorType()); + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); - auto vType = expand.getResultVectorType(); rewriter.replaceOpWithNewOp( - expand, typeConverter->convertType(vType), ptr, adaptor.mask(), - adaptor.pass_thru()); + expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); return success(); } }; @@ -498,11 +483,11 @@ ConversionPatternRewriter &rewriter) const override { auto loc = compress->getLoc(); auto adaptor = vector::CompressStoreOpAdaptor(operands); + MemRefType memRefType = compress.getMemRefType(); - Value ptr; - if (failed(getBasePtr(rewriter, loc, adaptor.base(), - compress.getMemRefType(), ptr))) - return failure(); + // Resolve address. + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( compress, adaptor.value(), ptr, adaptor.mask()); @@ -1223,21 +1208,11 @@ } // 1. Get the source/dst address as an LLVM vector pointer. - // The vector pointer would always be on address space 0, therefore - // addrspacecast shall be used when source/dst memrefs are not on - // address space 0. - // TODO: support alignment when possible. + VectorType vtp = xferOp.getVectorType(); Value dataPtr = this->getStridedElementPtr( loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); - auto vecTy = toLLVMTy(xferOp.getVectorType()) - .template cast(); - Value vectorDataPtr; - if (memRefType.getMemorySpace() == 0) - vectorDataPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); - else - vectorDataPtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); + Value vectorDataPtr = + castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); if (!xferOp.isMaskedDim(0)) return replaceTransferOpWithLoadOrStore(rewriter, @@ -1251,7 +1226,7 @@ // // TODO: when the leaf transfer rank is k > 1, we need the last `k` // dimensions here. - unsigned vecWidth = vecTy.getNumElements(); + unsigned vecWidth = vtp.getNumElements(); unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; Value dim = rewriter.create(loc, xferOp.source(), lastIndex); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -76,20 +76,6 @@ return MaskFormat::Unknown; } -/// Helper method to cast a 1-D memref<10xf32> "base" into a -/// memref> in the output parameter "newBase", -/// using the 'element' vector type "vt". Returns true on success. -static bool castedToMemRef(Location loc, Value base, MemRefType mt, - VectorType vt, PatternRewriter &rewriter, - Value &newBase) { - // The vector.type_cast operation does not accept unknown memref. - // TODO: generalize the cast and accept this case too - if (!mt.hasStaticShape()) - return false; - newBase = rewriter.create(loc, MemRefType::get({}, vt), base); - return true; -} - //===----------------------------------------------------------------------===// // VectorDialect //===----------------------------------------------------------------------===// @@ -2380,13 +2366,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { - Value newBase; switch (get1DMaskFormat(load.mask())) { case MaskFormat::AllTrue: - if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(), - load.getResultVectorType(), rewriter, newBase)) - return failure(); - rewriter.replaceOpWithNewOp(load, newBase); + rewriter.replaceOpWithNewOp( + load, load.getType(), load.base(), load.indices(), false); return success(); case MaskFormat::AllFalse: rewriter.replaceOp(load, load.pass_thru()); @@ -2426,13 +2409,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { - Value newBase; switch (get1DMaskFormat(store.mask())) { case MaskFormat::AllTrue: - if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(), - store.getValueVectorType(), rewriter, newBase)) - return failure(); - rewriter.replaceOpWithNewOp(store, store.value(), newBase); + rewriter.replaceOpWithNewOp( + store, store.value(), store.base(), store.indices(), false); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(store); @@ -2568,14 +2548,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { - Value newBase; switch (get1DMaskFormat(expand.mask())) { case MaskFormat::AllTrue: - if (!castedToMemRef(expand.getLoc(), expand.base(), - expand.getMemRefType(), expand.getResultVectorType(), - rewriter, newBase)) - return failure(); - rewriter.replaceOpWithNewOp(expand, newBase); + rewriter.replaceOpWithNewOp( + expand, expand.getType(), expand.base(), expand.indices(), false); return success(); case MaskFormat::AllFalse: rewriter.replaceOp(expand, expand.pass_thru()); @@ -2615,14 +2591,11 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { - Value newBase; switch (get1DMaskFormat(compress.mask())) { case MaskFormat::AllTrue: - if (!castedToMemRef(compress.getLoc(), compress.base(), - compress.getMemRefType(), - compress.getValueVectorType(), rewriter, newBase)) - return failure(); - rewriter.replaceOpWithNewOp(compress, compress.value(), newBase); + rewriter.replaceOpWithNewOp( + compress, compress.value(), compress.base(), compress.indices(), + false); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(compress); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1070,23 +1070,29 @@ // CHECK: llvm.return %[[T]] : !llvm.vec<16 x float> func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { - %0 = vector.maskedload %arg0, %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %c0 = constant 0: index + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return %0 : vector<16xf32> } // CHECK-LABEL: func @masked_load_op -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr>) -> !llvm.ptr> -// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr to !llvm.ptr> +// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[B]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr>, !llvm.vec<16 x i1>, !llvm.vec<16 x float>) -> !llvm.vec<16 x float> // CHECK: llvm.return %[[L]] : !llvm.vec<16 x float> func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { - vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref + %c0 = constant 0: index + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> return } // CHECK-LABEL: func @masked_store_op -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr>) -> !llvm.ptr> -// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr> +// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.intr.masked.store %{{.*}}, %[[B]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x float>, !llvm.vec<16 x i1> into !llvm.ptr> // CHECK: llvm.return func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { @@ -1110,21 +1116,25 @@ // CHECK: llvm.return func @expand_load_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> { - %0 = vector.expandload %arg0, %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> into vector<11xf32> + %c0 = constant 0: index + %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> into vector<11xf32> return %0 : vector<11xf32> } // CHECK-LABEL: func @expand_load_op -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr) -> !llvm.ptr +// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float> // CHECK: llvm.return %[[E]] : !llvm.vec<11 x float> func @compress_store_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) { - vector.compressstore %arg0, %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> + %c0 = constant 0: index + vector.compressstore %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> return } // CHECK-LABEL: func @compress_store_op -// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr) -> !llvm.ptr +// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr, !llvm.vec<11 x i1>) -> () // CHECK: llvm.return diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1199,36 +1199,41 @@ // ----- func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedload' op base and result element type should match}} - %0 = vector.maskedload %base, %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> } // ----- func @maskedload_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %pass: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}} - %0 = vector.maskedload %base, %mask, %pass : memref, vector<15xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<15xi1>, vector<16xf32> into vector<16xf32> } // ----- func @maskedload_pass_thru_type_mask_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xi32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}} - %0 = vector.maskedload %base, %mask, %pass : memref, vector<16xi1>, vector<16xi32> into vector<16xf32> + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xi1>, vector<16xi32> into vector<16xf32> } // ----- func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}} - vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref + vector.maskedstore %base[%c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> } // ----- func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}} - vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref + vector.maskedstore %base[%c0], %mask, %value : memref, vector<15xi1>, vector<16xf32> } // ----- @@ -1297,36 +1302,41 @@ // ----- func @expand_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.expandload' op base and result element type should match}} - %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> } // ----- func @expand_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}} - %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<17xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<17xi1>, vector<16xf32> into vector<16xf32> } // ----- func @expand_pass_thru_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}} - %0 = vector.expandload %base, %mask, %pass_thru : memref, vector<16xi1>, vector<17xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<17xf32> into vector<16xf32> } // ----- func @compress_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.compressstore' op base and value element type should match}} - vector.compressstore %base, %mask, %value : memref, vector<16xi1>, vector<16xf32> + vector.compressstore %base[%c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> } // ----- func @compress_dim_mask_mismatch(%base: memref, %mask: vector<17xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} - vector.compressstore %base, %mask, %value : memref, vector<17xi1>, vector<16xf32> + vector.compressstore %base[%c0], %mask, %value : memref, vector<17xi1>, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -452,10 +452,11 @@ // CHECK-LABEL: @masked_load_and_store func @masked_load_and_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { - // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.maskedload %base, %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref - vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref + %c0 = constant 0 : index + // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + vector.maskedstore %base[%c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } @@ -472,10 +473,11 @@ // CHECK-LABEL: @expand_and_compress func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { - // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.expandload %base, %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> - vector.compressstore %base, %mask, %0 : memref, vector<16xi1>, vector<16xf32> + %c0 = constant 0 : index + // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + vector.compressstore %base[%c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -1,82 +1,93 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s -// -// TODO: optimize this one too! -// -// CHECK-LABEL: func @maskedload0( -// CHECK-SAME: %[[A0:.*]]: memref, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[M:.*]] = vector.constant_mask -// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-NEXT: return %[[T]] : vector<16xf32> - +// CHECK-LABEL: func @maskedload0( +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-DAG: %[[C:.*]] = constant 0 : index +// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref, vector<16xf32> +// CHECK-NEXT: return %[[T]] : vector<16xf32> func @maskedload0(%base: memref, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - %ld = vector.maskedload %base, %mask, %pass_thru + %ld = vector.maskedload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } -// CHECK-LABEL: func @maskedload1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> -// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref> -// CHECK-NEXT: return %[[T1]] : vector<16xf32> - +// CHECK-LABEL: func @maskedload1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-DAG: %[[C:.*]] = constant 0 : index +// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32> +// CHECK-NEXT: return %[[T]] : vector<16xf32> func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - %ld = vector.maskedload %base, %mask, %pass_thru + %ld = vector.maskedload %base[%c0], %mask, %pass_thru : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } -// CHECK-LABEL: func @maskedload2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: return %[[A1]] : vector<16xf32> - +// CHECK-LABEL: func @maskedload2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-NEXT: return %[[A1]] : vector<16xf32> func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> - %ld = vector.maskedload %base, %mask, %pass_thru + %ld = vector.maskedload %base[%c0], %mask, %pass_thru : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } -// CHECK-LABEL: func @maskedstore1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> -// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref> -// CHECK-NEXT: return +// CHECK-LABEL: func @maskedload3( +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-DAG: %[[C:.*]] = constant 8 : index +// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref, vector<16xf32> +// CHECK-NEXT: return %[[T]] : vector<16xf32> +func @maskedload3(%base: memref, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c8 = constant 8 : index + %mask = vector.constant_mask [16] : vector<16xi1> + %ld = vector.maskedload %base[%c8], %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} +// CHECK-LABEL: func @maskedstore1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) { +// CHECK-NEXT: %[[C:.*]] = constant 0 : index +// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32> +// CHECK-NEXT: return func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) { + %c0 = constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - vector.maskedstore %base, %mask, %value - : vector<16xi1>, vector<16xf32> into memref<16xf32> + vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> return } -// CHECK-LABEL: func @maskedstore2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: return - +// CHECK-LABEL: func @maskedstore2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) { +// CHECK-NEXT: return func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) { + %c0 = constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> - vector.maskedstore %base, %mask, %value - : vector<16xi1>, vector<16xf32> into memref<16xf32> + vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> return } -// CHECK-LABEL: func @gather1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, -// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> -// CHECK-NEXT: return %1 : vector<16xf32> - +// CHECK-LABEL: func @gather1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> +// CHECK-NEXT: return %[[G]] : vector<16xf32> func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [16] : vector<16xi1> %ld = vector.gather %base, %indices, %mask, %pass_thru @@ -84,12 +95,11 @@ return %ld : vector<16xf32> } -// CHECK-LABEL: func @gather2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, -// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -// CHECK-NEXT: return %[[A2]] : vector<16xf32> - +// CHECK-LABEL: func @gather2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-NEXT: return %[[A2]] : vector<16xf32> func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [0] : vector<16xi1> %ld = vector.gather %base, %indices, %mask, %pass_thru @@ -97,14 +107,13 @@ return %ld : vector<16xf32> } -// CHECK-LABEL: func @scatter1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, -// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> -// CHECK-NEXT: return - +// CHECK-LABEL: func @scatter1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) { +// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> +// CHECK-NEXT: return func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %mask = vector.constant_mask [16] : vector<16xi1> vector.scatter %base, %indices, %mask, %value @@ -112,12 +121,11 @@ return } -// CHECK-LABEL: func @scatter2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, -// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -// CHECK-NEXT: return - +// CHECK-LABEL: func @scatter2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) { +// CHECK-NEXT: return func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %0 = vector.type_cast %base : memref<16xf32> to memref> %mask = vector.constant_mask [0] : vector<16xi1> @@ -126,52 +134,53 @@ return } -// CHECK-LABEL: func @expand1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> -// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref> -// CHECK-NEXT: return %[[T1]] : vector<16xf32> - +// CHECK-LABEL: func @expand1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-DAG: %[[C:.*]] = constant 0 : index +// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32> +// CHECK-NEXT: return %[[T]] : vector<16xf32> func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - %ld = vector.expandload %base, %mask, %pass_thru + %ld = vector.expandload %base[%c0], %mask, %pass_thru : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } -// CHECK-LABEL: func @expand2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: return %[[A1]] : vector<16xf32> - +// CHECK-LABEL: func @expand2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> { +// CHECK-NEXT: return %[[A1]] : vector<16xf32> func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %c0 = constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> - %ld = vector.expandload %base, %mask, %pass_thru + %ld = vector.expandload %base[%c0], %mask, %pass_thru : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } -// CHECK-LABEL: func @compress1( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> -// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref> -// CHECK-NEXT: return - +// CHECK-LABEL: func @compress1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) { +// CHECK-NEXT: %[[C:.*]] = constant 0 : index +// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32> +// CHECK-NEXT: return func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) { + %c0 = constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> return } -// CHECK-LABEL: func @compress2( -// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, -// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -// CHECK-NEXT: return - +// CHECK-LABEL: func @compress2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) { +// CHECK-NEXT: return func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) { + %c0 = constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> - vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -24,13 +24,25 @@ struct TestVectorToVectorConversion : public PassWrapper { + TestVectorToVectorConversion() = default; + TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), + llvm::cl::init(false)}; + void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert( - ctx, - UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( - filter)); + if (unroll) { + patterns.insert( + ctx, + UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( + filter)); + } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));