diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_ +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" @@ -30,12 +31,14 @@ class VectorTransferOpInterface; namespace vector { -class TransferWriteOp; -class TransferReadOp; - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); + +/// Return the result value of reducing two scalar/vector values with the +/// corresponding arith operation. +Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, + Value v1, Value v2); } // namespace vector /// Return the number of elements of basis, `0` if empty. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -243,47 +243,8 @@ for (int64_t i = 1; i < srcShape[0]; i++) { auto operand = rewriter.create(loc, multiReductionOp.source(), i); - switch (multiReductionOp.kind()) { - case vector::CombiningKind::ADD: - if (elementType.isIntOrIndex()) - result = rewriter.create(loc, operand, result); - else - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MUL: - if (elementType.isIntOrIndex()) - result = rewriter.create(loc, operand, result); - else - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINUI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINSI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MINF: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXUI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXSI: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::MAXF: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::AND: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::OR: - result = rewriter.create(loc, operand, result); - break; - case vector::CombiningKind::XOR: - result = rewriter.create(loc, operand, result); - break; - } + result = makeArithReduction(rewriter, loc, multiReductionOp.kind(), + operand, result); } rewriter.replaceOp(multiReductionOp, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" + #include #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -18,8 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" - -#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -514,40 +515,11 @@ if (!acc) return Optional(mul); - Value combinedResult; - switch (kind) { - case CombiningKind::ADD: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MUL: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MINUI: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MINSI: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MAXUI: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MAXSI: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::AND: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::OR: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::XOR: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MINF: // Only valid for floating point types. - case CombiningKind::MAXF: // Only valid for floating point types. + if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) + // Only valid for floating point types. return Optional(); - } - return Optional(combinedResult); + + return makeArithReduction(rewriter, loc, kind, mul, acc); } static Optional genMultF(Location loc, Value x, Value y, Value acc, @@ -565,28 +537,14 @@ if (!acc) return Optional(mul); - Value combinedResult; - switch (kind) { - case CombiningKind::MUL: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MINF: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::MAXF: - combinedResult = rewriter.create(loc, mul, acc); - break; - case CombiningKind::ADD: // Already handled this special case above. - case CombiningKind::AND: // Only valid for integer types. - case CombiningKind::MINUI: // Only valid for integer types. - case CombiningKind::MINSI: // Only valid for integer types. - case CombiningKind::MAXUI: // Only valid for integer types. - case CombiningKind::MAXSI: // Only valid for integer types. - case CombiningKind::OR: // Only valid for integer types. - case CombiningKind::XOR: // Only valid for integer types. + if (kind == CombiningKind::ADD || kind == CombiningKind::AND || + kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || + kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || + kind == CombiningKind::OR || kind == CombiningKind::XOR) + // Already handled or only valid for integer types. return Optional(); - } - return Optional(combinedResult); + + return makeArithReduction(rewriter, loc, kind, mul, acc); } }; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include @@ -42,6 +43,56 @@ llvm_unreachable("Expected MemRefType or TensorType"); } +Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, + CombiningKind kind, Value v1, Value v2) { + Type t1 = getElementTypeOrSelf(v1.getType()); + Type t2 = getElementTypeOrSelf(v2.getType()); + switch (kind) { + case CombiningKind::ADD: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for ADD reduction"); + case CombiningKind::AND: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MUL: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for MUL reduction"); + case CombiningKind::OR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::XOR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + }; + llvm_unreachable("unknown CombiningKind"); +} + /// Return the number of elements of basis, `0` if empty. int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { if (basis.empty())