diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2009,7 +2009,7 @@ // GlobalMemrefOp //===----------------------------------------------------------------------===// -def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> { +def GlobalMemrefOp : Std_Op<"global_memref", [Symbol]> { let summary = "declare or define a global memref variable"; let description = [{ The `global_memref` operation declares or defines a named global variable. @@ -2092,6 +2092,10 @@ let results = (outs AnyStaticShapeMemRef:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; + let extraClassDeclaration = [{ + GlobalMemrefOp getGlobalVariable(); + }]; + // `GetGlobalMemrefOp` is fully verified by its traits. let verifier = ?; } @@ -3617,6 +3621,7 @@ let assemblyFormat = "$memref attr-dict `:` type($memref)"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2234,6 +2234,10 @@ // GetGlobalMemrefOp //===----------------------------------------------------------------------===// +GlobalMemrefOp GetGlobalMemrefOp::getGlobalVariable() { + return SymbolTable::lookupNearestSymbolFrom(*this, name()); +} + LogicalResult GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Verify that the result type is same as the type of the referenced @@ -4013,6 +4017,33 @@ return {}; } +namespace { + +struct TensorLoadConstantGlobalFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorLoadOp tensorLoadOp, + PatternRewriter &rewriter) const override { + auto getGlobalMemref = + tensorLoadOp.memref().getDefiningOp(); + if (!getGlobalMemref) + return failure(); + GlobalMemrefOp global = getGlobalMemref.getGlobalVariable(); + if (!global || !global.constant()) + return failure(); + rewriter.replaceOpWithNewOp( + tensorLoadOp, tensorLoadOp.result().getType(), *global.initial_value()); + return success(); + } +}; + +} // end anonymous namespace + +void TensorLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TensorToMemrefOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -31,3 +31,16 @@ %1 = tensor_to_memref %0 : memref return %1 : memref } + +// Test case: Folding of tensor_load(get_global_memeref(constant global_memref)) -> initial value +// CHECK: global_memref "public" constant @gv : memref<4xf32> +global_memref "public" constant @gv : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]> + +// CHECK-LABEL: func @fold_tenor_load_of_constant_global_memref +// CHECK: %[[CONST:.*]] = constant +// CHECK: return %[[CONST]] +func @fold_tenor_load_of_constant_global_memref() -> tensor<4xf32> { + %0 = get_global_memref @gv : memref<4xf32> + %1 = tensor_load %0 : memref<4xf32> + return %1 : tensor<4xf32> +} diff --git a/mlir/test/Transforms/test-symbol-dce.mlir b/mlir/test/Transforms/test-symbol-dce.mlir --- a/mlir/test/Transforms/test-symbol-dce.mlir +++ b/mlir/test/Transforms/test-symbol-dce.mlir @@ -24,6 +24,12 @@ // CHECK: func @public_function_explicit func @public_function_explicit() attributes { sym_visibility = "public" } + + // CHECK: global_memref "public" @gv + global_memref "public" @gv : memref<3xf32> = uninitialized + + // CHECK-NOT: global_memref "private" @unused_gv + global_memref "private" @unused_gv : memref<3xi16> = uninitialized } // -----