diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1198,6 +1198,7 @@ }]; let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; + let hasCanonicalizer = 1; } def Vector_MaskedStoreOp : @@ -1244,6 +1245,7 @@ }]; let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " "type($mask) `,` type($value) `into` type($base)"; + let hasCanonicalizer = 1; } def Vector_GatherOp : @@ -1303,6 +1305,7 @@ } }]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + let hasCanonicalizer = 1; } def Vector_ScatterOp : @@ -1358,6 +1361,7 @@ }]; let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` " "type($indices) `,` type($mask) `,` type($value) `into` type($base)"; + let hasCanonicalizer = 1; } def Vector_ExpandLoadOp : @@ -1411,6 +1415,7 @@ }]; let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; + let hasCanonicalizer = 1; } def Vector_CompressStoreOp : @@ -1460,6 +1465,7 @@ }]; let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` " "type($base) `,` type($mask) `,` type($value)"; + let hasCanonicalizer = 1; } def Vector_ShapeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -30,6 +30,58 @@ using namespace mlir; using namespace mlir::vector; +namespace { + +// Helper enum and method to classify mask value. +enum class MaskFormat { + AllTrue = 0, + AllFalse = 1, + Unknown = 2, +}; +MaskFormat get1DMaskFormat(Value mask) { + if (auto c = mask.getDefiningOp()) { + // Inspect constant dense values. + if (auto denseElts = c.value().dyn_cast()) { + int64_t val = 0; + for (auto b : denseElts) + if (b.getBoolValue() && val >= 0) + val++; + else if (!b.getBoolValue() && val <= 0) + val--; + else + return MaskFormat::Unknown; + if (val > 0) + return MaskFormat::AllTrue; + if (val < 0) + return MaskFormat::AllFalse; + } + } else if (auto m = mask.getDefiningOp()) { + // Inspect constant mask index. + auto masks = m.mask_dim_sizes(); + assert(masks.size() == 1); + int64_t i = masks[0].cast().getInt(); + int64_t u = m.getType().cast().getDimSize(0); + if (i >= u) + return MaskFormat::AllTrue; + if (i <= 0) + return MaskFormat::AllFalse; + } + return MaskFormat::Unknown; +} + +// Helper method to cast 1-D memref<10xf32> into memref>. +bool canCastMemRef(Location loc, Value base, MemRefType mt, VectorType vt, + PatternRewriter &rewriter, Value &newBase) { + // The vector.type_cast operation does not accept unknown memref. + // TODO: generalize the cast and accept this case too + if (!mt.hasStaticShape()) + return false; + newBase = rewriter.create(loc, MemRefType::get({}, vt), base); + return true; +} + +} // namespace + //===----------------------------------------------------------------------===// // VectorDialect //===----------------------------------------------------------------------===// @@ -1869,6 +1921,35 @@ return success(); } +namespace { +class MaskedLoadFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MaskedLoadOp load, + PatternRewriter &rewriter) const override { + Value newBase; + switch (get1DMaskFormat(load.mask())) { + case MaskFormat::AllTrue: + if (!canCastMemRef(load.getLoc(), load.base(), load.getMemRefType(), + load.getResultVectorType(), rewriter, newBase)) + return failure(); + rewriter.replaceOpWithNewOp(load, newBase); + return success(); + case MaskFormat::AllFalse: + rewriter.replaceOp(load, load.pass_thru()); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void MaskedLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MaskedStoreOp //===----------------------------------------------------------------------===// @@ -1885,6 +1966,35 @@ return success(); } +namespace { +class MaskedStoreFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MaskedStoreOp store, + PatternRewriter &rewriter) const override { + Value newBase; + switch (get1DMaskFormat(store.mask())) { + case MaskFormat::AllTrue: + if (!canCastMemRef(store.getLoc(), store.base(), store.getMemRefType(), + store.getValueVectorType(), rewriter, newBase)) + return failure(); + rewriter.replaceOpWithNewOp(store, store.value(), newBase); + return success(); + case MaskFormat::AllFalse: + rewriter.eraseOp(store); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void MaskedStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// @@ -1909,6 +2019,30 @@ return success(); } +namespace { +class GatherFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter &rewriter) const override { + switch (get1DMaskFormat(gather.mask())) { + case MaskFormat::AllTrue: + return failure(); // no unmasked equivalent + case MaskFormat::AllFalse: + rewriter.replaceOp(gather, gather.pass_thru()); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// @@ -1928,6 +2062,30 @@ return success(); } +namespace { +class ScatterFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ScatterOp scatter, + PatternRewriter &rewriter) const override { + switch (get1DMaskFormat(scatter.mask())) { + case MaskFormat::AllTrue: + return failure(); // no unmasked equivalent + case MaskFormat::AllFalse: + rewriter.eraseOp(scatter); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ExpandLoadOp //===----------------------------------------------------------------------===// @@ -1947,6 +2105,35 @@ return success(); } +namespace { +class ExpandLoadFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpandLoadOp expand, + PatternRewriter &rewriter) const override { + Value newBase; + switch (get1DMaskFormat(expand.mask())) { + case MaskFormat::AllTrue: + if (!canCastMemRef(expand.getLoc(), expand.base(), expand.getMemRefType(), + expand.getResultVectorType(), rewriter, newBase)) + return failure(); + rewriter.replaceOpWithNewOp(expand, newBase); + return success(); + case MaskFormat::AllFalse: + rewriter.replaceOp(expand, expand.pass_thru()); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void ExpandLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // CompressStoreOp //===----------------------------------------------------------------------===// @@ -1963,6 +2150,36 @@ return success(); } +namespace { +class CompressStoreFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CompressStoreOp compress, + PatternRewriter &rewriter) const override { + Value newBase; + switch (get1DMaskFormat(compress.mask())) { + case MaskFormat::AllTrue: + if (!canCastMemRef(compress.getLoc(), compress.base(), + compress.getMemRefType(), + compress.getValueVectorType(), rewriter, newBase)) + return failure(); + rewriter.replaceOpWithNewOp(compress, compress.value(), newBase); + return success(); + case MaskFormat::AllFalse: + rewriter.eraseOp(compress); + return success(); + case MaskFormat::Unknown: + return failure(); + } + } +}; +} // namespace + +void CompressStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// @@ -2390,7 +2607,9 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); } diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -0,0 +1,177 @@ +// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s + +// +// TODO: optimize this one too! +// +// CHECK-LABEL: func @maskedload0( +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[M:.*]] = vector.constant_mask +// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-NEXT: return %[[T]] : vector<16xf32> + +func @maskedload0(%base: memref, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [16] : vector<16xi1> + %ld = vector.maskedload %base, %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @maskedload1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> +// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref> +// CHECK-NEXT: return %[[T1]] : vector<16xf32> + +func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [16] : vector<16xi1> + %ld = vector.maskedload %base, %mask, %pass_thru + : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @maskedload2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: return %[[A1]] : vector<16xf32> + +func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [0] : vector<16xi1> + %ld = vector.maskedload %base, %mask, %pass_thru + : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @maskedstore1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> +// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref> +// CHECK-NEXT: return + +func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) { + %mask = vector.constant_mask [16] : vector<16xi1> + vector.maskedstore %base, %mask, %value + : vector<16xi1>, vector<16xf32> into memref<16xf32> + return +} + +// CHECK-LABEL: func @maskedstore2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: return + +func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) { + %mask = vector.constant_mask [0] : vector<16xi1> + vector.maskedstore %base, %mask, %value + : vector<16xi1>, vector<16xf32> into memref<16xf32> + return +} + +// CHECK-LABEL: func @gather1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> +// CHECK-NEXT: return %1 : vector<16xf32> + +func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [16] : vector<16xi1> + %ld = vector.gather %base, %indices, %mask, %pass_thru + : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @gather2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) +// CHECK-NEXT: return %[[A2]] : vector<16xf32> + +func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [0] : vector<16xi1> + %ld = vector.gather %base, %indices, %mask, %pass_thru + : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @scatter1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1> +// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> +// CHECK-NEXT: return + +func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { + %mask = vector.constant_mask [16] : vector<16xi1> + vector.scatter %base, %indices, %mask, %value + : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> + return +} + +// CHECK-LABEL: func @scatter2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xi32>, +// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) +// CHECK-NEXT: return + +func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { + %0 = vector.type_cast %base : memref<16xf32> to memref> + %mask = vector.constant_mask [0] : vector<16xi1> + vector.scatter %base, %indices, %mask, %value + : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> + return +} + +// CHECK-LABEL: func @expand1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> +// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref> +// CHECK-NEXT: return %[[T1]] : vector<16xf32> + +func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [16] : vector<16xi1> + %ld = vector.expandload %base, %mask, %pass_thru + : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @expand2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: return %[[A1]] : vector<16xf32> + +func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { + %mask = vector.constant_mask [0] : vector<16xi1> + %ld = vector.expandload %base, %mask, %pass_thru + : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %ld : vector<16xf32> +} + +// CHECK-LABEL: func @compress1( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref> +// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref> +// CHECK-NEXT: return + +func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) { + %mask = vector.constant_mask [16] : vector<16xi1> + vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + return +} + +// CHECK-LABEL: func @compress2( +// CHECK-SAME: %[[A0:.*]]: memref<16xf32>, +// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) +// CHECK-NEXT: return + +func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) { + %mask = vector.constant_mask [0] : vector<16xi1> + vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + return +}