diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -38,20 +38,25 @@ } // The "kind" of combining function for contractions and reductions. -def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; -def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">; -def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">; -def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">; -def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">; -def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">; -def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">; +def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">; +def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">; +def COMBINING_KIND_MINUI : BitEnumAttrCase<"MINUI", 0x4, "minui">; +def COMBINING_KIND_MINSI : BitEnumAttrCase<"MINSI", 0x8, "minsi">; +def COMBINING_KIND_MINF : BitEnumAttrCase<"MINF", 0x10, "minf">; +def COMBINING_KIND_MAXUI : BitEnumAttrCase<"MAXUI", 0x20, "maxui">; +def COMBINING_KIND_MAXSI : BitEnumAttrCase<"MAXSI", 0x40, "maxsi">; +def COMBINING_KIND_MAXF : BitEnumAttrCase<"MAXF", 0x80, "maxf">; +def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x100, "and">; +def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x200, "or">; +def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x400, "xor">; def CombiningKind : BitEnumAttr< "CombiningKind", "Kind of combining function for contractions and reductions", - [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN, - COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, - COMBINING_KIND_XOR]> { + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI, + COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI, + COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND, + COMBINING_KIND_OR, COMBINING_KIND_XOR]> { let cppNamespace = "::mlir::vector"; let genSpecializedAttr = 0; } @@ -337,7 +342,7 @@ static SmallVector inferDestShape( ArrayRef shape, ArrayRef reducedDimsMask) { - assert(shape.size() == reducedDimsMask.size() && + assert(shape.size() == reducedDimsMask.size() && "shape and maks of different sizes"); SmallVector res; for (auto it : llvm::zip(reducedDimsMask, shape)) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -434,18 +434,16 @@ else if (kind == "mul") rewriter.replaceOpWithNewOp(reductionOp, llvmType, operand); - else if (kind == "min" && - (eltType.isIndex() || eltType.isUnsignedInteger())) + else if (kind == "minui") rewriter.replaceOpWithNewOp( reductionOp, llvmType, operand); - else if (kind == "min") + else if (kind == "minsi") rewriter.replaceOpWithNewOp( reductionOp, llvmType, operand); - else if (kind == "max" && - (eltType.isIndex() || eltType.isUnsignedInteger())) + else if (kind == "maxui") rewriter.replaceOpWithNewOp( reductionOp, llvmType, operand); - else if (kind == "max") + else if (kind == "maxsi") rewriter.replaceOpWithNewOp( reductionOp, llvmType, operand); else if (kind == "and") @@ -486,10 +484,14 @@ rewriter.replaceOpWithNewOp( reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == "min") + } else if (kind == "minf") + // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle + // NaNs/-0.0/+0.0 in the same way. rewriter.replaceOpWithNewOp(reductionOp, llvmType, operand); - else if (kind == "max") + else if (kind == "maxf") + // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle + // NaNs/-0.0/+0.0 in the same way. rewriter.replaceOpWithNewOp(reductionOp, llvmType, operand); else diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -111,24 +111,40 @@ return VectorType::get(st.getShape(), st.getElementType()); } +static llvm::Optional +getKindForOp(Operation *reductionOp) { + if (!reductionOp) + return llvm::None; + return llvm::TypeSwitch>( + reductionOp) + .Case([&](auto op) { return vector::CombiningKind::ADD; }) + .Case([&](auto op) { return vector::CombiningKind::MAXSI; }) + .Case([&](auto op) { return vector::CombiningKind::MAXF; }) + .Case([&](auto op) { return vector::CombiningKind::MINSI; }) + .Case([&](auto op) { return vector::CombiningKind::MINF; }) + .Default([&](auto op) { return llvm::None; }); +} + /// Check whether `outputOperand` is a reduction with a single combiner -/// operation. Return the combiner operation of the reduction, which is assumed -/// to be a binary operation. Multiple reduction operations would impose an -/// ordering between reduction dimensions and is currently unsupported in -/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != +/// operation. Return the combiner operation kind of the reduction, if +/// supported. Return llvm::None, otherwise. Multiple reduction operations would +/// impose an ordering between reduction dimensions and is currently unsupported +/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != /// max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. -static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) { +static llvm::Optional +matchLinalgReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); + // Only single combiner operatios are supported for now. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || combinerOps.size() != 1) - return nullptr; + return llvm::None; - // TODO: also assert no other subsequent ops break the reduction. - return combinerOps[0]; + // Return the combiner operation kind, if supported. + return getKindForOp(combinerOps[0]); } /// If `value` of assumed VectorType has a shape different than `shape`, try to @@ -151,19 +167,6 @@ newVecType, value); } -static llvm::Optional -getKindForOp(Operation *reductionOp) { - if (!reductionOp) - return llvm::None; - return llvm::TypeSwitch>( - reductionOp) - .Case([&](auto op) { - return llvm::Optional{ - vector::CombiningKind::ADD}; - }) - .Default([&](auto op) { return llvm::None; }); -} - /// If value of assumed VectorType has a shape different than `shape`, build and /// return a new vector.broadcast to `shape`. /// Otherwise, just return value. @@ -173,9 +176,7 @@ auto vecType = value.getType().dyn_cast(); if (!vecType || vecType.getShape() == targetVectorType.getShape()) return value; - // At this point, we know we need to reduce. Detect the reduction operator. - // TODO: Use the generic reduction detection util. - Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand); + unsigned pos = 0; MLIRContext *ctx = b.getContext(); SmallVector exprs; @@ -183,8 +184,9 @@ if (isParallelIterator(s)) exprs.push_back(getAffineDimExpr(pos++, ctx)); auto loc = value.getLoc(); - // TODO: reuse common CombiningKing logic and support more than add. - auto maybeKind = getKindForOp(reductionOp); + + // At this point, we know we need to reduce. Detect the reduction operator. + auto maybeKind = matchLinalgReduction(outputOperand); assert(maybeKind && "Failed precondition: could not get reduction kind"); unsigned idx = 0; SmallVector reductionMask(linalgOp.iterator_types().size(), false); @@ -597,8 +599,7 @@ if (llvm::none_of(op.iterator_types(), isReductionIterator)) return failure(); for (OpOperand *opOperand : op.getOutputOperands()) { - Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand); - if (!getKindForOp(reductionOp)) + if (!matchLinalgReduction(opOperand)) return failure(); } return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -92,13 +92,18 @@ switch (combiningKind) { case CombiningKind::ADD: case CombiningKind::MUL: - case CombiningKind::MIN: - case CombiningKind::MAX: return elementType.isIntOrIndexOrFloat(); + case CombiningKind::MINUI: + case CombiningKind::MINSI: + case CombiningKind::MAXUI: + case CombiningKind::MAXSI: case CombiningKind::AND: case CombiningKind::OR: case CombiningKind::XOR: return elementType.isIntOrIndex(); + case CombiningKind::MINF: + case CombiningKind::MAXF: + return elementType.isa(); } return false; } @@ -151,8 +156,12 @@ // clang-format off CombiningKind::ADD, CombiningKind::MUL, - CombiningKind::MIN, - CombiningKind::MAX, + CombiningKind::MINUI, + CombiningKind::MINSI, + CombiningKind::MINF, + CombiningKind::MAXUI, + CombiningKind::MAXSI, + CombiningKind::MAXF, CombiningKind::AND, CombiningKind::OR, CombiningKind::XOR, @@ -291,22 +300,20 @@ return op.emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. - auto kind = op.kind(); + StringRef strKind = op.kind(); + auto maybeKind = symbolizeCombiningKind(strKind); + if (!maybeKind) + return op.emitOpError("unknown reduction kind: ") << strKind; + Type eltType = op.dest().getType(); - if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") { - if (!eltType.isIntOrIndexOrFloat()) - return op.emitOpError("unsupported reduction type"); - } else if (kind == "and" || kind == "or" || kind == "xor") { - if (!eltType.isIntOrIndex()) - return op.emitOpError("unsupported reduction type"); - } else { - return op.emitOpError("unknown reduction kind: ") << kind; - } + if (!isSupportedCombiningKind(*maybeKind, eltType)) + return op.emitOpError("unsupported reduction type '") + << eltType << "' for kind '" << op.kind() << "'"; // Verify optional accumulator. if (!op.acc().empty()) { - if (kind != "add" && kind != "mul") - return op.emitOpError("no accumulator for reduction kind: ") << kind; + if (strKind != "add" && strKind != "mul") + return op.emitOpError("no accumulator for reduction kind: ") << strKind; if (!eltType.isa()) return op.emitOpError("no accumulator for type: ") << eltType; } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -821,15 +821,17 @@ case CombiningKind::MUL: combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MIN: - combinedResult = rewriter.create( - loc, rewriter.create(loc, CmpIPredicate::slt, mul, acc), mul, - acc); + case CombiningKind::MINUI: + combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MAX: - combinedResult = rewriter.create( - loc, rewriter.create(loc, CmpIPredicate::sge, mul, acc), mul, - acc); + case CombiningKind::MINSI: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MAXUI: + combinedResult = rewriter.create(loc, mul, acc); + break; + case CombiningKind::MAXSI: + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::AND: combinedResult = rewriter.create(loc, mul, acc); @@ -840,6 +842,9 @@ case CombiningKind::XOR: combinedResult = rewriter.create(loc, mul, acc); break; + case CombiningKind::MINF: // Only valid for floating point types. + case CombiningKind::MAXF: // Only valid for floating point types. + return Optional(); } return Optional(combinedResult); } @@ -864,18 +869,18 @@ case CombiningKind::MUL: combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MIN: - combinedResult = rewriter.create( - loc, rewriter.create(loc, CmpFPredicate::OLE, mul, acc), mul, - acc); + case CombiningKind::MINF: + combinedResult = rewriter.create(loc, mul, acc); break; - case CombiningKind::MAX: - combinedResult = rewriter.create( - loc, rewriter.create(loc, CmpFPredicate::OGT, mul, acc), mul, - acc); + case CombiningKind::MAXF: + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::ADD: // Already handled this special case above. case CombiningKind::AND: // Only valid for integer types. + case CombiningKind::MINUI: // Only valid for integer types. + case CombiningKind::MINSI: // Only valid for integer types. + case CombiningKind::MAXUI: // Only valid for integer types. + case CombiningKind::MAXSI: // Only valid for integer types. case CombiningKind::OR: // Only valid for integer types. case CombiningKind::XOR: // Only valid for integer types. return Optional(); @@ -3697,23 +3702,23 @@ else result = rewriter.create(loc, operand, result); break; - case vector::CombiningKind::MIN: - if (elementType.isIntOrIndex()) - condition = - rewriter.create(loc, CmpIPredicate::slt, operand, result); - else - condition = - rewriter.create(loc, CmpFPredicate::OLT, operand, result); - result = rewriter.create(loc, condition, operand, result); + case vector::CombiningKind::MINUI: + result = rewriter.create(loc, operand, result); break; - case vector::CombiningKind::MAX: - if (elementType.isIntOrIndex()) - condition = - rewriter.create(loc, CmpIPredicate::sge, operand, result); - else - condition = - rewriter.create(loc, CmpFPredicate::OGE, operand, result); - result = rewriter.create(loc, condition, operand, result); + case vector::CombiningKind::MINSI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MINF: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXUI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXSI: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MAXF: + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::AND: result = rewriter.create(loc, operand, result); @@ -3771,10 +3776,18 @@ return "add"; case vector::CombiningKind::MUL: return "mul"; - case vector::CombiningKind::MIN: - return "min"; - case vector::CombiningKind::MAX: - return "max"; + case vector::CombiningKind::MINUI: + return "minui"; + case vector::CombiningKind::MINSI: + return "minsi"; + case vector::CombiningKind::MINF: + return "minf"; + case vector::CombiningKind::MAXUI: + return "maxui"; + case vector::CombiningKind::MAXSI: + return "maxsi"; + case vector::CombiningKind::MAXF: + return "maxf"; case vector::CombiningKind::AND: return "and"; case vector::CombiningKind::OR: diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -806,3 +806,54 @@ } -> tensor<5x2xf32> return %0 : tensor<5x2xf32> } + +// ----- + +// CHECK-LABEL: func @red_max_2d( +func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> + // CHECK: maxf {{.*}} : vector<4x4xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + %minf32 = constant -3.40282e+38 : f32 + %init = linalg.init_tensor [4] : tensor<4xf32> + %fill = linalg.fill(%minf32, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): // no predecessors + %max = maxf %in0, %out0 : f32 + linalg.yield %max : f32 + } -> tensor<4xf32> + return %red : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @red_min_2d( +func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { + // CHECK: linalg.init_tensor [4] : tensor<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32> + // CHECK: minf {{.*}} : vector<4x4xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + %maxf32 = constant 3.40282e+38 : f32 + %init = linalg.init_tensor [4] : tensor<4xf32> + %fill = linalg.fill(%maxf32, %init) : f32, tensor<4xf32> -> tensor<4xf32> + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): // no predecessors + %min = minf %in0, %out0 : f32 + linalg.yield %min : f32 + } -> tensor<4xf32> + return %red : tensor<4xf32> +} + diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1019,7 +1019,7 @@ func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 { // expected-error@+1 {{'vector.reduction' op no accumulator for reduction kind: min}} - %0 = vector.reduction "min", %arg0, %arg1 : vector<16xf32> into f32 + %0 = vector.reduction "minf", %arg0, %arg1 : vector<16xf32> into f32 } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -243,13 +243,13 @@ #contraction_to_scalar_max_trait = { indexing_maps = #contraction_to_scalar_max_accesses, iterator_types = ["reduction"], - kind = #vector.kind + kind = #vector.kind } // CHECK-LABEL: @contraction_to_scalar_with_max func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 @@ -281,7 +281,7 @@ #contraction_trait2 = { indexing_maps = #contraction_accesses1, iterator_types = #iterator_types1, - kind = #vector.kind + kind = #vector.kind } // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, @@ -309,7 +309,7 @@ %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> // Test contraction with "max" instead of "add". - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return @@ -432,10 +432,10 @@ vector.reduction "mul", %arg0 : vector<16xf32> into f32 // CHECK: vector.reduction "mul", %{{.*}}, %{{.*}} : vector<16xf32> into f32 vector.reduction "mul", %arg0, %arg1 : vector<16xf32> into f32 - // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32 - vector.reduction "min", %arg0 : vector<16xf32> into f32 - // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32 - %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32 + // CHECK: vector.reduction "minf", %{{.*}} : vector<16xf32> into f32 + vector.reduction "minf", %arg0 : vector<16xf32> into f32 + // CHECK: %[[X:.*]] = vector.reduction "maxf", %{{.*}} : vector<16xf32> into f32 + %0 = vector.reduction "maxf", %arg0 : vector<16xf32> into f32 // CHECK: return %[[X]] : f32 return %0 : f32 } @@ -446,10 +446,14 @@ vector.reduction "add", %arg0 : vector<16xi32> into i32 // CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32 vector.reduction "mul", %arg0 : vector<16xi32> into i32 - // CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32 - vector.reduction "min", %arg0 : vector<16xi32> into i32 - // CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32 - vector.reduction "max", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "minui", %{{.*}} : vector<16xi32> into i32 + vector.reduction "minui", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "minsi", %{{.*}} : vector<16xi32> into i32 + vector.reduction "minsi", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "maxui", %{{.*}} : vector<16xi32> into i32 + vector.reduction "maxui", %arg0 : vector<16xi32> into i32 + // CHECK: vector.reduction "maxsi", %{{.*}} : vector<16xi32> into i32 + vector.reduction "maxsi", %arg0 : vector<16xi32> into i32 // CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32 vector.reduction "and", %arg0 : vector<16xi32> into i32 // CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32 diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -12,7 +12,7 @@ #matvecmax_trait = { indexing_maps = #matvec_accesses, iterator_types = ["parallel", "reduction"], - kind = #vector.kind + kind = #vector.kind } #mattransvec_accesses = [ @@ -91,10 +91,10 @@ // CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32> // CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind} : vector<2xf32>, f32 // CHECK: memref.store %[[T9]], %[[C]][] : memref> // CHECK: return func @matvecmax2x2(%arg0: memref>, %arg1: memref>, diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -18,7 +18,7 @@ // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -27,18 +27,15 @@ // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32> -// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RV01:.+]] = minf %[[V1]], %[[V0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> -// CHECK: %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32> -// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RV012:.+]] = minf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> -// CHECK: %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32> -// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = minf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } @@ -47,14 +44,11 @@ // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32> -// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RV01:.+]] = maxf %[[V1]], %[[V0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> -// CHECK: %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32> -// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RV012:.+]] = maxf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> -// CHECK: %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32> -// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = maxf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir @@ -27,10 +27,10 @@ %1 = vector.reduction "mul", %v2 : vector<64xf32> into f32 vector.print %1 : f32 // CHECK: 6 - %2 = vector.reduction "min", %v2 : vector<64xf32> into f32 + %2 = vector.reduction "minf", %v2 : vector<64xf32> into f32 vector.print %2 : f32 // CHECK: 1 - %3 = vector.reduction "max", %v2 : vector<64xf32> into f32 + %3 = vector.reduction "maxf", %v2 : vector<64xf32> into f32 vector.print %3 : f32 // CHECK: 3 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir @@ -39,10 +39,10 @@ %1 = vector.reduction "mul", %v9 : vector<10xf32> into f32 vector.print %1 : f32 // CHECK: -5760 - %2 = vector.reduction "min", %v9 : vector<10xf32> into f32 + %2 = vector.reduction "minf", %v9 : vector<10xf32> into f32 vector.print %2 : f32 // CHECK: -16 - %3 = vector.reduction "max", %v9 : vector<10xf32> into f32 + %3 = vector.reduction "maxf", %v9 : vector<10xf32> into f32 vector.print %3 : f32 // CHECK: 5 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir @@ -27,10 +27,10 @@ %1 = vector.reduction "mul", %v2 : vector<64xf64> into f64 vector.print %1 : f64 // CHECK: 6 - %2 = vector.reduction "min", %v2 : vector<64xf64> into f64 + %2 = vector.reduction "minf", %v2 : vector<64xf64> into f64 vector.print %2 : f64 // CHECK: 1 - %3 = vector.reduction "max", %v2 : vector<64xf64> into f64 + %3 = vector.reduction "maxf", %v2 : vector<64xf64> into f64 vector.print %3 : f64 // CHECK: 3 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir @@ -39,10 +39,10 @@ %1 = vector.reduction "mul", %v9 : vector<10xf64> into f64 vector.print %1 : f64 // CHECK: -5760 - %2 = vector.reduction "min", %v9 : vector<10xf64> into f64 + %2 = vector.reduction "minf", %v9 : vector<10xf64> into f64 vector.print %2 : f64 // CHECK: -16 - %3 = vector.reduction "max", %v9 : vector<10xf64> into f64 + %3 = vector.reduction "maxf", %v9 : vector<10xf64> into f64 vector.print %3 : f64 // CHECK: 5 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir @@ -39,10 +39,10 @@ %1 = vector.reduction "mul", %v9 : vector<10xi32> into i32 vector.print %1 : i32 // CHECK: -1228800 - %2 = vector.reduction "min", %v9 : vector<10xi32> into i32 + %2 = vector.reduction "minsi", %v9 : vector<10xi32> into i32 vector.print %2 : i32 // CHECK: -80 - %3 = vector.reduction "max", %v9 : vector<10xi32> into i32 + %3 = vector.reduction "maxsi", %v9 : vector<10xi32> into i32 vector.print %3 : i32 // CHECK: 5 %4 = vector.reduction "and", %v9 : vector<10xi32> into i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir @@ -20,11 +20,11 @@ vector.print %1 : i4 // CHECK: 0 - %2 = vector.reduction "min", %v : vector<24xi4> into i4 + %2 = vector.reduction "minsi", %v : vector<24xi4> into i4 vector.print %2 : i4 // CHECK: -8 - %3 = vector.reduction "max", %v : vector<24xi4> into i4 + %3 = vector.reduction "maxsi", %v : vector<24xi4> into i4 vector.print %3 : i4 // CHECK: 7 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir @@ -39,10 +39,10 @@ %1 = vector.reduction "mul", %v9 : vector<10xi64> into i64 vector.print %1 : i64 // CHECK: -1228800 - %2 = vector.reduction "min", %v9 : vector<10xi64> into i64 + %2 = vector.reduction "minsi", %v9 : vector<10xi64> into i64 vector.print %2 : i64 // CHECK: -80 - %3 = vector.reduction "max", %v9 : vector<10xi64> into i64 + %3 = vector.reduction "maxsi", %v9 : vector<10xi64> into i64 vector.print %3 : i64 // CHECK: 5 %4 = vector.reduction "and", %v9 : vector<10xi64> into i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir @@ -19,11 +19,11 @@ vector.print %1 : si4 // CHECK: 0 - %2 = vector.reduction "min", %v : vector<16xsi4> into si4 + %2 = vector.reduction "minsi", %v : vector<16xsi4> into si4 vector.print %2 : si4 // CHECK: -8 - %3 = vector.reduction "max", %v : vector<16xsi4> into si4 + %3 = vector.reduction "maxsi", %v : vector<16xsi4> into si4 vector.print %3 : si4 // CHECK: 7 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir @@ -19,11 +19,11 @@ vector.print %1 : ui4 // CHECK: 0 - %2 = vector.reduction "min", %v : vector<16xui4> into ui4 + %2 = vector.reduction "minui", %v : vector<16xui4> into ui4 vector.print %2 : ui4 // CHECK: 0 - %3 = vector.reduction "max", %v : vector<16xui4> into ui4 + %3 = vector.reduction "maxui", %v : vector<16xui4> into ui4 vector.print %3 : ui4 // CHECK: 15