diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -551,10 +551,22 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; - matcherFunc.getBody().walk([&](Operation *op) { - opToIndex.insert(std::make_pair(op, opToIndex.size())); - }); + DenseMap opToFirstIndex; + DenseMap opToLastIndex; + + // A custom walk that marks the first and the last index of each operation. + // The entry marks the beginning of the liveness range for this operation, + // followed by nested operations, followed by the end of the liveness range. + unsigned index = 0; + llvm::unique_function walk = [&](Operation *op) { + opToFirstIndex.try_emplace(op, index++); + for (Region ®ion : op->getRegions()) + for (Block &block : region.getBlocks()) + for (Operation &nested : block) + walk(&nested); + opToLastIndex.try_emplace(op, index++); + }; + walk(matcherFunc); // Liveness info for each of the defs within the matcher. ByteCodeLiveRange::Allocator allocator; @@ -578,8 +590,8 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; defRangeIt->second.liveness->insert( - opToIndex[firstUseOrDef], - opToIndex[info->getEndOperation(value, firstUseOrDef)], + opToFirstIndex[firstUseOrDef], + opToLastIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); // Check to see if this value is a range type. diff --git a/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir @@ -0,0 +1,294 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s + +// ----- + +//===----------------------------------------------------------------------===// +// 1-layer perceptron with split fwd/bwd operations +//===----------------------------------------------------------------------===// + +module @patterns { + // fc_fwd + pdl.pattern : benefit(1) { + %in_type = pdl.type + %out_type = pdl.type + %weight_type = pdl.type + %rxact = pdl.operand : %in_type + %weight = pdl.operand : %weight_type + + %attr0 = pdl.attribute false + %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) + + pdl.rewrite %op0 { + %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) + %val1 = pdl.result 0 of %op1 // txact + pdl.replace %op0 with (%val1 : !pdl.value) // tf.MatMul + } + } + + // fc_bwd + pdl.pattern : benefit(4) { + %in_type = pdl.type + %out_type = pdl.type + %weight_type = pdl.type + %const_type = pdl.type + %rxact = pdl.operand : %in_type + %rxdelta = pdl.operand : %out_type + %weight = pdl.operand : %weight_type + + %attr0 = pdl.attribute true + %attr1 = pdl.attribute false + %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type) + %val0 = pdl.result 0 of %op0 + %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) + + pdl.rewrite %op3 { + %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) + %val4 = pdl.result 0 of %op4 // weight_out + pdl.replace %op3 with (%val4 : !pdl.value) // tf.Sub + pdl.erase %op2 // tf.Mul + pdl.erase %op1 // tf.Const + pdl.erase %op0 // tf.MatMul + } + } + + // softmax_cross_entropy + pdl.pattern : benefit(6) { + %in_type = pdl.type + %label_type = pdl.type + %loss_type = pdl.type + %mean_loss_type = pdl.type + %mean_const_type = pdl.type + %mul_const_type = pdl.type + %rxact = pdl.operand : %in_type + %rxlabel = pdl.operand : %label_type + + %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type) + %val0_0 = pdl.result 0 of %op0 // loss + %val0_1 = pdl.result 1 of %op0 // gradient + %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type) + %val3 = pdl.result 0 of %op3 + %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type) + %val4 = pdl.result 0 of %op4 + %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type) + + pdl.rewrite { // roots: %op2, %op5 + %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type) + %val6_0 = pdl.result 0 of %op6 // txloss + %val6_1 = pdl.result 1 of %op6 // txdelta + pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul + pdl.erase %op4 // tf.Const + pdl.erase %op3 // tf.PreventGradient + pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean + pdl.erase %op1 // tf.Const + pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits + } + } +} + +// CHECK-LABEL: test.mlp_split +// CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32> +// CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor, tensor<2x10xf32>) +// CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32> +// CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor, tensor<20x10xf32> +module @ir attributes { test.mlp_split } { + func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor, tensor<20x10xf32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32> + %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32> + %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>) + %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32> + %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32> + %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32> + %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor) -> tensor<20x10xf32> + %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32> + return %4, %9 : tensor, tensor<20x10xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// 2-layer perceptron with fused fwd/bwd operations +//===----------------------------------------------------------------------===// + +module @patterns { + + // gradient descent + pdl.pattern : benefit(3) { + %const_type = pdl.type + %param_type = pdl.type + %param = pdl.operand : %param_type + %gradient = pdl.operand : %param_type + + %attr0 = pdl.attribute + %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type) + %val0 = pdl.result 0 of %op0 + %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) + + pdl.rewrite %op2 { + %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) + %val3 = pdl.result 0 of %op3 + pdl.replace %op2 with (%val3 : !pdl.value) // tf.Sub + pdl.erase %op1 // tf.Mul + } + } + + // first FC + pdl.pattern : benefit(8) { + %in_type = pdl.type + %out_type = pdl.type + %weight_type = pdl.type + %bias_type = pdl.type + %rxact = pdl.operand : %in_type + %rxdelta = pdl.operand : %out_type + %weight = pdl.operand : %weight_type + %bias = pdl.operand : %bias_type + + %attr0 = pdl.attribute false + %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) + %val0 = pdl.result 0 of %op0 + %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) + %val3 = pdl.result 0 of %op3 + %attr1 = pdl.attribute true + %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type) + %val4 = pdl.result 0 of %op4 + %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) + %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type) + %val6 = pdl.result 0 of %op6 + %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type) + + pdl.rewrite { // roots: %op2, %op5, %op7 + %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type) + %val8_0 = pdl.result 0 of %op8 // txact + %val8_1 = pdl.result 1 of %op8 // weight_out + %val8_2 = pdl.result 2 of %op8 // bias_out + pdl.replace %op7 with (%val8_2 : !pdl.value) // kernel.GD + pdl.erase %op6 // tf.BiasAddGrad + pdl.replace %op5 with (%val8_1 : !pdl.value) // kernel.GD + pdl.erase %op4 // tf.MatMul + pdl.erase %op3 // tf.ReluGrad + pdl.replace %op2 with (%val8_0 : !pdl.value) // tf.Relu + pdl.erase %op1 // tf.BiasAdd + pdl.erase %op0 // tf.MatMul + } + } + + // second FC + pdl.pattern : benefit(4) { + %in_type = pdl.type + %out_type = pdl.type + %weight_type = pdl.type + %rxact = pdl.operand : %in_type + %rxdelta = pdl.operand : %out_type + %weight = pdl.operand : %weight_type + + %attr0 = pdl.attribute false + %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) + %attr1 = pdl.attribute true + %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type) + %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) + + pdl.rewrite { // roots: %op0, %op1, %op3 + %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type) + %val4_0 = pdl.result 0 of %op4 // txact + %val4_1 = pdl.result 1 of %op4 // txdelta + %val4_2 = pdl.result 2 of %op4 // weight_out + pdl.replace %op3 with (%val4_2 : !pdl.value) // Sgd + pdl.erase %op2 // tf.MatMul + pdl.replace %op1 with (%val4_1 : !pdl.value) // tf.MatMul + pdl.replace %op0 with (%val4_0 : !pdl.value) // tf.MatMul + } + } + + // softmax_cross_entropy + pdl.pattern : benefit(6) { + %in_type = pdl.type + %label_type = pdl.type + %loss_type = pdl.type + %mean_loss_type = pdl.type + %mean_const_type = pdl.type + %mul_const_type = pdl.type + %rxact = pdl.operand : %in_type + %rxlabel = pdl.operand : %label_type + + %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type) + %val0_0 = pdl.result 0 of %op0 // loss + %val0_1 = pdl.result 1 of %op0 // gradient + %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type) + %val1 = pdl.result 0 of %op1 + %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type) + %val2 = pdl.result 0 of %op2 + %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type) + %val3 = pdl.result 0 of %op3 + %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type) + %val4 = pdl.result 0 of %op4 + %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type) + + pdl.rewrite { // roots: %op2, %op5 + %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type) + %val6_0 = pdl.result 0 of %op6 // txloss + %val6_1 = pdl.result 1 of %op6 // txdelta + pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul + pdl.erase %op4 // tf.Const + pdl.erase %op3 // tf.PreventGradient + pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean + pdl.erase %op1 // tf.Const + pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits + } + } +} + +// CHECK-LABEL: test.mlp_fused +// CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>) +// CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor, tensor<2x10xf32>) +// CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) +module @ir attributes { test.mlp_fused } { + func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) { + // The replacement operations fuse forward and backward pass; therefore, the + // resulting graph is not a DAG. To address this, we wrap the operations in + // a graph region. + "test.graph_region"() ({ + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32> + %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32> + %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32> + %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32> + %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32> + %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>) + %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32> + %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32> + %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32> + %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32> + %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32> + %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32> + %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32> + %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor) -> tensor<20x256xf32> + %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32> + %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor) -> tensor<256xf32> + %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32> + %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor) -> tensor<256x10xf32> + %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32> + }) : () -> () + return + } +} diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -75,6 +76,11 @@ StringRef getDescription() const final { return "Test PDL ByteCode functionality"; } + void getDependentDialects(DialectRegistry ®istry) const override { + // Mark the pdl_interp dialect as a dependent. This is needed, because we + // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. + registry.insert(); + } void runOnOperation() final { ModuleOp module = getOperation();