diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -219,25 +219,6 @@ return success(); } -namespace { -/// Type conversions. -class ShapeTypeConverter : public TypeConverter { -public: - using TypeConverter::convertType; - - ShapeTypeConverter(MLIRContext *ctx) { - // Add default pass-through conversion. - addConversion([&](Type type) { return type; }); - - addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); - addConversion([ctx](ShapeType type) { - return RankedTensorType::get({ShapedType::kDynamicSize}, - IndexType::get(ctx)); - }); - } -}; -} // namespace - namespace { /// Conversion pass. class ConvertShapeToStandardPass @@ -248,23 +229,15 @@ } // namespace void ConvertShapeToStandardPass::runOnOperation() { - // Setup type conversion. - MLIRContext &ctx = getContext(); - ShapeTypeConverter typeConverter(&ctx); - // Setup target legality. + MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); - target.addLegalOp(); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()) && - typeConverter.isLegal(&op.getBody()); - }); + target.addLegalDialect(); + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; populateShapeToStandardConversionPatterns(patterns, &ctx); - populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); // Apply conversion. auto module = getOperation(); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -1,40 +1,11 @@ // RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s -// Convert `size` to `index` type. -// CHECK-LABEL: @size_id -// CHECK-SAME: (%[[SIZE:.*]]: index) -func @size_id(%size : !shape.size) -> !shape.size { - // CHECK: return %[[SIZE]] : index - return %size : !shape.size -} - -// ----- - -// Convert `shape` to `tensor` type. -// CHECK-LABEL: @shape_id -// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -func @shape_id(%shape : !shape.shape) -> !shape.shape { - // CHECK: return %[[SHAPE]] : tensor - return %shape : !shape.shape -} - -// ----- - -// Lower binary ops. -// CHECK-LABEL: @binary_ops -// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) -func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) { - // CHECK: addi %[[LHS]], %[[RHS]] : index - %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size - return -} - -// ----- - // Lower binary ops. // CHECK-LABEL: @binary_ops // CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) func @binary_ops(%lhs : index, %rhs : index) { + // CHECK: addi %[[LHS]], %[[RHS]] : index + %sum = shape.add %lhs, %rhs : index, index -> index // CHECK: muli %[[LHS]], %[[RHS]] : index %product = shape.mul %lhs, %rhs : index, index -> index return