diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -110,7 +110,7 @@ return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); - if (type.isa()) + if (ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value); return nullptr; } diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -13,7 +13,6 @@ #include "mlir/Transforms/FoldUtils.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" @@ -60,11 +59,6 @@ assert(matchPattern(constOp, m_Constant())); return constOp; } - - // If the dialect is unable to materialize a constant, check to see if the - // standard constant can be used. - if (ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); return nullptr; } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -178,6 +178,11 @@ allowUnknownOperations(); } +Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return builder.create(loc, type, value); +} + static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ def Test_Dialect : Dialect { let name = "test"; let cppNamespace = "::mlir::test"; + let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -254,7 +254,7 @@ // CHECK-LABEL: testConstOp func @testConstOp() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 + // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1 %0 = "test.constant"() {value = 1 : i32} : () -> i32 // CHECK-NEXT: return [[C0]] @@ -263,7 +263,7 @@ // CHECK-LABEL: testConstOpUsed func @testConstOpUsed() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 + // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1 %0 = "test.constant"() {value = 1 : i32} : () -> i32 // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]]) @@ -275,7 +275,7 @@ // CHECK-LABEL: testConstOpReplaced func @testConstOpReplaced() -> (i32) { - // CHECK-NEXT: [[C0:%.+]] = constant 1 + // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1 %0 = "test.constant"() {value = 1 : i32} : () -> i32 %1 = "test.constant"() {value = 2 : i32} : () -> i32 @@ -288,10 +288,10 @@ // CHECK-LABEL: testConstOpMatchFailure func @testConstOpMatchFailure() -> (i64) { - // CHECK-DAG: [[C0:%.+]] = constant 1 + // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1 %0 = "test.constant"() {value = 1 : i64} : () -> i64 - // CHECK-DAG: [[C1:%.+]] = constant 2 + // CHECK-DAG: [[C1:%.+]] = "test.constant"() {value = 2 %1 = "test.constant"() {value = 2 : i64} : () -> i64 // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]]) @@ -303,7 +303,7 @@ // CHECK-LABEL: testConstOpMatchNonConst func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) { - // CHECK-DAG: [[C0:%.+]] = constant 1 + // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1 %0 = "test.constant"() {value = 1 : i32} : () -> i32 // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], %arg0)