diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3065,11 +3065,32 @@ matchAndRewrite(CmpIOp cmpiOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { CmpIOpAdaptor transformed(operands); + auto operandType = transformed.lhs().getType(); + auto resultType = cmpiOp.getResult().getType(); - rewriter.replaceOpWithNewOp( - cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), - convertCmpPredicate(cmpiOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + cmpiOp, typeConverter->convertType(resultType), + convertCmpPredicate(cmpiOp.getPredicate()), + transformed.lhs(), transformed.rhs()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); + + return handleMultidimensionalVectors( + cmpiOp.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + CmpIOpAdaptor transformed(operands); + return rewriter.create( + cmpiOp.getLoc(), llvm1DVectorTy, + convertCmpPredicate(cmpiOp.getPredicate()), + transformed.lhs(), transformed.rhs()); + }, + rewriter); return success(); } @@ -3082,13 +3103,32 @@ matchAndRewrite(CmpFOp cmpfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); + auto operandType = transformed.lhs().getType(); + auto resultType = cmpfOp.getResult().getType(); - rewriter.replaceOpWithNewOp( - cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), - convertCmpPredicate(cmpfOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + // Handle the scalar and 1D vector cases. + if (!operandType.isa()) { + rewriter.replaceOpWithNewOp( + cmpfOp, typeConverter->convertType(resultType), + convertCmpPredicate(cmpfOp.getPredicate()), + transformed.lhs(), transformed.rhs()); + return success(); + } - return success(); + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); + + return handleMultidimensionalVectors( + cmpfOp.getOperation(), operands, *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) { + CmpFOpAdaptor transformed(operands); + return rewriter.create( + cmpfOp.getLoc(), llvm1DVectorTy, + convertCmpPredicate(cmpfOp.getPredicate()), + transformed.lhs(), transformed.rhs()); + }, + rewriter); } }; diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -274,3 +274,26 @@ std.return } +// ----- + +// CHECK-LABEL: func @cmpf_2dvector( +func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>> + %0 = cmpf olt, %arg0, %arg1 : vector<4x3xf32> + std.return +} + +// ----- + +// CHECK-LABEL: func @cmpi_2dvector( +func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) { + // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %0[0] : !llvm.array<4 x vector<3xi1>> + %0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32> + std.return +}