diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -155,44 +155,57 @@ } // namespace namespace { -class ConvertIfOpTypes : public OpConversionPattern { +class ConvertIfOpTypes + : public Structural1ToNConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // TODO: Generalize this to any type conversion, not just 1:1. - // - // We need to implement something more sophisticated here that tracks - // which types convert to which other types and does the appropriate - // materialization logic. - // For example, it's possible that one result type converts to 0 types and - // another to 2 types, so newResultTypes would at least be the right size - // to not crash in the llvm::zip call below, but then we would set the the - // wrong type on the SSA values! These edge cases are also why we cannot - // safely use the TypeConverter::convertTypes helper here. - SmallVector newResultTypes; - for (auto type : op.getResultTypes()) { - Type newType = typeConverter->convertType(type); - if (!newType) - return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); - newResultTypes.push_back(newType); - } + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - // See comments in the ForOp pattern for why we clone without regions and - // then inline. - IfOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); + Optional convertSourceOp(IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + + IfOp newOp = rewriter.create(op.getLoc(), dstTypes, + adaptor.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), newOp.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), newOp.getElseRegion().end()); - // Update the operands and types. - newOp->setOperands(adaptor.getOperands()); - for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); - rewriter.replaceOp(op, newOp.getResults()); - return success(); + return newOp; + } +}; +} // namespace + +namespace { +class ConvertWhileOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + Optional convertSourceOp(WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + // Unpacked the iteration arguments. + SmallVector flatArgs; + for (Value arg : adaptor.getOperands()) + unpackUnrealizedConversionCast(arg, flatArgs); + + auto newOp = rewriter.create(op.getLoc(), dstTypes, flatArgs); + + for (auto i : {0u, 1u}) { + if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) + return llvm::None; + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + return newOp; } }; } // namespace @@ -217,34 +230,6 @@ }; } // namespace -namespace { -class ConvertWhileOpTypes : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(WhileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *converter = getTypeConverter(); - assert(converter); - SmallVector newResultTypes; - if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) - return failure(); - - auto newOp = rewriter.create(op.getLoc(), newResultTypes, - adaptor.getOperands()); - for (auto i : {0u, 1u}) { - auto &dstRegion = newOp.getRegion(i); - rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); - if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) - return rewriter.notifyMatchFailure(op, "could not convert body types"); - } - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } -}; -} // namespace - namespace { class ConvertConditionOpTypes : public OpConversionPattern { public: @@ -252,8 +237,11 @@ LogicalResult matchAndRewrite(ConditionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace( - op, [&]() { op->setOperands(adaptor.getOperands()); }); + SmallVector unpackedYield; + for (Value operand : adaptor.getOperands()) + unpackUnrealizedConversionCast(operand, unpackedYield); + + rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); }); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -30,3 +30,68 @@ return %1 : tensor<1024xf32, #SparseVector> } + +// CHECK-LABEL: func @if( +// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*3]]: memref, +// CHECK-SAME: %[[INDICES:.*4]]: memref, +// CHECK-SAME: %[[VALUE:.*5]]: memref, +// CHECK-SAME: %[[DIM_SIZE_1:.*6]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR_1:.*7]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE_1:.*8]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER_1:.*9]]: memref, +// CHECK-SAME: %[[INDICES_1:.*10]]: memref, +// CHECK-SAME: %[[VALUE_1:.*11]]: memref, +// CHECK-SAME: %[[TMP_arg12:.*12]]: i1) -> +// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[SV:.*]]:6 = scf.if %[[TMP_arg12]] -> (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: scf.yield %[[DIM_SIZE]], %[[DIM_CURSOR]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } else { +// CHECK: scf.yield %[[DIM_SIZE_1]], %[[DIM_CURSOR_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } +// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @if(%t: tensor<1024xf32, #SparseVector>, + %f: tensor<1024xf32, #SparseVector>, + %c: i1) -> tensor<1024xf32, #SparseVector> { + %1 = scf.if %c -> tensor<1024xf32, #SparseVector> { + scf.yield %t : tensor<1024xf32, #SparseVector> + } else { + scf.yield %f : tensor<1024xf32, #SparseVector> + } + + return %1 : tensor<1024xf32, #SparseVector> +} + +// CHECK-LABEL: func @while( +// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*3]]: memref, +// CHECK-SAME: %[[INDICES:.*4]]: memref, +// CHECK-SAME: %[[VALUE:.*5]]: memref, +// CHECK-SAME: %[[TMP_arg6:.*6]]: i1) -> +// CHECK-SAME: (memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK: %[[SV:.*]]:6 = scf.while ( +// CHECK-SAME: %[[TMP_arg7:.*]] = %[[DIM_SIZE]], +// CHECK-SAME: %[[TMP_arg8:.*]] = %[[DIM_CURSOR]], +// CHECK-SAME: %[[TMP_arg9:.*]] = %[[MEM_SIZE]], +// CHECK-SAME: %[[TMP_arg10:.*]] = %[[POINTER]], +// CHECK-SAME: %[[TMP_arg11:.*]] = %[[INDICES]], +// CHECK-SAME: %[[TMP_arg12:.*]] = %[[VALUE]]) +// CHECK: scf.condition(%[[TMP_arg6]]) %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } do { +// CHECK: ^bb0(%[[TMP_arg7]]: memref<1xindex>, %[[TMP_arg8]]: memref<1xindex>, %[[TMP_arg9]]: memref<3xindex>, %[[TMP_arg10]]: memref, %[[TMP_arg11]]: memref, %[[TMP_arg12]]: memref): +// CHECK: scf.yield %[[TMP_arg7]], %[[TMP_arg8]], %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } +// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4, %[[SV]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { + %0 = scf.while (%arg4 = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { + scf.condition(%c) %arg4 : tensor<1024xf32, #SparseVector> + } do { + ^bb0(%arg7: tensor<1024xf32, #SparseVector>): + scf.yield %arg7 : tensor<1024xf32, #SparseVector> + } + return %0: tensor<1024xf32, #SparseVector> +}